├── .devcontainer.json ├── .dockerignore ├── .github └── workflows │ ├── ci.yaml │ └── publish.yaml ├── .gitignore ├── Dockerfile ├── LICENSE ├── NOTICE.md ├── README.md ├── analysis ├── almost_scaled_dot_product_attention │ ├── .gitignore │ ├── almost_scaled_dot_product_attention.ipynb │ └── demo_transformer.py ├── benchmarking_compiled_unit_scaled_ops.ipynb ├── emb_lr_analysis.ipynb ├── empirical_op_scaling.ipynb └── sgd_scaling.ipynb ├── dev ├── docs ├── Makefile ├── _static │ ├── animation.html │ └── scales.png ├── _templates │ ├── custom-class-template.rst │ └── custom-module-template.rst ├── api_reference.rst ├── conf.py ├── index.rst ├── limitations.rst ├── make.bat ├── posts │ └── almost_scaled_dot_product_attention.md ├── u-muP_slides.pdf └── user_guide.rst ├── examples ├── .gitignore ├── demo.ipynb ├── how_to_scale_op.ipynb └── scale_analysis.py ├── pyproject.toml ├── requirements-dev.txt ├── setup.cfg └── unit_scaling ├── __init__.py ├── _internal_utils.py ├── _modules.py ├── analysis.py ├── constraints.py ├── core ├── __init__.py └── functional.py ├── docs.py ├── formats.py ├── functional.py ├── optim.py ├── parameter.py ├── scale.py ├── tests ├── __init__.py ├── conftest.py ├── core │ ├── __init__.py │ └── test_functional.py ├── helper.py ├── test_analysis.py ├── test_constraints.py ├── test_docs.py ├── test_formats.py ├── test_functional.py ├── test_modules.py ├── test_optim.py ├── test_parameter.py ├── test_scale.py ├── test_utils.py └── transforms │ ├── __init__.py │ ├── test_compile.py │ ├── test_simulate_format.py │ ├── test_track_scales.py │ └── test_unit_scale.py ├── transforms ├── __init__.py ├── _compile.py ├── _simulate_format.py ├── _track_scales.py ├── _unit_scale.py └── utils.py └── utils.py /.devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "build": { 3 | "dockerfile": "Dockerfile" 4 | }, 5 | "workspaceFolder": "/home/developer/unit-scaling", 6 | "customizations": { 7 | "vscode": { 8 | "extensions": [ 9 | "ms-python.python", 10 | "ms-toolsai.jupyter" 11 | ], 12 | "settings": { 13 | "terminal.integrated.defaultProfile.linux": "zsh", 14 | "terminal.integrated.profiles.linux": { "zsh": { "path": "/bin/zsh" } } 15 | } 16 | } 17 | }, 18 | "mounts": [ 19 | "source=${localEnv:HOME}/.ssh,target=/home/developer/.ssh,type=bind,readonly=true", 20 | "source=${localEnv:HOME}/.gitconfig,target=/home/developer/.gitconfig,type=bind,readonly=true", 21 | "source=${localWorkspaceFolder},target=/home/developer/unit-scaling,type=bind" 22 | ], 23 | "remoteUser": "developer" 24 | } 25 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !requirements-dev.txt 3 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: { branches: ["main"] } 5 | pull_request: 6 | workflow_dispatch: 7 | 8 | concurrency: 9 | # Run everything on main, most-recent on PR builds 10 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | ci: 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 10 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v3 20 | 21 | - name: Build Docker image 22 | run: docker build -t unit-scaling-dev:latest . 23 | 24 | - name: Local unit_scaling install 25 | run: docker run -v $(pwd):/home/developer/unit-scaling unit-scaling-dev:latest pip install --user -e . 26 | 27 | - name: Run CI 28 | run: docker run -v $(pwd):/home/developer/unit-scaling unit-scaling-dev:latest ./dev ci 29 | 30 | - name: Publish documentation 31 | if: ${{github.ref == 'refs/heads/main'}} 32 | uses: Cecilapp/GitHub-Pages-deploy@v3 33 | env: { GITHUB_TOKEN: "${{ github.token }}" } 34 | with: 35 | build_dir: docs/_build/html 36 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | # Based on template at https://docs.github.com/en/enterprise-cloud@latest/actions/use-cases-and-examples/building-and-testing/building-and-testing-python?learn=continuous_integration#publishing-to-pypi 2 | 3 | name: Publish to PyPI 4 | 5 | on: 6 | release: 7 | types: [published] 8 | workflow_dispatch: 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | release-build: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.x" 23 | 24 | - name: Build release distributions 25 | run: | 26 | python -m pip install build 27 | python -m build 28 | 29 | - name: Upload distributions 30 | uses: actions/upload-artifact@v4 31 | with: 32 | name: release-dists 33 | path: dist/ 34 | 35 | pypi-publish: 36 | runs-on: ubuntu-latest 37 | 38 | needs: 39 | - release-build 40 | 41 | permissions: 42 | # IMPORTANT: this permission is mandatory for trusted publishing 43 | id-token: write 44 | 45 | # Dedicated environments with protections for publishing are strongly recommended. 46 | environment: 47 | name: pypi 48 | url: https://pypi.org/p/unit-scaling 49 | 50 | steps: 51 | - name: Retrieve release distributions 52 | uses: actions/download-artifact@v4 53 | with: 54 | name: release-dists 55 | path: dist/ 56 | 57 | - name: Publish release distributions to PyPI 58 | uses: pypa/gh-action-pypi-publish@release/v1 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .coverage* 3 | .env 4 | .mypy_cache 5 | __pycache__ 6 | .pytest_cache 7 | .venv 8 | .venvs 9 | .vscode 10 | unit_scaling/_version.py 11 | 12 | /build 13 | /dist 14 | /local 15 | 16 | /docs/_build 17 | /docs/generated 18 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use PyTorch base image 2 | FROM pytorch/pytorch:latest 3 | 4 | # Install additional dependencies 5 | RUN apt-get update && apt-get install -y \ 6 | git \ 7 | vim \ 8 | sudo \ 9 | make \ 10 | g++ \ 11 | zsh \ 12 | && chsh -s /bin/zsh \ 13 | && apt-get clean && rm -rf /var/lib/apt/lists/* # cleanup (smaller image) 14 | 15 | # Configure a non-root user with sudo privileges 16 | ARG USERNAME=developer # Change this to preferred username 17 | ARG USER_UID=1001 18 | ARG USER_GID=$USER_UID 19 | RUN groupadd --gid $USER_GID $USERNAME \ 20 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ 21 | && echo "$USERNAME ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/$USERNAME \ 22 | && chmod 0440 /etc/sudoers.d/$USERNAME 23 | USER $USERNAME 24 | 25 | # Set working directory 26 | WORKDIR /home/$USERNAME/unit-scaling 27 | 28 | # Puts pip install libs on $PATH & sets correct locale 29 | ENV PATH="$PATH:/home/$USERNAME/.local/bin" \ 30 | LC_ALL=C.UTF-8 31 | 32 | # Install Python dependencies 33 | COPY requirements-dev.txt . 34 | RUN pip install --user -r requirements-dev.txt 35 | 36 | # Creates basic .zshrc 37 | RUN sudo cp /etc/zsh/newuser.zshrc.recommended /home/$USERNAME/.zshrc 38 | 39 | CMD ["/bin/zsh"] 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Graphcore 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Graphcore Ltd. Licensed under the Apache 2.0 License. 2 | 3 | The included code is released under an Apache 2.0 license, (see [LICENSE](LICENSE)). 4 | 5 | Our dependencies are (see [pyproject.toml](pyproject.toml)): 6 | 7 | | Component | About | License | 8 | | --- | --- | --- | 9 | | docstring-parser | Parse Python docstrings | MIT | 10 | | einops | Deep learning operations reinvented (for pytorch, tensorflow, jax and others) | MIT | 11 | | numpy | Array processing library | BSD 3-Clause | 12 | 13 | We also use additional Python dependencies for development/testing (see [requirements-dev.txt](requirements-dev.txt)). 14 | 15 | **This directory includes derived work from the following:** 16 | 17 | --- 18 | 19 | Sphinx: https://github.com/sphinx-doc/sphinx, licensed under: 20 | 21 | > Unless otherwise indicated, all code in the Sphinx project is licenced under the 22 | > two clause BSD licence below. 23 | > 24 | > Copyright (c) 2007-2023 by the Sphinx team (see AUTHORS file). 25 | > All rights reserved. 26 | > 27 | > Redistribution and use in source and binary forms, with or without 28 | > modification, are permitted provided that the following conditions are 29 | > met: 30 | > 31 | > * Redistributions of source code must retain the above copyright 32 | > notice, this list of conditions and the following disclaimer. 33 | > 34 | > * Redistributions in binary form must reproduce the above copyright 35 | > notice, this list of conditions and the following disclaimer in the 36 | > documentation and/or other materials provided with the distribution. 37 | > 38 | > THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 39 | > "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 40 | > LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 41 | > A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 42 | > HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 43 | > SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 44 | > LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 45 | > DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 46 | > THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 47 | > (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 48 | > OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 49 | 50 | this applies to: 51 | - `docs/_templates/custom-class-template.rst` (modified) 52 | - `docs/_templates/custom-module-template.rst` (modified) 53 | 54 | --- 55 | 56 | The Example: Basic Sphinx project for Read the Docs: https://github.com/readthedocs-examples/example-sphinx-basic, licensed under: 57 | 58 | > MIT License 59 | > 60 | > Copyright (c) 2022 Read the Docs Inc 61 | > 62 | > Permission is hereby granted, free of charge, to any person obtaining a copy 63 | > of this software and associated documentation files (the "Software"), to deal 64 | > in the Software without restriction, including without limitation the rights 65 | > to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 66 | > copies of the Software, and to permit persons to whom the Software is 67 | > furnished to do so, subject to the following conditions: 68 | > 69 | > The above copyright notice and this permission notice shall be included in all 70 | > copies or substantial portions of the Software. 71 | > 72 | > THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 73 | > IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 74 | > FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 75 | > AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 76 | > LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 77 | > OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 78 | > SOFTWARE. 79 | 80 | this applies to: 81 | - `docs/conf.py` (modified) 82 | - `docs/make.bat` 83 | - `docs/Makefile` 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unit-Scaled Maximal Update Parameterization (u-μP) 2 | 3 | [![tests](https://github.com/graphcore-research/unit-scaling/actions/workflows/ci.yaml/badge.svg)](https://github.com/graphcore-research/unit-scaling/actions/workflows/ci-public.yaml) 4 | ![PyPI version](https://img.shields.io/pypi/v/unit-scaling) 5 | [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/graphcore-research/unit-scaling/blob/main/LICENSE) 6 | [![GitHub Repo stars](https://img.shields.io/github/stars/graphcore-research/unit-scaling)](https://github.com/graphcore-research/unit-scaling/stargazers) 7 | 8 | A library for unit scaling in PyTorch, based on the paper [u-μP: The Unit-Scaled Maximal Update Parametrization](https://arxiv.org/abs/2407.17465) and previous work [Unit Scaling: Out-of-the-Box Low-Precision Training](https://arxiv.org/abs/2303.11257). 9 | 10 | Documentation can be found at 11 | [https://graphcore-research.github.io/unit-scaling](https://graphcore-research.github.io/unit-scaling) and an example notebook at [examples/demo.ipynb](examples/demo.ipynb). 12 | 13 | **Note:** The library is currently in its _beta_ release. 14 | Some features have yet to be implemented and occasional bugs may be present. 15 | We're keen to help users with any problems they encounter. 16 | 17 | ## Installation 18 | 19 | To install the `unit-scaling` library, run: 20 | 21 | ```sh 22 | pip install unit-scaling 23 | ``` 24 | or for a local editable install (i.e. one which uses the files in this repo), run: 25 | 26 | ```sh 27 | pip install -e . 28 | ``` 29 | 30 | ## Development 31 | 32 | For development in this repository, we recommend using the provided docker container. 33 | This image can be built and entered interactively using: 34 | 35 | ```sh 36 | docker build -t unit-scaling-dev:latest . 37 | docker run -it --rm --user developer:developer -v $(pwd):/home/developer/unit-scaling unit-scaling-dev:latest 38 | # To use git within the container, add `-v ~/.ssh:/home/developer/.ssh:ro -v ~/.gitconfig:/home/developer/.gitconfig:ro`. 39 | ``` 40 | 41 | For vscode users, this repo also contains a `.devcontainer.json` file, which enables the container to be used as a full-featured IDE (see the [Dev Container docs](https://code.visualstudio.com/docs/devcontainers/containers) for details on how to use this feature). 42 | 43 | Key development functionality is contained within the `./dev` script. This includes running unit tests, linting, formatting, documentation generation and more. Run `./dev --help` for the available options. Running `./dev` without arguments is equivalent to using the `--ci` option, which runs all of the available dev checks. This is the test used for GitHub CI. 44 | 45 | We encourage pull requests from the community. Please reach out to us with any questions about contributing. 46 | 47 | ## What is u-μP? 48 | 49 | u-μP inserts scaling factors into the model to make activations, gradients and weights unit-scaled (RMS ≈ 1) at initialisation, and into optimiser learning rates to keep updates stable as models are scaled in width and depth. This results in hyperparameter transfer from small to large models and easy support for low-precision training. 50 | 51 | For a quick intro, see [examples/demo.ipynb](examples/demo.ipynb), for more depth see the [paper](https://arxiv.org/abs/2407.17465) and [library documentation](https://graphcore-research.github.io/unit-scaling/). 52 | 53 | ## What is unit scaling? 54 | 55 | For a demonstration of the library and an overview of how it works, see 56 | [Out-of-the-Box FP8 Training](https://github.com/graphcore-research/out-of-the-box-fp8-training/blob/main/out_of_the_box_fp8_training.ipynb) 57 | (a notebook showing how to unit-scale the nanoGPT model). 58 | 59 | For a more in-depth explanation, consult our paper 60 | [Unit Scaling: Out-of-the-Box Low-Precision Training](https://arxiv.org/abs/2303.11257). 61 | 62 | And for a practical introduction to using the library, see our [User Guide](https://graphcore-research.github.io/unit-scaling/user_guide.html). 63 | 64 | ## License 65 | 66 | Copyright (c) 2023 Graphcore Ltd. Licensed under the Apache 2.0 License. 67 | 68 | See [NOTICE.md](NOTICE.md) for further details. 69 | -------------------------------------------------------------------------------- /analysis/almost_scaled_dot_product_attention/.gitignore: -------------------------------------------------------------------------------- 1 | shakespeare.txt 2 | -------------------------------------------------------------------------------- /analysis/almost_scaled_dot_product_attention/demo_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from itertools import islice 4 | import math 5 | from pathlib import Path 6 | from typing import * 7 | 8 | import einops 9 | import torch 10 | from torch import nn, Tensor 11 | import tqdm 12 | 13 | 14 | class Config(dict): 15 | def __init__(self, *args: Any, **kwargs: Any): 16 | super().__init__(*args, **kwargs) 17 | self.__dict__ = self 18 | 19 | 20 | CONFIG = Config( 21 | sequence_length=256, 22 | batch_size=16, 23 | hidden_size=256, 24 | head_size=64, 25 | depth=4, 26 | fully_scaled_attention=False, 27 | lr=2**-10, 28 | steps=5000, 29 | ) 30 | 31 | 32 | # https://www.gutenberg.org/cache/epub/100/pg100.txt 33 | DATA = torch.tensor(list(Path("shakespeare.txt").read_bytes())) 34 | 35 | 36 | def batches() -> Iterable[Tensor]: 37 | while True: 38 | offsets = torch.randint( 39 | len(DATA) - CONFIG.sequence_length - 1, (CONFIG.batch_size,) 40 | ) 41 | yield torch.stack([DATA[i : i + CONFIG.sequence_length + 1] for i in offsets]) 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self): 46 | super().__init__() 47 | self.head_size = CONFIG.head_size 48 | self.n_heads = CONFIG.hidden_size // CONFIG.head_size 49 | self.qkv = nn.Linear(CONFIG.hidden_size, 3 * self.n_heads * self.head_size) 50 | self.proj = nn.Linear(self.n_heads * self.head_size, CONFIG.hidden_size) 51 | # Put the scale in a non-trainable parameter, to avoid recompilation 52 | self.out_scale = nn.Parameter( 53 | torch.tensor( 54 | (CONFIG.sequence_length / math.e) ** 0.5 55 | if CONFIG.fully_scaled_attention 56 | else 1.0 57 | ), 58 | requires_grad=False, 59 | ) 60 | 61 | def forward(self, x: Tensor) -> Tensor: 62 | s = x.shape[1] 63 | q, k, v = einops.rearrange( 64 | self.qkv(x), "b s (M n d) -> M b n s d", M=3, n=self.n_heads 65 | ) 66 | qk_scale = torch.tensor(self.head_size**-0.5, dtype=x.dtype, device=x.device) 67 | pre_a = torch.einsum("bnsd, bntd -> bnst", q, k) * qk_scale 68 | pre_a = pre_a + torch.triu( 69 | torch.full((s, s), -1e4, device=x.device), diagonal=1 70 | ) 71 | a = torch.softmax(pre_a, -1) 72 | out = torch.einsum("bnst, bntd -> bnsd", a, v) * self.out_scale 73 | return self.proj(einops.rearrange(out, "b n s d -> b s (n d)")) 74 | 75 | 76 | class FFN(nn.Module): 77 | def __init__(self): 78 | super().__init__() 79 | self.up = nn.Linear(CONFIG.hidden_size, 4 * CONFIG.hidden_size) 80 | self.down = nn.Linear(self.up.out_features, self.up.in_features) 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | return self.down(torch.nn.functional.gelu(self.up(x))) 84 | 85 | 86 | class PreNormResidual(nn.Module): 87 | def __init__(self, body: nn.Module): 88 | super().__init__() 89 | self.norm = nn.LayerNorm([CONFIG.hidden_size]) 90 | self.body = body 91 | 92 | def forward(self, x: Tensor) -> Tensor: 93 | return x + self.body(self.norm(x)) 94 | 95 | 96 | class AbsolutePositionalEncoding(nn.Module): 97 | def __init__(self): 98 | super().__init__() 99 | self.weight = nn.Parameter( 100 | torch.randn(CONFIG.sequence_length, CONFIG.hidden_size) 101 | ) 102 | 103 | def forward(self, x: Tensor) -> Tensor: 104 | return x + self.weight 105 | 106 | 107 | class Model(nn.Module): 108 | def __init__(self): 109 | super().__init__() 110 | self.model = nn.Sequential( 111 | nn.Embedding(256, CONFIG.hidden_size), 112 | AbsolutePositionalEncoding(), 113 | nn.LayerNorm([CONFIG.hidden_size]), 114 | *( 115 | nn.Sequential(PreNormResidual(Attention()), PreNormResidual(FFN())) 116 | for _ in range(CONFIG.depth) 117 | ), 118 | nn.LayerNorm([CONFIG.hidden_size]), 119 | nn.Linear(CONFIG.hidden_size, 256), 120 | ) 121 | 122 | def forward(self, indices: Tensor) -> Tensor: 123 | return nn.functional.cross_entropy( 124 | self.model(indices[:, :-1]).flatten(0, -2), indices[:, 1:].flatten() 125 | ) 126 | 127 | 128 | def train() -> Tensor: 129 | model = Model() 130 | opt = torch.optim.Adam(model.parameters(), lr=CONFIG.lr) 131 | losses = [] 132 | for batch in tqdm.tqdm(islice(batches(), CONFIG.steps), total=CONFIG.steps): 133 | opt.zero_grad() 134 | loss = model(batch) 135 | loss.backward() 136 | opt.step() 137 | losses.append(float(loss)) 138 | return torch.tensor(losses) 139 | -------------------------------------------------------------------------------- /analysis/emb_lr_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Analysis of the effect of the embedding LR update on the subsequent matmul\n", 8 | "\n", 9 | "I wanted to write this out in a notebook to make sure I understood the way in which the embedding update effects the subsequent matmul.\n", 10 | "\n", 11 | "No revelations unfortunately - it still seems as though our rule can't be justified this way (it is \"unnatural\"!). Under the \"no-alignment\" assumption the standard embedding LR breaks, but unfortunately our fix does nothing to help. Oh well." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from torch import randn\n", 22 | "from typing import Iterable" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "def rms(*xs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:\n", 32 | " if len(xs) == 1:\n", 33 | " return xs[0].pow(2).mean().sqrt()\n", 34 | " return tuple(rms(x) for x in xs)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Setup\n", 42 | "\n", 43 | "Toggle `full_alignment` and `umup_lr_rule` to see the effect. mup scaling is used by default." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "d = 2**11\n", 53 | "full_alignment = True\n", 54 | "umup_lr_rule = False\n", 55 | "\n", 56 | "w_lr = d ** -(1 if full_alignment else 0.5)\n", 57 | "e_lr = d ** -(0.5 if umup_lr_rule else 0)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## Model & update\n", 65 | "\n", 66 | "Everything can be described in terms of these three tensors (a single embedding vector, weight matrix and a gradient vector). Note that I assume the gradient is unit-scale, and then just use the adam LR rules but under and SGD-like update (I appreciate this is a bit odd, but it's simple and the maths should work out)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "(tensor(0.9984), tensor(0.0221), tensor(0.9882))" 78 | ] 79 | }, 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "e1 = randn(d, 1)\n", 87 | "W1 = randn(d + 1, d) * d**-0.5\n", 88 | "g = randn(d + 1, 1)\n", 89 | "rms(\n", 90 | " e1, W1, g\n", 91 | ") # all \"well-scaled\", except the weight which is 1/sqrt(d) (this isn't unit scaling!)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "Then we just run:" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 5, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "tensor(0.9953)" 110 | ] 111 | }, 112 | "execution_count": 5, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "x1 = W1 @ e1\n", 119 | "rms(x1) # well-scaled" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 6, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "((tensor(0.9977), tensor(0.0005)), 0.00048828125)" 131 | ] 132 | }, 133 | "execution_count": 6, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | } 137 | ], 138 | "source": [ 139 | "u_e = W1.T @ g * e_lr\n", 140 | "u_W = g @ e1.T * w_lr\n", 141 | "(\n", 142 | " rms(u_e, u_W),\n", 143 | " 1 / d,\n", 144 | ") # the weight update is under-scaled (to be expected I think), though as a rank-1 matrix it has a much higher (O(1)) spectral norm! This means its effect doesn't \"go to zero\" in inf. width, though the rms does." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 7, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "text/plain": [ 155 | "(tensor(0.9998), tensor(0.0221))" 156 | ] 157 | }, 158 | "execution_count": 7, 159 | "metadata": {}, 160 | "output_type": "execute_result" 161 | } 162 | ], 163 | "source": [ 164 | "e2 = e1 + u_e\n", 165 | "e2_std = e2.std()\n", 166 | "e2 /= e2_std # Why is `/ e2.std()` allowed/justified? Normally we'd have a much smaller weight update (scaled down by small LR constant), and then the original weight would be decayed a bit, keeping this at about rms=1. This re-scaling does something similar, though allows us to see the effect of the weight update scaling more clearly.\n", 167 | "W2 = W1 + u_W\n", 168 | "rms(\n", 169 | " e2, W2\n", 170 | ") # Update is well-scaled. Weight has barely changed from its 1/sqrt(d) starting point" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 8, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/plain": [ 181 | "tensor(1.7412)" 182 | ] 183 | }, 184 | "execution_count": 8, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "x2 = W2 @ e2\n", 191 | "rms(x2) # ~well-scaled. Certainly doesn't scale with a significant power of d" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "## Analysis\n", 199 | "\n", 200 | "Now we break this down into its constituent terms.\n", 201 | "\n", 202 | "First checking that they combine to the original" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 9, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "data": { 212 | "text/plain": [ 213 | "True" 214 | ] 215 | }, 216 | "execution_count": 9, 217 | "metadata": {}, 218 | "output_type": "execute_result" 219 | } 220 | ], 221 | "source": [ 222 | "torch.allclose(x2, (W1 + u_W) @ (e1 + u_e * e_lr) / e2_std, atol=1e-6)\n", 223 | "torch.allclose(x2, (W1 + g @ e1.T * w_lr) @ (e1 + W1.T @ g * e_lr) / e2_std, atol=1e-6)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 10, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "True" 235 | ] 236 | }, 237 | "execution_count": 10, 238 | "metadata": {}, 239 | "output_type": "execute_result" 240 | } 241 | ], 242 | "source": [ 243 | "# t1 = W1 @ e1 (== x1)\n", 244 | "t2 = W1 @ W1.T @ g * e_lr\n", 245 | "t3 = g @ e1.T * w_lr @ e1\n", 246 | "t4 = g @ e1.T * w_lr @ W1.T @ g * e_lr\n", 247 | "torch.allclose(x2, (x1 + t2 + t3 + t4) / e2_std, atol=1e-5)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "### Weight @ emb_update (t2)\n", 255 | "\n", 256 | "This is well-scaled under the original emb lr rule, but not under our lr rule - which isn't a great sign for our approach" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 11, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "rms(W1, g), e_lr=((tensor(0.0221), tensor(0.9882)), 1)\n", 269 | "rms(W1 @ W1.T)=tensor(0.0312)\n", 270 | "rms(W1.T @ g)=tensor(0.9977)\n", 271 | "rms(W1 @ W1.T @ g * e_lr / e2_std)=tensor(0.9857)\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "print(f\"{rms(W1, g), e_lr=}\")\n", 277 | "print(f\"{rms(W1 @ W1.T)=}\")\n", 278 | "print(f\"{rms(W1.T @ g)=}\")\n", 279 | "print(f\"{rms(W1 @ W1.T @ g * e_lr / e2_std)=}\")" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": {}, 285 | "source": [ 286 | "### Weight_update @ emb (t3)\n", 287 | "\n", 288 | "This is well-scaled under the original emb lr rule and our rule" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 12, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "name": "stdout", 298 | "output_type": "stream", 299 | "text": [ 300 | "rms(g, e1)=(tensor(0.9882), tensor(0.9984))\n", 301 | "rms(g @ e1.T)=tensor(0.9866)\n", 302 | "rms(e1.T @ e1 * w_lr)=tensor(0.9968)\n", 303 | "rms(g @ e1.T * w_lr @ e1)=tensor(0.9850)\n" 304 | ] 305 | } 306 | ], 307 | "source": [ 308 | "print(f\"{rms(g, e1)=}\")\n", 309 | "print(f\"{rms(g @ e1.T)=}\")\n", 310 | "print(f\"{rms(e1.T @ e1 * w_lr)=}\")\n", 311 | "print(f\"{rms(g @ e1.T * w_lr @ e1)=}\")" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "metadata": {}, 317 | "source": [ 318 | "### Weight_update @ emb_update (t4)\n", 319 | "\n", 320 | "This vanishes with width under the original emb lr and our rule. Probably a good thing?" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 13, 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "name": "stdout", 330 | "output_type": "stream", 331 | "text": [ 332 | "rms(g @ e1.T @ W1.T @ g)=tensor(46.5558)\n", 333 | "rms(g @ e1.T * w_lr @ W1.T @ g * e_lr)=tensor(0.0227)\n" 334 | ] 335 | } 336 | ], 337 | "source": [ 338 | "print(f\"{rms(g @ e1.T @ W1.T @ g)=}\")\n", 339 | "print(f\"{rms(g @ e1.T * w_lr @ W1.T @ g * e_lr)=}\")" 340 | ] 341 | } 342 | ], 343 | "metadata": { 344 | "kernelspec": { 345 | "display_name": ".venv", 346 | "language": "python", 347 | "name": "python3" 348 | }, 349 | "language_info": { 350 | "codemirror_mode": { 351 | "name": "ipython", 352 | "version": 3 353 | }, 354 | "file_extension": ".py", 355 | "mimetype": "text/x-python", 356 | "name": "python", 357 | "nbconvert_exporter": "python", 358 | "pygments_lexer": "ipython3", 359 | "version": "3.11.9" 360 | } 361 | }, 362 | "nbformat": 4, 363 | "nbformat_minor": 2 364 | } 365 | -------------------------------------------------------------------------------- /dev: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 3 | 4 | """Dev task launcher.""" 5 | 6 | import argparse 7 | import datetime 8 | import os 9 | import subprocess 10 | import sys 11 | from pathlib import Path 12 | from typing import Any, Callable, Iterable, List, Optional, TypeVar 13 | 14 | # Utilities 15 | 16 | 17 | def run(command: Iterable[Any]) -> None: 18 | """Run a command, terminating on failure.""" 19 | cmd = [str(arg) for arg in command if arg is not None] 20 | print("$ " + " ".join(cmd), file=sys.stderr) 21 | environ = os.environ.copy() 22 | environ["PYTHONPATH"] = f"{os.getcwd()}:{environ.get('PYTHONPATH', '')}" 23 | exit_code = subprocess.call(cmd, env=environ) 24 | if exit_code: 25 | sys.exit(exit_code) 26 | 27 | 28 | T = TypeVar("T") 29 | 30 | 31 | def cli(*args: Any, **kwargs: Any) -> Callable[[T], T]: 32 | """Declare a CLI command / arguments for that command.""" 33 | 34 | def wrap(func: T) -> T: 35 | if not hasattr(func, "cli_args"): 36 | setattr(func, "cli_args", []) 37 | if args or kwargs: 38 | getattr(func, "cli_args").append((args, kwargs)) 39 | return func 40 | 41 | return wrap 42 | 43 | 44 | # Commands 45 | 46 | PYTHON_ROOTS = ["unit_scaling", "dev", "examples"] 47 | 48 | 49 | @cli("-k", "--filter") 50 | def tests(filter: Optional[str]) -> None: 51 | """run Python tests""" 52 | run( 53 | [ 54 | "python", 55 | "-m", 56 | "pytest", 57 | "unit_scaling", 58 | None if filter else "--cov=unit_scaling", 59 | *(["-k", filter] if filter else []), 60 | ] 61 | ) 62 | 63 | 64 | @cli("commands", nargs="*") 65 | def python(commands: List[Any]) -> None: 66 | """run Python with the current directory on PYTHONPATH, for development""" 67 | run(["python"] + commands) 68 | 69 | 70 | @cli() 71 | def lint() -> None: 72 | """run static analysis""" 73 | run(["python", "-m", "flake8", *PYTHON_ROOTS]) 74 | run(["python", "-m", "mypy", *PYTHON_ROOTS]) 75 | 76 | 77 | @cli("--check", action="store_true") 78 | def format(check: bool) -> None: 79 | """autoformat all sources""" 80 | run(["python", "-m", "black", "--check" if check else None, *PYTHON_ROOTS]) 81 | run(["python", "-m", "isort", "--check" if check else None, *PYTHON_ROOTS]) 82 | 83 | 84 | @cli() 85 | def copyright() -> None: 86 | """check for Graphcore copyright headers on relevant files""" 87 | command = ( 88 | f"find {' '.join(PYTHON_ROOTS)} -type f -not -name *.pyc -not -name *.json" 89 | " -not -name .gitignore -not -name *_version.py" 90 | " | xargs grep -L 'Copyright (c) 202. Graphcore Ltd[.] All rights reserved[.]'" 91 | ) 92 | print(f"$ {command}", file=sys.stderr) 93 | # Note: grep exit codes are not consistent between versions, so we don't use 94 | # check=True 95 | output = ( 96 | subprocess.run( 97 | command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT 98 | ) 99 | .stdout.decode() 100 | .strip() 101 | ) 102 | if output: 103 | print( 104 | "Error - failed copyright header check in:\n " 105 | + output.replace("\n", "\n "), 106 | file=sys.stderr, 107 | ) 108 | print("Template(s):") 109 | comment_prefixes = { 110 | {".cpp": "//"}.get(Path(f).suffix, "#") for f in output.split("\n") 111 | } 112 | for prefix in comment_prefixes: 113 | print( 114 | ( 115 | f"{prefix} Copyright (c) {datetime.datetime.now().year}" 116 | " Graphcore Ltd. All rights reserved." 117 | ), 118 | file=sys.stderr, 119 | ) 120 | sys.exit(1) 121 | 122 | 123 | @cli() 124 | def doc() -> None: 125 | """generate API documentation""" 126 | subprocess.call(["rm", "-r", "docs/generated", "docs/_build"]) 127 | run( 128 | [ 129 | "make", 130 | "-C", 131 | "docs", 132 | "html", 133 | ] 134 | ) 135 | 136 | 137 | @cli( 138 | "-s", 139 | "--skip", 140 | nargs="*", 141 | default=[], 142 | choices=["tests", "lint", "format", "copyright"], 143 | help="commands to skip", 144 | ) 145 | def ci(skip: List[str] = []) -> None: 146 | """run all continuous integration tests & checks""" 147 | if "tests" not in skip: 148 | tests(filter=None) 149 | if "lint" not in skip: 150 | lint() 151 | if "format" not in skip: 152 | format(check=True) 153 | if "copyright" not in skip: 154 | copyright() 155 | if "doc" not in skip: 156 | doc() 157 | 158 | 159 | # Script 160 | 161 | 162 | def _main() -> None: 163 | # Build an argparse command line by finding globals in the current module 164 | # that are marked via the @cli() decorator. Each one becomes a subcommand 165 | # running that function, usage "$ ./dev fn_name ...args". 166 | parser = argparse.ArgumentParser(description=__doc__) 167 | parser.set_defaults(command=ci) 168 | 169 | subs = parser.add_subparsers() 170 | for key, value in globals().items(): 171 | if hasattr(value, "cli_args"): 172 | sub = subs.add_parser(key.replace("_", "-"), help=value.__doc__) 173 | for args, kwargs in value.cli_args: 174 | sub.add_argument(*args, **kwargs) 175 | sub.set_defaults(command=value) 176 | 177 | cli_args = vars(parser.parse_args()) 178 | command = cli_args.pop("command") 179 | command(**cli_args) 180 | 181 | 182 | if __name__ == "__main__": 183 | _main() 184 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Read the Docs Inc. All rights reserved. 2 | 3 | # Minimal makefile for Sphinx documentation 4 | # 5 | 6 | # You can set these variables from the command line, and also 7 | # from the environment for the first two. 8 | SPHINXOPTS ?= 9 | SPHINXBUILD ?= sphinx-build 10 | SOURCEDIR = . 11 | BUILDDIR = _build 12 | 13 | # Put it first so that "make" without argument is like "make help". 14 | help: 15 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 16 | 17 | .PHONY: help Makefile 18 | 19 | # Catch-all target: route all unknown targets to Sphinx using the new 20 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 21 | %: Makefile 22 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 23 | -------------------------------------------------------------------------------- /docs/_static/animation.html: -------------------------------------------------------------------------------- 1 | 5 |

6 | -------------------------------------------------------------------------------- /docs/_static/scales.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/unit-scaling/215ba6395230c375156aa139dcd4db950ad19439/docs/_static/scales.png -------------------------------------------------------------------------------- /docs/_templates/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | .. 2 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 3 | # Copyright (c) 2007-2023 by the Sphinx team. All rights reserved. 4 | 5 | {{ fullname | escape | underline}} 6 | 7 | .. currentmodule:: {{ module }} 8 | 9 | .. autoclass:: {{ objname }} 10 | :members: 11 | :inherited-members: Module 12 | :exclude-members: extra_repr, forward 13 | -------------------------------------------------------------------------------- /docs/_templates/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | .. 2 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 3 | # Copyright (c) 2007-2023 by the Sphinx team. All rights reserved. 4 | 5 | {{ fullname | escape | underline}} 6 | 7 | .. automodule:: {{ fullname }} 8 | 9 | {% block aliases %} 10 | {% if fullname == 'unit_scaling.constraints' %} 11 | .. rubric:: {{ _('Type Aliases') }} 12 | 13 | .. autosummary:: 14 | :toctree: 15 | 16 | BinaryConstraint 17 | TernaryConstraint 18 | VariadicConstraint 19 | {% endif %} 20 | {% endblock %} 21 | 22 | {% block attributes %} 23 | {% if attributes %} 24 | .. rubric:: {{ _('Module Attributes') }} 25 | 26 | .. autosummary:: 27 | :toctree: 28 | {% for item in attributes %} 29 | {{ item }} 30 | {%- endfor %} 31 | {% endif %} 32 | {% endblock %} 33 | 34 | {% block functions %} 35 | .. rubric:: {{ _('Functions') }} 36 | 37 | .. autosummary:: 38 | :toctree: 39 | {% for item in functions %} 40 | {{ item }} 41 | {%- endfor %} 42 | {% endblock %} 43 | 44 | {% block classes %} 45 | {% if classes %} 46 | .. rubric:: {{ _('Classes') }} 47 | 48 | .. autosummary:: 49 | :toctree: 50 | :template: custom-class-template.rst 51 | {% for item in classes %} 52 | {{ item }} 53 | {%- endfor %} 54 | {% endif %} 55 | {% endblock %} 56 | 57 | {% block exceptions %} 58 | {% if exceptions %} 59 | .. rubric:: {{ _('Exceptions') }} 60 | 61 | .. autosummary:: 62 | :toctree: 63 | {% for item in exceptions %} 64 | {{ item }} 65 | {%- endfor %} 66 | {% endif %} 67 | {% endblock %} 68 | 69 | {% block modules %} 70 | {% if modules %} 71 | .. rubric:: Modules 72 | 73 | .. autosummary:: 74 | :toctree: 75 | :template: custom-module-template.rst 76 | :recursive: 77 | {% for item in modules %} 78 | {% if "test" not in item and "docs" not in item %} 79 | {{ item }} 80 | {% endif %} 81 | {%- endfor %} 82 | {% endif %} 83 | {% endblock %} 84 | -------------------------------------------------------------------------------- /docs/api_reference.rst: -------------------------------------------------------------------------------- 1 | API reference 2 | ============= 3 | 4 | :code:`unit-scaling` is implemented using thin wrappers around existing :code:`torch.nn` 5 | classes and functions. Documentation also inherits from the standard PyTorch docs, with 6 | modifications for scaling. Note that some docs may no longer be relevant but are 7 | nevertheless inherited. 8 | 9 | The API is built to mirror :code:`torch.nn` as closely as possible, such that PyTorch 10 | classes and functions can easily be swapped-out for their unit-scaled equivalents. 11 | 12 | For PyTorch code which uses the following imports: 13 | 14 | .. code-block:: 15 | 16 | from torch import nn 17 | from torch.nn import functional as F 18 | 19 | Unit scaling can be applied by first adding: 20 | 21 | .. code-block:: 22 | 23 | import unit_scaling as uu 24 | from unit_scaling import functional as U 25 | 26 | and then replacing the letters :code:`nn` with :code:`uu` and 27 | :code:`F` with :code:`U`, for those classes/functions to be unit-scaled 28 | (assuming they are supported). 29 | 30 | Click below for the full documentation: 31 | 32 | .. autosummary:: 33 | :toctree: generated 34 | :template: custom-module-template.rst 35 | :recursive: 36 | 37 | unit_scaling 38 | unit_scaling.analysis 39 | unit_scaling.constraints 40 | unit_scaling.formats 41 | unit_scaling.functional 42 | unit_scaling.optim 43 | unit_scaling.scale 44 | unit_scaling.transforms 45 | unit_scaling.transforms.utils 46 | unit_scaling.utils 47 | unit_scaling.core.functional 48 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # Copyright (c) 2022 Read the Docs Inc. All rights reserved. 3 | 4 | # Configuration file for the Sphinx documentation builder. 5 | # 6 | # For the full list of built-in configuration values, see the documentation: 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 8 | 9 | # Setup based on https://example-sphinx-basic.readthedocs.io 10 | import os 11 | import sys 12 | from unit_scaling._version import __version__ 13 | 14 | sys.path.insert(0, os.path.abspath("..")) 15 | 16 | # -- Project information ----------------------------------------------------- 17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 18 | 19 | project = "unit-scaling" 20 | copyright = "(c) 2023 Graphcore Ltd. All rights reserved" 21 | author = "Charlie Blake, Douglas Orr" 22 | version = __version__ 23 | 24 | # -- General configuration --------------------------------------------------- 25 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 26 | 27 | extensions = [ 28 | "sphinx.ext.duration", 29 | "sphinx.ext.doctest", 30 | "sphinx.ext.autodoc", 31 | "sphinx.ext.autosummary", 32 | "sphinx.ext.intersphinx", 33 | "sphinx.ext.napoleon", # support for google-style docstrings 34 | "myst_parser", # support for including markdown files in .rst files (e.g. readme) 35 | "sphinx.ext.viewcode", # adds source code to docs 36 | "sphinx.ext.autosectionlabel", # links to sections in the same document 37 | "sphinx.ext.mathjax", # equations 38 | ] 39 | 40 | autosummary_generate = True 41 | autosummary_imported_members = True 42 | autosummary_ignore_module_all = False 43 | napoleon_google_docstring = True 44 | napoleon_numpy_docstring = False 45 | 46 | numfig_format = { 47 | "section": "Section {number}. {name}", 48 | "figure": "Fig. %s", 49 | "table": "Table %s", 50 | "code-block": "Listing %s", 51 | } 52 | 53 | intersphinx_mapping = { 54 | "rtd": ("https://docs.readthedocs.io/en/stable/", None), 55 | "python": ("https://docs.python.org/3/", None), 56 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None), 57 | "pytorch": ("https://pytorch.org/docs/stable/", None), 58 | } 59 | intersphinx_disabled_domains = ["std"] 60 | 61 | autodoc_type_aliases = { 62 | "BinaryConstraint": "unit_scaling.constraints.BinaryConstraint", 63 | "TernaryConstraint": "unit_scaling.constraints.TernaryConstraint", 64 | "VariadicConstraint": "unit_scaling.constraints.VariadicConstraint", 65 | } # make docgen output name of alias rather than definition. 66 | 67 | templates_path = ["_templates"] 68 | 69 | # -- Options for EPUB output 70 | epub_show_urls = "footnote" 71 | 72 | # List of patterns, relative to source directory, that match files and 73 | # directories to ignore when looking for source files. 74 | # This pattern also affects html_static_path and html_extra_path. 75 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 76 | 77 | html_favicon = "_static/scales.png" 78 | 79 | # -- Options for HTML output ------------------------------------------------- 80 | 81 | # The theme to use for HTML and HTML Help pages. See the documentation for 82 | # a list of builtin themes. 83 | # 84 | html_theme = "sphinx_rtd_theme" 85 | 86 | # Add any paths that contain custom static files (such as style sheets) here, 87 | # relative to this directory. They are copied after the builtin static files, 88 | # so a file named "default.css" will overwrite the builtin "default.css". 89 | html_static_path = ["_static"] 90 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Unit Scaling 2 | ============ 3 | 4 | Welcome to the :code:`unit-scaling` library. This library is designed to facilitate 5 | the use of the *unit scaling* and *u-µP* methods, as outlined in the papers 6 | `Unit Scaling: Out-of-the-Box Low-Precision Training (ICML, 2023) 7 | `_ and 8 | `u-μP: The Unit-Scaled Maximal Update Parametrization 9 | `_ 10 | 11 | For a demonstration of the library, see `u-μP using the unit_scaling library 12 | `_ — a notebook showing the definition and training of a u-µP language model, comparing against Standard Parametrization (SP). 13 | 14 | Installation 15 | ------------ 16 | 17 | To install :code:`unit-scaling`, run: 18 | 19 | .. code-block:: 20 | 21 | pip install unit-scaling 22 | 23 | Getting Started 24 | --------------- 25 | 26 | We recommend that new users get started with :numref:`User Guide`. 27 | 28 | A reference outlining our API can be found at :numref:`API reference`. 29 | 30 | The following video gives a broad overview of the workings of unit scaling. 31 | 32 | .. raw:: html 33 | :file: _static/animation.html 34 | 35 | .. Note:: The library is currently in its *beta* release. 36 | Some features have yet to be implemented and occasional bugs may be present. 37 | We're keen to help users with any problems they encounter. 38 | 39 | `The following slides `_ also give an overview of u-µP. 40 | 41 | Development 42 | ----------- 43 | 44 | For those who wish to develop on the :code:`unit-scaling` codebase, clone or fork our 45 | `GitHub repo `_ and follow the 46 | instructions in our :doc:`developer guide `. 47 | 48 | .. toctree:: 49 | :caption: Contents 50 | :numbered: 51 | :maxdepth: 3 52 | 53 | User guide 54 | Developer guide 55 | Limitations 56 | API reference 57 | -------------------------------------------------------------------------------- /docs/limitations.rst: -------------------------------------------------------------------------------- 1 | Limitations 2 | =========== 3 | 4 | :code:`unit-scaling` is a new library and (despite our best efforts!) we can't guarantee 5 | it will be bug-free or feature-complete. We're keen to assist anyone who wants to use 6 | the library, and help them work through any issues. 7 | 8 | Known limitations of the library include: 9 | 10 | 1. **Op coverage:** we've currently focussed on adding common transformer operations — other ops may be missing (though we can add most requested ops without difficulty). 11 | 2. **Using transforms with torch.compile:** currently our transforms (for example :code:`unit_scale`, :code:`simulate_fp8`) can't be used directly with :code:`torch.compile`. We provide a special compilation function to get around this: :code:`unit_scaling.transforms.compile` (see docs for more details), though this only works with :code:`unit_scale` and not :code:`simulate_fp8`. 12 | 3. **Distributed training:** although we suspect distributed training will still work reasonably well with the current library, we haven't tested this. 13 | 14 | This list is not exhaustive and we encourage you to get in touch if you have 15 | feature-requests not listed here. 16 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | :: Copyright (c) 2022 Read the Docs Inc. All rights reserved. 2 | @ECHO OFF 3 | 4 | pushd %~dp0 5 | 6 | REM Command file for Sphinx documentation 7 | 8 | if "%SPHINXBUILD%" == "" ( 9 | set SPHINXBUILD=sphinx-build 10 | ) 11 | set SOURCEDIR=. 12 | set BUILDDIR=_build 13 | 14 | %SPHINXBUILD% >NUL 2>NUL 15 | if errorlevel 9009 ( 16 | echo. 17 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 18 | echo.installed, then set the SPHINXBUILD environment variable to point 19 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 20 | echo.may add the Sphinx directory to PATH. 21 | echo. 22 | echo.If you don't have Sphinx installed, grab it from 23 | echo.https://www.sphinx-doc.org/ 24 | exit /b 1 25 | ) 26 | 27 | if "%1" == "" goto help 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/posts/almost_scaled_dot_product_attention.md: -------------------------------------------------------------------------------- 1 | # Almost-scaled dot-product attention 2 | 3 | **This post [has moved](https://graphcore-research.github.io/posts/almost_scaled/)**. 4 | 5 | Note that the approach and equations described in this post are legacy and do not reflect the current implementation of u-μP. Please see the code for a definitive reference. 6 | -------------------------------------------------------------------------------- /docs/u-muP_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/unit-scaling/215ba6395230c375156aa139dcd4db950ad19439/docs/u-muP_slides.pdf -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | -------------------------------------------------------------------------------- /examples/scale_analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | 4 | import unit_scaling as uu 5 | from unit_scaling.utils import analyse_module 6 | 7 | print("=== Unit-scaled Linear ===\n") 8 | 9 | batch_size = 2**8 10 | hidden_size = 2**10 11 | out_size = 2**10 12 | input = torch.randn(batch_size, hidden_size).requires_grad_() 13 | backward = torch.randn(batch_size, out_size) 14 | 15 | annotated_code = analyse_module(uu.Linear(hidden_size, out_size), input, backward) 16 | print(annotated_code) 17 | 18 | print("=== Unit-scaled MLP ===\n") 19 | 20 | batch_size = 2**8 21 | hidden_size = 2**10 22 | input = torch.randn(batch_size, hidden_size).requires_grad_() 23 | backward = torch.randn(batch_size, hidden_size) 24 | 25 | annotated_code = analyse_module(uu.MLP(hidden_size), input, backward) 26 | print(annotated_code) 27 | 28 | print("=== Unit-scaled MHSA ===\n") 29 | 30 | batch_size = 2**8 31 | seq_len = 2**6 32 | hidden_size = 2**6 33 | heads = 4 34 | dropout_p = 0.1 35 | input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_() 36 | backward = torch.randn(batch_size, seq_len, hidden_size) 37 | 38 | annotated_code = analyse_module( 39 | uu.MHSA(hidden_size, heads, is_causal=False, dropout_p=dropout_p), input, backward 40 | ) 41 | print(annotated_code) 42 | 43 | print("=== Unit-scaled Transformer Layer ===\n") 44 | 45 | batch_size = 2**8 46 | seq_len = 2**6 47 | hidden_size = 2**6 48 | heads = 4 49 | dropout_p = 0.1 50 | input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_() 51 | backward = torch.randn(batch_size, seq_len, hidden_size) 52 | 53 | annotated_code = analyse_module( 54 | uu.TransformerLayer( 55 | hidden_size, 56 | heads, 57 | mhsa_tau=0.1, 58 | mlp_tau=1.0, 59 | is_causal=False, 60 | dropout_p=dropout_p, 61 | ), 62 | input, 63 | backward, 64 | ) 65 | print(annotated_code) 66 | 67 | print("=== Unit-scaled Full Transformer Decoder ===\n") 68 | 69 | batch_size = 2**8 70 | seq_len = 2**6 71 | hidden_size = 2**6 72 | vocab_size = 2**12 73 | layers = 2 74 | heads = 4 75 | dropout_p = 0.1 76 | 77 | seq = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len + 1)) 78 | input_idxs = seq[:, :-1] 79 | labels = torch.roll(seq, -1, 1)[:, 1:] 80 | 81 | annotated_code = analyse_module( 82 | uu.TransformerDecoder(hidden_size, vocab_size, layers, heads, dropout_p), 83 | (input_idxs, labels), 84 | ) 85 | print(annotated_code) 86 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Configuration inspired by official pypa example: 2 | # https://github.com/pypa/sampleproject/blob/main/pyproject.toml 3 | 4 | [build-system] 5 | requires = ["setuptools>=68.2.2", "setuptools-scm"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "unit-scaling" 10 | description = "A library for unit scaling in PyTorch, based on the paper 'u-muP: The Unit-Scaled Maximal Update Parametrization.'" 11 | readme = "README.md" 12 | authors = [ 13 | { name = "Charlie Blake", email = "charlieb@graphcore.ai" }, 14 | { name = "Douglas Orr", email = "douglaso@graphcore.ai" }, 15 | ] 16 | requires-python = ">=3.9" 17 | classifiers = [ 18 | "Development Status :: 4 - Beta", 19 | "Intended Audience :: Developers", 20 | "Intended Audience :: Science/Research", 21 | "License :: OSI Approved :: Apache Software License", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Programming Language :: Python :: 3.12", 26 | "Programming Language :: Python :: 3.13", 27 | "Programming Language :: Python :: 3.14", 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | ] 30 | dependencies = [ 31 | "datasets", 32 | "docstring-parser", 33 | "einops", 34 | "numpy<2.0.0", 35 | "seaborn", 36 | "tabulate", 37 | "torch>=2.2", 38 | ] 39 | dynamic = ["version"] 40 | 41 | [project.urls] 42 | "Homepage" = "https://github.com/graphcore-research/unit-scaling/#readme" 43 | "Bug Reports" = "https://github.com/graphcore-research/unit-scaling/issues" 44 | "Source" = "https://github.com/graphcore-research/unit-scaling/" 45 | 46 | [project.optional-dependencies] 47 | dev = ["check-manifest"] 48 | test = ["pytest"] 49 | 50 | [tool.setuptools] 51 | packages = ["unit_scaling", "unit_scaling.core", "unit_scaling.transforms"] 52 | 53 | [tool.setuptools.dynamic] 54 | version = {attr = "unit_scaling._version.__version__"} 55 | 56 | [tool.setuptools_scm] 57 | version_file = "unit_scaling/_version.py" 58 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Look in pytorch-cpu first, then pypi second 2 | --index-url https://download.pytorch.org/whl/cpu 3 | --extra-index-url=https://pypi.org/simple 4 | 5 | # Same as pyproject.toml, but with versions locked-in 6 | datasets==3.1.0 7 | docstring-parser==0.16 8 | einops==0.8.0 9 | numpy==1.26.4 10 | seaborn==0.13.2 11 | tabulate==0.9.0 12 | torch==2.5.1+cpu 13 | 14 | # Additional dev requirements 15 | black==24.10.0 16 | flake8==7.1.1 17 | isort==5.13.2 18 | mypy==1.13.0 19 | myst-parser==4.0.0 20 | pandas-stubs==2.2.3.241009 21 | pytest==8.3.3 22 | pytest-cov==6.0.0 23 | setuptools==70.0.0 24 | sphinx==8.1.3 25 | sphinx-rtd-theme==3.0.1 26 | transformers==4.46.1 27 | triton==3.1.0 28 | types-Pygments==2.18.0.20240506 29 | types-tabulate==0.9.0.20240106 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [mypy] 2 | pretty = true 3 | show_error_codes = true 4 | strict = true 5 | check_untyped_defs = true 6 | 7 | # As torch.fx doesn't explicitly export many of its useful modules. 8 | [mypy-torch.fx] 9 | implicit_reexport = True 10 | 11 | [flake8] 12 | # See https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html 13 | max-line-length = 88 14 | extend-ignore = E203,E731 15 | 16 | [isort] 17 | profile = black 18 | 19 | [tool:pytest] 20 | addopts = --no-cov-on-fail 21 | 22 | [coverage:report] 23 | # fail_under = 100 24 | skip_covered = true 25 | show_missing = true 26 | exclude_lines = 27 | pragma: no cover 28 | raise NotImplementedError 29 | assert False 30 | -------------------------------------------------------------------------------- /unit_scaling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | """Unit-scaled versions of common torch.nn modules.""" 4 | 5 | # This all has to be done manually to keep mypy happy. 6 | # Removing the `--no-implicit-reexport` option ought to fix this, but doesn't appear to. 7 | 8 | from . import core, functional, optim, parameter 9 | from ._modules import ( 10 | GELU, 11 | MHSA, 12 | MLP, 13 | Conv1d, 14 | CrossEntropyLoss, 15 | DepthModuleList, 16 | DepthSequential, 17 | Dropout, 18 | Embedding, 19 | LayerNorm, 20 | Linear, 21 | LinearReadout, 22 | RMSNorm, 23 | SiLU, 24 | Softmax, 25 | TransformerDecoder, 26 | TransformerLayer, 27 | ) 28 | from ._version import __version__ 29 | from .analysis import visualiser 30 | from .core.functional import transformer_residual_scaling_rule 31 | from .parameter import MupType, Parameter 32 | 33 | __all__ = [ 34 | # Modules 35 | "Conv1d", 36 | "CrossEntropyLoss", 37 | "DepthModuleList", 38 | "DepthSequential", 39 | "Dropout", 40 | "Embedding", 41 | "GELU", 42 | "LayerNorm", 43 | "Linear", 44 | "LinearReadout", 45 | "MHSA", 46 | "MLP", 47 | "MupType", 48 | "RMSNorm", 49 | "SiLU", 50 | "Softmax", 51 | "TransformerDecoder", 52 | "TransformerLayer", 53 | # Modules 54 | "core", 55 | "functional", 56 | "optim", 57 | "parameter", 58 | # Functions 59 | "Parameter", 60 | "transformer_residual_scaling_rule", 61 | "visualiser", 62 | "__version__", 63 | ] 64 | -------------------------------------------------------------------------------- /unit_scaling/_internal_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import inspect 4 | import sys 5 | from typing import List 6 | 7 | 8 | def generate__all__(module_name: str, include_imports: bool = False) -> List[str]: 9 | """Generates the contents of __all__ by extracting every public function/class/etc. 10 | except those imported from other modules. Necessary for Sphinx docs.""" 11 | module = sys.modules[module_name] 12 | all = [] 13 | for name, member in inspect.getmembers(module): 14 | # Skip members imported from other modules and private members 15 | is_local = inspect.getmodule(member) == module 16 | if (include_imports or is_local) and not name.startswith("_"): 17 | all.append(name) 18 | return all 19 | -------------------------------------------------------------------------------- /unit_scaling/constraints.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | """Common scale-constraints used in unit-scaled operations.""" 4 | 5 | from __future__ import annotations # required for docs to alias type annotations 6 | 7 | import sys 8 | from math import pow, prod 9 | from typing import Optional, Tuple 10 | 11 | from ._internal_utils import generate__all__ 12 | 13 | 14 | def gmean(*scales: float) -> float: 15 | """Computes the geometric mean of the provided scales. Recommended for unit scaling. 16 | 17 | Args: 18 | scales: (*float): the group of constrained scales 19 | 20 | Returns: 21 | float: the geometric mean. 22 | """ 23 | return pow(prod(scales), (1 / len(scales))) 24 | 25 | 26 | def hmean(*scales: float) -> float: 27 | """Computes the harmonic mean of the provided scales. Used in Xavier/Glorot scaling. 28 | 29 | Args: 30 | scales: (*float): the group of constrained scales 31 | 32 | Returns: 33 | float: the harmonic mean. 34 | """ 35 | return 1 / (sum(1 / s for s in scales) / len(scales)) 36 | 37 | 38 | def amean(*scales: float) -> float: 39 | """Computes the arithmetic mean of the provided scales. 40 | 41 | Args: 42 | scales: (*float): the group of constrained scales 43 | 44 | Returns: 45 | float: the arithmetic mean. 46 | """ 47 | return sum(scales) / len(scales) 48 | 49 | 50 | def to_output_scale(output_scale: float, *grad_input_scale: float) -> float: 51 | """Assumes an output scale is provided and any number of grad input scales: 52 | `(output_scale, *grad_input_scales)`. Selects only `output_scale` as the chosen 53 | scaling factor. 54 | 55 | Args: 56 | output_scale (float): the scale of the op's output 57 | grad_input_scales (*float): the scales of the op's input gradients 58 | 59 | Returns: 60 | float: equal to `output_scale` 61 | """ 62 | return output_scale 63 | 64 | 65 | def to_grad_input_scale(output_scale: float, grad_input_scale: float) -> float: 66 | """Assumes two provided scales: `(output_scale, grad_input_scale)`. Selects only 67 | `grad_input_scale` as the chosen scaling factor. 68 | 69 | Args: 70 | output_scale (float): the scale of the op's output 71 | grad_input_scale (float): the scale of the op's input gradient 72 | 73 | Returns: 74 | float: equal to `grad_input_scale` 75 | """ 76 | return grad_input_scale 77 | 78 | 79 | def to_left_grad_scale( 80 | output_scale: float, left_grad_scale: float, right_grad_scale: float 81 | ) -> float: 82 | """Assumes three provided scales: 83 | `(output_scale, left_grad_scale, right_grad_scale)`. Selects only `left_grad_scale` 84 | as the chosen scaling factor. 85 | 86 | Args: 87 | output_scale (float): the scale of the op's output 88 | left_grad_scale (float): the scale of the op's left input gradient 89 | right_grad_scale (float): the scale of the op's right input gradient 90 | 91 | Returns: 92 | float: equal to `left_grad_scale` 93 | """ 94 | return left_grad_scale 95 | 96 | 97 | def to_right_grad_scale( 98 | output_scale: float, left_grad_scale: float, right_grad_scale: float 99 | ) -> float: 100 | """Assumes three provided scales: 101 | `(output_scale, left_grad_scale, right_grad_scale)`. Selects only `right_grad_scale` 102 | as the chosen scaling factor. 103 | 104 | Args: 105 | output_scale (float): the scale of the op's output 106 | left_grad_scale (float): the scale of the op's left input gradient 107 | right_grad_scale (float): the scale of the op's right input gradient 108 | 109 | Returns: 110 | float: equal to `right_grad_scale` 111 | """ 112 | return right_grad_scale 113 | 114 | 115 | def apply_constraint( 116 | constraint_name: Optional[str], *scales: float 117 | ) -> Tuple[float, ...]: 118 | """Retrieves the constraint function corresponding to `constraint_name` and applies 119 | it to the group of scales. This name must be that of one of the functions defined in 120 | this module. 121 | 122 | Args: 123 | constraint_name (Optional[str]): The name of the constraint function to be used. 124 | 125 | Raises: 126 | ValueError: if `constraint_name` is not that of a function in this module. 127 | 128 | Returns: 129 | Tuple[float, ...]: the scales after the constraint has been applied. 130 | """ 131 | if constraint_name is None or constraint_name == "": 132 | return scales 133 | constraint = getattr(sys.modules[__name__], constraint_name, None) 134 | if constraint is None: 135 | raise ValueError( 136 | f"Constraint: {constraint_name} is not a valid constraint (see" 137 | " unit_scaling.constraints for available options)." 138 | ) 139 | scale = constraint(*scales) 140 | return tuple(scale for _ in scales) 141 | 142 | 143 | __all__ = generate__all__(__name__) 144 | -------------------------------------------------------------------------------- /unit_scaling/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | 3 | """Core components for advanced library usage.""" 4 | 5 | from . import functional 6 | 7 | __all__ = ["functional"] 8 | -------------------------------------------------------------------------------- /unit_scaling/core/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | 3 | """Core functionality for implementing `unit_scaling.functional`.""" 4 | 5 | import math 6 | from typing import Any, Callable, Optional, Tuple 7 | 8 | from torch import Tensor 9 | 10 | from .._internal_utils import generate__all__ 11 | from ..constraints import apply_constraint 12 | from ..docs import binary_constraint_docstring, format_docstring 13 | from ..scale import scale_bwd, scale_fwd 14 | 15 | 16 | @format_docstring(binary_constraint_docstring) 17 | def scale_elementwise( 18 | f: Callable[..., Tensor], 19 | output_scale: float, 20 | grad_input_scale: float, 21 | constraint: Optional[str] = "to_output_scale", 22 | ) -> Callable[..., Tensor]: 23 | """Transforms an element-wise function into a scaled version. 24 | 25 | Args: 26 | f (Callable[..., Tensor]): the element-wise function to be scaled. Should take 27 | as its first input a `Tensor`, followed by `*args, **kwargs`. 28 | output_scale (float): the scale to be applied to the output 29 | grad_input_scale (float): the scale to be applied to the grad of the input 30 | {0} 31 | 32 | Returns: 33 | Callable[..., Tensor]: the scaled function 34 | """ 35 | output_scale, grad_input_scale = apply_constraint( 36 | constraint, output_scale, grad_input_scale 37 | ) 38 | 39 | def scaled_f(input: Tensor, *args: Any, **kwargs: Any) -> Tensor: 40 | input = scale_bwd(input, grad_input_scale) 41 | output = f(input, *args, **kwargs) 42 | return scale_fwd(output, output_scale) 43 | 44 | return scaled_f 45 | 46 | 47 | def logarithmic_interpolation(alpha: float, lower: float, upper: float) -> float: 48 | """Interpolate between lower and upper with logarithmic spacing (constant ratio). 49 | 50 | For example:: 51 | 52 | logarithmic_interpolation(alpha=0.0, lower=1/1000, upper=1/10) == 1/1000 53 | logarithmic_interpolation(alpha=0.5, lower=1/1000, upper=1/10) == 1/100 54 | logarithmic_interpolation(alpha=1.0, lower=1/1000, upper=1/10) == 1/10 55 | 56 | Args: 57 | alpha (float): interpolation weight (0=lower, 1=upper) 58 | lower (float): lower limit (alpha=0), must be > 0 59 | upper (float): upper limit (alpha=1), must be > 0 60 | 61 | Returns: 62 | float: interpolated value 63 | """ 64 | return math.exp(alpha * math.log(upper) + (1 - alpha) * math.log(lower)) 65 | 66 | 67 | def rms( 68 | x: Tensor, 69 | dims: Optional[Tuple[int, ...]] = None, 70 | keepdim: bool = False, 71 | eps: float = 0.0, 72 | ) -> Tensor: 73 | """Compute the RMS :math:`\\sqrt{\\mathrm{mean}(x^2) + \\epsilon}` of a tensor.""" 74 | mean = x.float().pow(2).mean(dims, keepdim=keepdim) 75 | if eps: 76 | mean = mean + eps 77 | return mean.sqrt().to(x.dtype) 78 | 79 | 80 | ResidualScalingFn = Callable[[int, int], float] 81 | 82 | 83 | def transformer_residual_scaling_rule( 84 | residual_mult: float = 1.0, residual_attn_ratio: float = 1.0 85 | ) -> ResidualScalingFn: 86 | """Compute the residual tau ratios for the default transformer rule. 87 | 88 | For a transformer stack that starts with embedding, then alternates 89 | between attention and MLP layers, this rule ensures: 90 | 91 | - Every attention layer contributes the same scale. 92 | - Every MLP layer contributes the same scale. 93 | - The ratio of the average (variance) contribution of all attention 94 | and all MLP layers to the embedding layer is `residual_mult`. 95 | - The ratio of Attn to MLP contributions is `residual_attn_ratio`. 96 | 97 | If both hyperparameters are set to 1.0, the total contribution of 98 | embedding, attention and MLP layers are all equal. 99 | 100 | This scheme is described in Appendix G of the u-μP paper, 101 | 102 | Args: 103 | residual_mult (float, optional): contribution of residual layers 104 | (relative to an initial/embedding layer). 105 | residual_attn_ratio (float, optional): contribution of attn 106 | layers relative to FFN layers. 107 | 108 | Returns: 109 | :code:`fn(index, layers) -> tau` : a function for calculating tau 110 | at a given depth. 111 | """ 112 | alpha_mlp = residual_mult * (2 / (1 + residual_attn_ratio**2)) ** 0.5 113 | alpha_attn = residual_attn_ratio * alpha_mlp 114 | 115 | def _tau(index: int, layers: int) -> float: 116 | n_attn = (index + 1) // 2 117 | n_mlp = index // 2 118 | tau = (alpha_attn if (index % 2) == 0 else alpha_mlp) / ( 119 | layers / 2 + n_attn * alpha_attn**2 + n_mlp * alpha_mlp**2 120 | ) ** 0.5 121 | return tau # type:ignore[no-any-return] 122 | 123 | return _tau 124 | 125 | 126 | __all__ = generate__all__(__name__) 127 | -------------------------------------------------------------------------------- /unit_scaling/docs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import inspect 4 | from functools import wraps 5 | from itertools import zip_longest 6 | from typing import Any, Callable, Iterable, List, Optional, Type, TypeVar 7 | 8 | from docstring_parser.google import ( 9 | DEFAULT_SECTIONS, 10 | GoogleParser, 11 | Section, 12 | SectionType, 13 | compose, 14 | ) 15 | 16 | T = TypeVar("T") 17 | 18 | 19 | def _validate( 20 | f: Callable[..., T], unsupported_args: Iterable[str] = {} 21 | ) -> Callable[..., T]: 22 | """Wraps the supplied function in a check to ensure its arguments aren't in the 23 | unsupported args list. Unsupported args are by nature optional (they have 24 | a default value). It is assumed this default is valid, but all other values are 25 | invalid.""" 26 | 27 | argspec = inspect.getfullargspec(f) 28 | 29 | # argspec.defaults is a tuple of default arguments. These may begin at an offset 30 | # relative to rgspec.args due to args without a default. To zip these properly the 31 | # lists are reversed, zipped, and un-reversed, with missing values filled with `...` 32 | rev_args = reversed(argspec.args) 33 | rev_defaults = reversed(argspec.defaults) if argspec.defaults else [] 34 | rev_arg_default_pairs = list(zip_longest(rev_args, rev_defaults, fillvalue=...)) 35 | default_kwargs = dict(reversed(rev_arg_default_pairs)) 36 | 37 | for arg in unsupported_args: 38 | if arg not in default_kwargs: 39 | raise ValueError(f"unsupported arg '{arg}' is not valid.") 40 | if default_kwargs[arg] is ...: 41 | raise ValueError(f"unsupported arg '{arg}' has no default value") 42 | 43 | @wraps(f) 44 | def _validate_args_supported(*args: Any, **kwargs: Any) -> T: 45 | arg_values = dict(zip(argspec.args, args)) 46 | full_kwargs = {**arg_values, **kwargs} 47 | for arg_name, arg_value in full_kwargs.items(): 48 | if arg_name in unsupported_args: 49 | arg_default_value = default_kwargs[arg_name] 50 | if arg_value != arg_default_value: 51 | raise ValueError( 52 | f"Support for the '{arg_name}' argument has not been" 53 | " implemented for the unit-scaling library." 54 | " Please remove it or replace it with its default value." 55 | ) 56 | return f(*args, **kwargs) 57 | 58 | return _validate_args_supported 59 | 60 | 61 | def _get_docstring_from_target( 62 | source: T, 63 | target: Any, 64 | short_description: Optional[str] = None, 65 | add_args: Optional[List[str]] = None, 66 | unsupported_args: Iterable[str] = {}, 67 | ) -> T: 68 | """Takes the docstring from `target`, modifies it, and applies it to `source`.""" 69 | 70 | # Make the parser aware of the Shape and Examples sections (standard in torch docs) 71 | parser_sections = DEFAULT_SECTIONS + [ 72 | Section("Shape", "shape", SectionType.SINGULAR), 73 | Section("Examples:", "examples", SectionType.SINGULAR), 74 | ] 75 | parser = GoogleParser(sections=parser_sections) 76 | docstring = parser.parse(target.__doc__) 77 | docstring.short_description = short_description 78 | if docstring.long_description: 79 | docstring.long_description += "\n" # fixes "Args:" section merging 80 | 81 | for param in docstring.params: 82 | if param.arg_name in unsupported_args and param.description is not None: 83 | param.description = ( 84 | "**[not supported by unit-scaling]** " + param.description 85 | ) 86 | 87 | if add_args: 88 | for arg_str in add_args: 89 | # Parse the additional args strings and add them to the docstring object 90 | param_meta = parser._build_meta(arg_str, "Args") 91 | docstring.meta.append(param_meta) 92 | 93 | source.__doc__ = compose(docstring) # docstring object to actual string 94 | return source 95 | 96 | 97 | def inherit_docstring( 98 | short_description: Optional[str] = None, 99 | add_args: Optional[List[str]] = None, 100 | unsupported_args: Iterable[str] = {}, 101 | ) -> Callable[[Type[T]], Type[T]]: 102 | """Returns a decorator which causes the wrapped class to inherit its parent 103 | docstring, with the specified modifications applied. 104 | 105 | Args: 106 | short_description (Optional[str], optional): Replaces the top one-line 107 | description in the parent docstring with the one supplied. Defaults to None. 108 | add_args (Optional[List[str]], optional): Appends the supplied argument strings 109 | to the list of arguments. Defaults to None. 110 | unsupported_args (Iterable[str]): A list of arguments which are not supported. 111 | Documentation is updated and runtime checks added to enforce this. 112 | 113 | Returns: 114 | Callable[[Type], Type]: The decorator used to wrap the child class. 115 | """ 116 | 117 | def decorator(cls: Type[T]) -> Type[T]: 118 | parent = cls.mro()[1] 119 | source = _get_docstring_from_target( 120 | source=cls, 121 | target=parent, 122 | short_description=short_description, 123 | add_args=add_args, 124 | unsupported_args=unsupported_args, 125 | ) 126 | source.__init__ = _validate(source.__init__, unsupported_args) # type: ignore 127 | return source 128 | 129 | return decorator 130 | 131 | 132 | def docstring_from( 133 | target: Callable[..., T], 134 | short_description: Optional[str] = None, 135 | add_args: Optional[List[str]] = None, 136 | unsupported_args: Iterable[str] = {}, 137 | ) -> Callable[[Callable[..., T]], Callable[..., T]]: 138 | """Returns a decorator which causes the wrapped object to take the docstring from 139 | the target object, with the specified modifications applied. 140 | 141 | Args: 142 | target (Any): The object to take the docstring from. 143 | short_description (Optional[str], optional): Replaces the top one-line 144 | description in the parent docstring with the one supplied. Defaults to None. 145 | add_args (Optional[List[str]], optional): Appends the supplied argument strings 146 | to the list of arguments. Defaults to None. 147 | unsupported_args (Iterable[str]): A list of arguments which are not supported. 148 | Documentation is updated and runtime checks added to enforce this. 149 | 150 | Returns: 151 | Callable[[Callable], Callable]: The decorator used to wrap the child object. 152 | """ 153 | 154 | def decorator(source: Callable[..., T]) -> Callable[..., T]: 155 | source = _get_docstring_from_target( 156 | source=source, 157 | target=target, 158 | short_description=short_description, 159 | add_args=add_args, 160 | unsupported_args=unsupported_args, 161 | ) 162 | return _validate(source, unsupported_args) 163 | 164 | return decorator 165 | 166 | 167 | def format_docstring(*args: str) -> Callable[[Callable[..., T]], Callable[..., T]]: 168 | """Returns a decorator that applies `cls.__doc__.format(*args)` to the target class. 169 | 170 | Args: 171 | args: (*str): The arguments to be passed to the docstrings `.format()` method. 172 | 173 | Returns: 174 | Callable[[Type], Type]: A decorator to format the docstring. 175 | """ 176 | 177 | def f(cls: T) -> T: 178 | if isinstance(cls.__doc__, str): 179 | cls.__doc__ = cls.__doc__.format(*args) 180 | return cls 181 | 182 | return f 183 | 184 | 185 | binary_constraint_docstring = ( 186 | "constraint (Optional[str], optional): The name of the constraint function to be" 187 | " applied to the outputs & input gradient. In this case, the constraint name must" 188 | " be one of:" 189 | " [None, 'gmean', 'hmean', 'amean', 'to_output_scale', 'to_grad_input_scale']" 190 | " (see `unit_scaling.constraints` for details on these constraint functions)." 191 | " Defaults to `gmean`." 192 | ) 193 | 194 | ternary_constraint_docstring = ( 195 | "constraint (Optional[str], optional): The name of the constraint function to be" 196 | " applied to the outputs & input gradients. In this case, the constraint name must" 197 | " be one of:" 198 | " [None, 'gmean', 'hmean', 'amean', 'to_output_scale', 'to_left_grad_scale'," 199 | " to_right_grad_scale]" 200 | " (see `unit_scaling.constraints` for details on these constraint functions)." 201 | " Defaults to `gmean`." 202 | ) 203 | 204 | variadic_constraint_docstring = ( 205 | "constraint (Optional[str], optional): The name of the constraint function to be" 206 | " applied to the outputs & input gradients. In this case, the constraint name must" 207 | " be one of:" 208 | " [None, 'gmean', 'hmean', 'amean', 'to_output_scale']" 209 | " (see `unit_scaling.constraints` for details on these constraint functions)." 210 | " Defaults to `gmean`." 211 | ) 212 | 213 | 214 | def mult_docstring(name: str = "mult") -> str: 215 | return ( 216 | f"{name} (float, optional): a multiplier to be applied to change the shape" 217 | " of a nonlinear function. Typically, high multipliers (> 1) correspond to a" 218 | " 'sharper' (low temperature) function, while low multipliers (< 1) correspond" 219 | " to a 'flatter' (high temperature) function." 220 | ) 221 | -------------------------------------------------------------------------------- /unit_scaling/formats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | """Classes for simulating (non-standard) number formats.""" 4 | 5 | from dataclasses import dataclass 6 | from typing import Tuple, cast 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from ._internal_utils import generate__all__ 12 | 13 | Shape = Tuple[int, ...] 14 | 15 | 16 | @dataclass 17 | class FPFormat: 18 | """Generic representation of a floating-point number format.""" 19 | 20 | exponent_bits: int 21 | mantissa_bits: int 22 | rounding: str = "stochastic" # "stochastic|nearest" 23 | srbits: int = 0 # Number of bits for stochastic rounding, zero => use all bits 24 | 25 | def __post_init__(self) -> None: 26 | assert self.exponent_bits >= 2, "FPFormat requires at least 2 exponent bits" 27 | assert ( 28 | self.srbits == 0 or self.rounding == "stochastic" 29 | ), "Nonzero srbits for non-stochastic rounding" 30 | if self.srbits == 0 and self.rounding == "stochastic": 31 | self.srbits = 23 - self.mantissa_bits 32 | 33 | @property 34 | def bits(self) -> int: 35 | """The number of bits used by the format.""" 36 | return 1 + self.exponent_bits + self.mantissa_bits 37 | 38 | def __str__(self) -> str: # pragma: no cover 39 | return ( 40 | f"E{self.exponent_bits}M{self.mantissa_bits}-" 41 | + dict(stochastic="SR", nearest="RN")[self.rounding] 42 | ) 43 | 44 | @property 45 | def max_absolute_value(self) -> float: 46 | """The maximum absolute value representable by the format.""" 47 | max_exponent = 2 ** (self.exponent_bits - 1) - 1 48 | return cast(float, 2**max_exponent * (2 - 2**-self.mantissa_bits)) 49 | 50 | @property 51 | def min_absolute_normal(self) -> float: 52 | """The minimum absolute normal value representable by the format.""" 53 | min_exponent = 1 - 2 ** (self.exponent_bits - 1) 54 | return cast(float, 2**min_exponent) 55 | 56 | @property 57 | def min_absolute_subnormal(self) -> float: 58 | """The minimum absolute subnormal value representable by the format.""" 59 | return self.min_absolute_normal * 2.0**-self.mantissa_bits 60 | 61 | def quantise(self, x: Tensor) -> Tensor: 62 | """Non-differentiably quantise the given tensor in this format.""" 63 | absmax = self.max_absolute_value 64 | downscale = 2.0 ** (127 - 2 ** (self.exponent_bits - 1)) 65 | mask = torch.tensor(2 ** (23 - self.mantissa_bits) - 1, device=x.device) 66 | if self.rounding == "stochastic": 67 | srbitsbar = 23 - self.mantissa_bits - self.srbits 68 | offset = ( 69 | torch.randint( 70 | 0, 2**self.srbits, x.shape, dtype=torch.int32, device=x.device 71 | ) 72 | << srbitsbar 73 | ) 74 | # Correct for bias. We can do this only for srbits < 23-mantissa_bits, 75 | # but it is only likely to matter when srbits is small. 76 | if srbitsbar > 0: 77 | offset += 1 << (srbitsbar - 1) 78 | 79 | elif self.rounding == "nearest": 80 | offset = mask // 2 81 | else: # pragma: no cover 82 | raise ValueError( 83 | f'Unexpected FPFormat(rounding="{self.rounding}"),' 84 | ' expected "stochastic" or "nearest"' 85 | ) 86 | q = x.to(torch.float32) 87 | q = torch.clip(x, -absmax, absmax) 88 | q /= downscale 89 | q = ((q.view(torch.int32) + offset) & ~mask).view(torch.float32) 90 | q *= downscale 91 | return q.to(x.dtype) 92 | 93 | def quantise_fwd(self, x: Tensor) -> Tensor: 94 | """Quantise the given tensor in the forward pass only.""" 95 | 96 | class QuantiseForward(torch.autograd.Function): 97 | @staticmethod 98 | def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor) -> Tensor: 99 | return self.quantise(x) 100 | 101 | @staticmethod 102 | def backward( # type:ignore[override] 103 | ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor 104 | ) -> Tensor: 105 | return grad_y 106 | 107 | return QuantiseForward.apply(x) # type: ignore 108 | 109 | def quantise_bwd(self, x: Tensor) -> Tensor: 110 | """Quantise the given tensor in the backward pass only.""" 111 | 112 | class QuantiseBackward(torch.autograd.Function): 113 | @staticmethod 114 | def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor) -> Tensor: 115 | return x 116 | 117 | @staticmethod 118 | def backward( # type:ignore[override] 119 | ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor 120 | ) -> Tensor: 121 | return self.quantise(grad_y) 122 | 123 | return QuantiseBackward.apply(x) # type: ignore 124 | 125 | 126 | def format_to_tuple(format: FPFormat) -> Tuple[int, int]: 127 | """Convert the format into a tuple of `(exponent_bits, mantissa_bits)`""" 128 | return (format.exponent_bits, format.mantissa_bits) 129 | 130 | 131 | def tuple_to_format(t: Tuple[int, int]) -> FPFormat: 132 | """Given a tuple of `(exponent_bits, mantissa_bits)` returns the corresponding 133 | :class:`FPFormat`""" 134 | return FPFormat(*t) 135 | 136 | 137 | __all__ = generate__all__(__name__) 138 | -------------------------------------------------------------------------------- /unit_scaling/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | 3 | """Optimizer wrappers that apply scaling rules for u-muP. 4 | 5 | Provides :class:`Adam`, :class:`AdamW`, :class:`SGD` as out-of-the-box 6 | optimizers. 7 | 8 | Alternatively, :func:`scaled_parameters` provides finer control by 9 | transforming a parameter group for any downstream optimizer, given a 10 | function that defines the LR scaling rules. 11 | """ 12 | 13 | # mypy: disable-error-code="no-any-return" 14 | 15 | from typing import Any, Callable, Optional, Union 16 | 17 | import torch 18 | from torch import Tensor 19 | from torch.optim.optimizer import ParamsT 20 | 21 | from ._internal_utils import generate__all__ 22 | from .docs import inherit_docstring 23 | from .parameter import ParameterData, has_parameter_data 24 | 25 | 26 | def lr_scale_for_depth(param: ParameterData) -> float: 27 | """Calculate the LR scaling factor for depth only.""" 28 | if param.mup_scaling_depth is None: 29 | return 1 30 | return param.mup_scaling_depth**-0.5 31 | 32 | 33 | def _get_fan_in(param: ParameterData) -> int: 34 | # Note: the "fan_in" of an embedding layer is the hidden (output) dimension 35 | if len(param.shape) == 1: 36 | return param.shape[0] 37 | if len(param.shape) == 2: 38 | return param.shape[1] 39 | if len(param.shape) == 3: 40 | return param.shape[1] * param.shape[2] 41 | raise ValueError( 42 | f"Cannot get fan_in of `ndim >= 4` param, shape={tuple(param.shape)}" 43 | ) 44 | 45 | 46 | def lr_scale_func_sgd( 47 | readout_constraint: Optional[str], 48 | ) -> Callable[[ParameterData], float]: 49 | """Calculate the LR scaling factor for :class:`torch.optim.SGD`.""" 50 | 51 | if readout_constraint is None: 52 | # If there is no readout constraint we will have unit-scaled gradients and hence 53 | # unit-scaled weight updates. In this case the scaling rules are the same as 54 | # for Adam, which naturally has unit-scaled weight updates. 55 | return lr_scale_func_adam 56 | elif readout_constraint == "to_output_scale": 57 | 58 | def lr_scale_func_sgd_inner(param: ParameterData) -> float: 59 | scale = lr_scale_for_depth(param) 60 | 61 | if param.mup_type in ("bias", "norm"): 62 | return scale * param.shape[0] 63 | if param.mup_type == "weight": 64 | return scale * _get_fan_in(param) ** 0.5 65 | if param.mup_type == "output": 66 | return scale 67 | assert False, f"Unexpected mup_type {param.mup_type}" 68 | 69 | return lr_scale_func_sgd_inner 70 | else: 71 | assert False, f"Unhandled readout constraint: {readout_constraint}" 72 | 73 | 74 | def lr_scale_func_adam(param: ParameterData) -> float: 75 | """Calculate the LR scaling factor for :class:`torch.optim.Adam` 76 | and :class:`torch.optim.AdamW`. 77 | """ 78 | scale = lr_scale_for_depth(param) 79 | if param.mup_type in ("bias", "norm"): 80 | return scale 81 | if param.mup_type == "weight": 82 | return scale * _get_fan_in(param) ** -0.5 83 | if param.mup_type == "output": 84 | return scale 85 | assert False, f"Unexpected mup_type {param.mup_type}" 86 | 87 | 88 | def scaled_parameters( 89 | params: ParamsT, 90 | lr_scale_func: Callable[[ParameterData], float], 91 | lr: Union[None, float, Tensor] = None, 92 | weight_decay: float = 0, 93 | independent_weight_decay: bool = True, 94 | allow_non_unit_scaling_params: bool = False, 95 | ) -> ParamsT: 96 | """Create optimizer-appropriate **lr-scaled** parameter groups. 97 | 98 | This method creates param_groups that apply the relevant scaling factors for u-muP 99 | models. For example:: 100 | 101 | torch.optim.Adam(uu.optim.scaled_parameters( 102 | model.parameters(), uu.optim.adam_lr_scale_func, lr=1.0 103 | )) 104 | 105 | Args: 106 | params (ParamsT): an iterable of parameters of parameter groups, as passed to 107 | a torch optimizer. 108 | lr_scale_func (Callable): gets the optimizer-appropriate learning rate scale, 109 | based on a parameter tagged with `mup_type` and `mup_scaling_depth`. For 110 | example, :func:`lr_scale_func_sgd`. 111 | lr (float, optional): global learning rate (overridden by groups). 112 | weight_decay (float, optional): weight decay value (overridden by groups). 113 | independent_weight_decay (bool, optional): enable lr-independent weight decay, 114 | which performs an update per-step that does not depend on lr. 115 | allow_non_unit_scaling_params (bool, optional): by default, this method fails 116 | if passed any regular non-unit-scaled params; set to `True` to disable this 117 | check. 118 | 119 | Returns: 120 | ParamsT: for passing on to the optimizer. 121 | """ 122 | 123 | result = [] 124 | for entry in params: 125 | group = dict(params=[entry]) if isinstance(entry, Tensor) else entry.copy() 126 | group.setdefault("lr", lr) # type: ignore[arg-type] 127 | group.setdefault("weight_decay", weight_decay) # type: ignore[arg-type] 128 | if group["lr"] is None: 129 | raise ValueError( 130 | "scaled_params() requires lr to be provided," 131 | " unless passing parameter groups which already have an lr" 132 | ) 133 | for param in group["params"]: 134 | # Careful not to overwrite `lr` or `weight_decay` 135 | param_lr = group["lr"] 136 | if has_parameter_data(param): # type: ignore[arg-type] 137 | if isinstance(param_lr, Tensor): 138 | param_lr = param_lr.clone() 139 | param_lr *= lr_scale_func(param) # type: ignore[operator] 140 | elif not allow_non_unit_scaling_params: 141 | raise ValueError( 142 | "Non-unit-scaling parameter (no mup_type)," 143 | f" shape {tuple(param.shape)}" 144 | ) 145 | param_weight_decay = group["weight_decay"] 146 | if independent_weight_decay: 147 | # Note: only independent of peak LR, not of schedule 148 | param_weight_decay /= float(param_lr) # type: ignore 149 | 150 | result.append( 151 | dict( 152 | params=[param], 153 | lr=param_lr, 154 | weight_decay=param_weight_decay, 155 | **{ 156 | k: v 157 | for k, v in group.items() 158 | if k not in ("params", "lr", "weight_decay") 159 | }, 160 | ) 161 | ) 162 | return result 163 | 164 | 165 | @inherit_docstring( 166 | short_description="An **lr-scaled** version of :class:`torch.optim.SGD` for u-muP." 167 | "`readout_constraint` should match the `constraint` arg used in `LinearReadout`." 168 | ) 169 | class SGD(torch.optim.SGD): 170 | 171 | def __init__( 172 | self, 173 | params: ParamsT, 174 | lr: Union[float, Tensor] = 1e-3, 175 | *args: Any, 176 | weight_decay: float = 0, 177 | independent_weight_decay: bool = True, 178 | allow_non_unit_scaling_params: bool = False, 179 | readout_constraint: Optional[str] = None, 180 | **kwargs: Any, 181 | ) -> None: 182 | params = scaled_parameters( 183 | params, 184 | lr_scale_func_sgd(readout_constraint), 185 | lr=lr, 186 | weight_decay=weight_decay, 187 | independent_weight_decay=independent_weight_decay, 188 | allow_non_unit_scaling_params=allow_non_unit_scaling_params, 189 | ) 190 | # No need to forward {lr, weight_decay}, as each group has these specified 191 | super().__init__(params, *args, **kwargs) 192 | 193 | 194 | @inherit_docstring( 195 | short_description="An **lr-scaled** version of :class:`torch.optim.Adam` for u-muP." 196 | ) 197 | class Adam(torch.optim.Adam): 198 | def __init__( 199 | self, 200 | params: ParamsT, 201 | lr: Union[float, Tensor] = 1e-3, 202 | *args: Any, 203 | weight_decay: float = 0, 204 | independent_weight_decay: bool = True, 205 | allow_non_unit_scaling_params: bool = False, 206 | **kwargs: Any, 207 | ) -> None: 208 | params = scaled_parameters( 209 | params, 210 | lr_scale_func_adam, 211 | lr=lr, 212 | weight_decay=weight_decay, 213 | independent_weight_decay=independent_weight_decay, 214 | allow_non_unit_scaling_params=allow_non_unit_scaling_params, 215 | ) 216 | # No need to forward {lr, weight_decay}, as each group has these specified 217 | super().__init__(params, *args, **kwargs) 218 | 219 | 220 | @inherit_docstring( 221 | short_description=( 222 | "An **lr-scaled** version of :class:`torch.optim.AdamW` for u-muP." 223 | ) 224 | ) 225 | class AdamW(torch.optim.AdamW): 226 | def __init__( 227 | self, 228 | params: ParamsT, 229 | lr: Union[float, Tensor] = 1e-3, 230 | *args: Any, 231 | weight_decay: float = 0, 232 | independent_weight_decay: bool = True, 233 | allow_non_unit_scaling_params: bool = False, 234 | **kwargs: Any, 235 | ) -> None: 236 | params = scaled_parameters( 237 | params, 238 | lr_scale_func_adam, 239 | lr=lr, 240 | weight_decay=weight_decay, 241 | independent_weight_decay=independent_weight_decay, 242 | allow_non_unit_scaling_params=allow_non_unit_scaling_params, 243 | ) 244 | # No need to forward {lr, weight_decay}, as each group has these specified 245 | super().__init__(params, *args, **kwargs) 246 | 247 | 248 | __all__ = generate__all__(__name__) 249 | -------------------------------------------------------------------------------- /unit_scaling/parameter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | 3 | """Extends :class:`torch.nn.Parameter` with attributes for u-μP.""" 4 | 5 | # mypy: disable-error-code="attr-defined, method-assign, no-untyped-call" 6 | 7 | from collections import OrderedDict 8 | from typing import Any, Dict, Literal, Optional, Protocol, TypeGuard 9 | 10 | import torch 11 | from torch import Tensor, nn 12 | 13 | MupType = Literal["weight", "bias", "norm", "output"] 14 | 15 | 16 | class ParameterData(Protocol): 17 | """Extra fields for :class:`torch.nn.Parameter`, tagging u-μP metadata. 18 | 19 | Objects supporting this protocol should implicitly also support 20 | :class:`torch.nn.Parameter`. 21 | """ 22 | 23 | mup_type: MupType 24 | mup_scaling_depth: Optional[int] 25 | shape: torch.Size # repeated from nn.Parameter, for convenience 26 | 27 | 28 | def has_parameter_data(parameter: nn.Parameter) -> TypeGuard[ParameterData]: 29 | """Check that the parameter supports the :class:`ParameterData` protocol.""" 30 | return ( 31 | getattr(parameter, "mup_type", None) in MupType.__args__ 32 | and hasattr(parameter, "mup_scaling_depth") 33 | and isinstance(parameter.mup_scaling_depth, (type(None), int)) 34 | ) 35 | 36 | 37 | def _parameter_deepcopy(self: nn.Parameter, memo: Dict[int, Any]) -> nn.Parameter: 38 | result: nn.Parameter = nn.Parameter.__deepcopy__(self, memo) 39 | result.mup_type = self.mup_type 40 | result.mup_scaling_depth = self.mup_scaling_depth 41 | return result 42 | 43 | 44 | def _rebuild_parameter_with_state(*args: Any, **kwargs: Any) -> nn.Parameter: 45 | p: nn.Parameter = torch._utils._rebuild_parameter_with_state(*args, **kwargs) 46 | p.__deepcopy__ = _parameter_deepcopy.__get__(p) 47 | p.__reduce_ex__ = _parameter_reduce_ex.__get__(p) 48 | return p 49 | 50 | 51 | def _parameter_reduce_ex(self: nn.Parameter, protocol: int) -> Any: 52 | # Based on `torch.nn.Parameter.__reduce_ex__`, filtering out the 53 | # dynamic methods __deepcopy__ and __reduce_ex__, as these 54 | # don't unpickle 55 | state = { 56 | k: v 57 | for k, v in torch._utils._get_obj_state(self).items() 58 | if k not in ["__deepcopy__", "__reduce_ex__"] 59 | } 60 | return ( 61 | _rebuild_parameter_with_state, 62 | (self.data, self.requires_grad, OrderedDict(), state), 63 | ) 64 | 65 | 66 | def Parameter( 67 | data: Tensor, mup_type: MupType, mup_scaling_depth: Optional[int] = None 68 | ) -> nn.Parameter: 69 | """Construct a u-μP parameter object, an annotated :class:`torch.nn.Parameter`. 70 | 71 | The returned parameter also supports the :class:`ParameterData` protocol: 72 | 73 | p = uu.Parameter(torch.zeros(10), mup_type="weight") 74 | assert p.mup_type == "weight" 75 | assert p.mup_scaling_depth is None 76 | """ 77 | p = nn.Parameter(data) 78 | p.mup_type = mup_type 79 | p.mup_scaling_depth = mup_scaling_depth 80 | p.__deepcopy__ = _parameter_deepcopy.__get__(p) 81 | p.__reduce_ex__ = _parameter_reduce_ex.__get__(p) 82 | # Note: cannot override __repr__ as it's __class__.__repr__ 83 | return p 84 | -------------------------------------------------------------------------------- /unit_scaling/scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | """Operations to enable different scaling factors in the forward and backward passes.""" 4 | 5 | from __future__ import annotations # required for docs to alias type annotations 6 | 7 | from typing import Tuple 8 | 9 | import torch 10 | from torch import Tensor, fx 11 | 12 | from ._internal_utils import generate__all__ 13 | 14 | 15 | class _ScaledGrad(torch.autograd.Function): # pragma: no cover 16 | """Enables a custom backward method which has a different scale to forward.""" 17 | 18 | @staticmethod 19 | def forward( 20 | ctx: torch.autograd.function.FunctionCtx, 21 | X: Tensor, 22 | fwd_scale: float, 23 | bwd_scale: float, 24 | ) -> Tensor: 25 | # Special cases required for torch.fx tracing 26 | if isinstance(bwd_scale, fx.proxy.Proxy): 27 | ctx.save_for_backward(bwd_scale) # type: ignore 28 | elif isinstance(X, fx.proxy.Proxy): 29 | ctx.save_for_backward(torch.tensor(bwd_scale)) 30 | else: 31 | ctx.save_for_backward(torch.tensor(bwd_scale, dtype=X.dtype)) 32 | return fwd_scale * X 33 | 34 | @staticmethod 35 | def backward( # type:ignore[override] 36 | ctx: torch.autograd.function.FunctionCtx, grad_Y: Tensor 37 | ) -> Tuple[Tensor, None, None]: 38 | (bwd_scale,) = ctx.saved_tensors # type: ignore 39 | return bwd_scale * grad_Y, None, None 40 | 41 | 42 | def _scale( 43 | t: Tensor, fwd_scale: float = 1.0, bwd_scale: float = 1.0 44 | ) -> Tensor: # pragma: no cover 45 | """Given a tensor, applies a separate scale in the forward and backward pass.""" 46 | return _ScaledGrad.apply(t, fwd_scale, bwd_scale) # type: ignore 47 | 48 | 49 | def scale_fwd(input: Tensor, scale: float) -> Tensor: 50 | """Applies a scalar multiplication to a tensor in only the forward pass. 51 | 52 | Args: 53 | input (Tensor): the tensor to be scaled. 54 | scale (float): the scale factor applied to the tensor in the forward pass. 55 | 56 | Returns: 57 | Tensor: scaled in the forward pass, but with its original grad. 58 | """ 59 | return _scale(input, fwd_scale=scale) 60 | 61 | 62 | def scale_bwd(input: Tensor, scale: float) -> Tensor: 63 | """Applies a scalar multiplication to a tensor in only the backward pass. 64 | 65 | Args: 66 | input (Tensor): the tensor to be scaled. 67 | scale (float): the scale factor applied to the tensor in the backward pass. 68 | 69 | Returns: 70 | Tensor: unchanged in the forward pass, but with a scaled grad. 71 | """ 72 | return _scale(input, bwd_scale=scale) 73 | 74 | 75 | __all__ = generate__all__(__name__) 76 | -------------------------------------------------------------------------------- /unit_scaling/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | -------------------------------------------------------------------------------- /unit_scaling/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import random 4 | 5 | import numpy as np 6 | import pytest 7 | import torch 8 | 9 | 10 | @pytest.fixture(scope="function", autouse=True) 11 | def fix_seed() -> None: 12 | """For each test function, reset all random seeds.""" 13 | random.seed(1472) 14 | np.random.seed(1472) 15 | torch.manual_seed(1472) 16 | -------------------------------------------------------------------------------- /unit_scaling/tests/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | -------------------------------------------------------------------------------- /unit_scaling/tests/core/test_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | from torch import ones, randn, tensor 5 | from torch.testing import assert_close 6 | 7 | from ...core.functional import rms, scale_elementwise, transformer_residual_scaling_rule 8 | from ..helper import unit_backward 9 | 10 | 11 | def test_scale_elementwise_no_constraint() -> None: 12 | input = randn(2**10, requires_grad=True) 13 | f = lambda x: x 14 | scaled_f = scale_elementwise( 15 | f, output_scale=2.5, grad_input_scale=0.5, constraint=None 16 | ) 17 | output = scaled_f(input) 18 | unit_backward(output) 19 | 20 | assert output.std().detach() == pytest.approx(2.5, rel=0.1) 21 | assert input.grad.std().detach() == pytest.approx(0.5, rel=0.1) # type: ignore 22 | 23 | 24 | def test_scale_elementwise_for_output() -> None: 25 | input = randn(2**10, requires_grad=True) 26 | f = lambda x: x 27 | scaled_f = scale_elementwise( 28 | f, output_scale=2.5, grad_input_scale=0.5, constraint="to_output_scale" 29 | ) 30 | output = scaled_f(input) 31 | unit_backward(output) 32 | 33 | assert output.std().detach() == pytest.approx(2.5, rel=0.1) 34 | assert input.grad.std().detach() == pytest.approx(2.5, rel=0.1) # type: ignore 35 | 36 | 37 | def test_scale_elementwise_for_grad_input() -> None: 38 | input = randn(2**10, requires_grad=True) 39 | f = lambda x: x 40 | scaled_f = scale_elementwise( 41 | f, output_scale=2.5, grad_input_scale=0.5, constraint="to_grad_input_scale" 42 | ) 43 | output = scaled_f(input) 44 | unit_backward(output) 45 | 46 | assert output.std().detach() == pytest.approx(0.5, rel=0.1) 47 | assert input.grad.std().detach() == pytest.approx(0.5, rel=0.1) # type: ignore 48 | 49 | 50 | def test_rms() -> None: 51 | output = rms(-4 + 3 * randn(2**12)) 52 | assert output.item() == pytest.approx(5, rel=0.1) 53 | 54 | x = tensor([[2, -2, -2, -2], [0, 2, 0, 0]]).float() 55 | output = rms(x, dims=(1,)) 56 | assert_close(output, tensor([2.0, 1.0])) 57 | 58 | output = rms(tensor([0.0, 0.0, 0.0]), eps=1 / 16) 59 | assert output.item() == pytest.approx(1 / 4, rel=0.1) 60 | 61 | 62 | @pytest.mark.parametrize( 63 | ["residual_mult", "residual_attn_ratio", "layers"], 64 | [ 65 | (1.0, 1.0, 4), 66 | (0.5, 1.0, 4), 67 | (1.0, 2.0, 4), 68 | (0.5, 1 / 3, 6), 69 | ], 70 | ) 71 | def test_transformer_residual_scaling_rule( 72 | residual_mult: float, residual_attn_ratio: float, layers: int 73 | ) -> None: 74 | scaling_rule = transformer_residual_scaling_rule( 75 | residual_mult=residual_mult, 76 | residual_attn_ratio=residual_attn_ratio, 77 | ) 78 | scales = ones(layers + 1) 79 | for n in range(layers): 80 | tau = scaling_rule(n, layers) 81 | scales[n + 1] *= tau 82 | scales[: n + 2] /= (1 + tau**2) ** 0.5 83 | embedding_scale = scales[0] 84 | attn_scales = scales[1:][::2] 85 | mlp_scales = scales[2:][::2] 86 | 87 | # Basic properties 88 | assert_close(scales.pow(2).sum(), tensor(1.0)) 89 | 90 | s_embedding = embedding_scale 91 | s_attn = attn_scales.pow(2).sum().sqrt() 92 | s_mlp = mlp_scales.pow(2).sum().sqrt() 93 | s_attn_mlp_average = ((s_attn**2 + s_mlp**2) / 2).sqrt() 94 | 95 | assert_close(s_attn_mlp_average / s_embedding, tensor(residual_mult)) 96 | assert_close(s_attn / s_mlp, tensor(residual_attn_ratio)) 97 | 98 | # Per-layer scales are equal 99 | assert_close(attn_scales, attn_scales[:1].broadcast_to(attn_scales.shape)) 100 | assert_close(mlp_scales, mlp_scales[:1].broadcast_to(mlp_scales.shape)) 101 | -------------------------------------------------------------------------------- /unit_scaling/tests/helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import Optional 4 | 5 | import pytest 6 | import torch 7 | from torch import Tensor, randn 8 | 9 | from ..core.functional import rms 10 | 11 | 12 | def unit_backward(tensor: Tensor) -> Tensor: 13 | """Applies the `backward()` method with a unit normal tensor as input. 14 | 15 | Args: 16 | tensor (Tensor): tensor to have `backward()` applied. 17 | 18 | Returns: 19 | Tensor: the unit normal gradient tensor fed into `backward()`. 20 | """ 21 | gradient = randn(*tensor.shape) 22 | tensor.backward(gradient) # type: ignore 23 | return gradient 24 | 25 | 26 | def assert_scale( 27 | *tensors: Optional[Tensor], target: float, abs: float = 0.1, stat: str = "std" 28 | ) -> None: 29 | for t in tensors: 30 | assert t is not None 31 | t = t.detach() 32 | approx_target = pytest.approx(target, abs=abs) 33 | stat_value = dict(rms=rms, std=torch.std)[stat](t) # type:ignore[operator] 34 | assert ( 35 | stat_value == approx_target 36 | ), f"{stat}={stat_value:.3f}, shape={list(t.shape)}" 37 | 38 | 39 | def assert_not_scale( 40 | *tensors: Optional[Tensor], target: float, abs: float = 0.1, stat: str = "std" 41 | ) -> None: 42 | for t in tensors: 43 | assert t is not None 44 | t = t.detach() 45 | approx_target = pytest.approx(target, abs=abs) 46 | stat_value = dict(rms=t.pow(2).mean().sqrt(), std=t.std())[stat] 47 | assert ( 48 | stat_value != approx_target 49 | ), f"{stat}={stat_value:.3f}, shape={list(t.shape)}" 50 | 51 | 52 | def assert_unit_scaled( 53 | *tensors: Optional[Tensor], abs: float = 0.1, stat: str = "std" 54 | ) -> None: 55 | return assert_scale(*tensors, target=1.0, abs=abs, stat=stat) 56 | 57 | 58 | def assert_not_unit_scaled( 59 | *tensors: Optional[Tensor], abs: float = 0.1, stat: str = "std" 60 | ) -> None: 61 | return assert_not_scale(*tensors, target=1.0, abs=abs, stat=stat) 62 | 63 | 64 | def assert_zeros(*tensors: Optional[Tensor]) -> None: 65 | for t in tensors: 66 | assert t is not None 67 | t = t.detach() 68 | assert torch.all(t == 0) 69 | 70 | 71 | def assert_non_zeros(*tensors: Optional[Tensor]) -> None: 72 | for t in tensors: 73 | assert t is not None 74 | t = t.detach() 75 | assert torch.any(t != 0) 76 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import Tuple 4 | 5 | import torch.nn.functional as F 6 | from torch import Size, Tensor, nn, randn 7 | from transformers import AutoTokenizer # type: ignore[import-untyped] 8 | 9 | from ..analysis import _create_batch, _example_seqs, example_batch, plot, visualiser 10 | from ..transforms import track_scales 11 | 12 | 13 | def test_example_seqs() -> None: 14 | batch_size, min_seq_len = 3, 1024 15 | seqs = _example_seqs(batch_size, min_seq_len) 16 | assert len(seqs) == batch_size, len(seqs) 17 | for s in seqs: 18 | assert isinstance(s, str) 19 | assert not s.isspace() 20 | assert len(s) >= min_seq_len 21 | 22 | 23 | def test_create_batch() -> None: 24 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped") 25 | batch_size, seq_len = 3, 256 26 | seqs = _example_seqs(batch_size, min_seq_len=seq_len * 4) 27 | input_idxs, attn_mask, labels = _create_batch(tokenizer, seqs, seq_len) 28 | 29 | assert isinstance(input_idxs, Tensor) 30 | assert isinstance(attn_mask, Tensor) 31 | assert isinstance(labels, Tensor) 32 | assert input_idxs.shape == Size([batch_size, seq_len]) 33 | assert attn_mask.shape == Size([batch_size, seq_len]) 34 | assert labels.shape == Size([batch_size, seq_len]) 35 | 36 | 37 | def test_example_batch() -> None: 38 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped") 39 | batch_size, seq_len = 3, 256 40 | input_idxs, attn_mask, labels = example_batch(tokenizer, batch_size, seq_len) 41 | 42 | assert isinstance(input_idxs, Tensor) 43 | assert isinstance(attn_mask, Tensor) 44 | assert isinstance(labels, Tensor) 45 | assert input_idxs.shape == Size([batch_size, seq_len]) 46 | assert attn_mask.shape == Size([batch_size, seq_len]) 47 | assert labels.shape == Size([batch_size, seq_len]) 48 | 49 | 50 | def test_plot() -> None: 51 | class Model(nn.Module): 52 | def __init__(self, dim: int) -> None: 53 | super().__init__() 54 | self.dim = dim 55 | self.linear = nn.Linear(dim, dim // 2) 56 | 57 | def forward(self, x: Tensor) -> Tensor: # pragma: no cover 58 | y = F.relu(x) 59 | z = self.linear(y) 60 | return z.sum() # type: ignore[no-any-return] 61 | 62 | b, dim = 2**4, 2**8 63 | input = randn(b, dim) 64 | model = Model(dim) 65 | model = track_scales(model) 66 | loss = model(input) 67 | loss.backward() 68 | 69 | graph = model.scales_graph() 70 | axes = plot(graph, "demo", xmin=2**-20, xmax=2**10) 71 | assert axes 72 | 73 | 74 | def test_visualiser() -> None: 75 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped") 76 | 77 | class Model(nn.Module): 78 | def __init__(self, n_embed: int, dim: int) -> None: 79 | super().__init__() 80 | self.embedding = nn.Embedding(n_embed, dim) 81 | self.linear = nn.Linear(dim, n_embed) 82 | 83 | def forward( 84 | self, inputs: Tensor, labels: Tensor 85 | ) -> Tuple[Tensor, Tensor]: # pragma: no cover 86 | x = self.embedding(inputs) 87 | x = self.linear(x) 88 | loss = F.cross_entropy(x.view(-1, x.size(-1)), labels.view(-1)) 89 | return x, loss 90 | 91 | axes = visualiser( 92 | model=Model(n_embed=tokenizer.vocab_size, dim=128), 93 | tokenizer=tokenizer, 94 | batch_size=16, 95 | seq_len=256, 96 | ) 97 | assert axes 98 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_constraints.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | 5 | from ..constraints import ( 6 | amean, 7 | apply_constraint, 8 | gmean, 9 | hmean, 10 | to_grad_input_scale, 11 | to_output_scale, 12 | ) 13 | 14 | 15 | def test_gmean() -> None: 16 | assert gmean(1.0) == 1.0 17 | assert gmean(8.0, 2.0) == 4.0 18 | assert gmean(12.75, 55.0, 0.001) == pytest.approx(0.8884322) 19 | 20 | 21 | def test_hmean() -> None: 22 | assert hmean(1.0) == 1.0 23 | assert hmean(8.0, 2.0) == 3.2 24 | assert hmean(12.75, 55.0, 0.001) == pytest.approx(0.00299971) 25 | 26 | 27 | def test_amean() -> None: 28 | assert amean(1.0) == 1.0 29 | assert amean(8.0, 2.0) == 5.0 30 | assert amean(12.75, 55.0, 0.001) == pytest.approx(22.583667) 31 | 32 | 33 | def test_to_output_scale() -> None: 34 | assert to_output_scale(2, 3) == 2 35 | assert to_output_scale(2, 3, 4) == 2 36 | 37 | 38 | def test_to_grad_input_scale() -> None: 39 | assert to_grad_input_scale(2, 3) == 3 40 | 41 | 42 | def test_apply_constraint() -> None: 43 | assert apply_constraint("gmean", 1.0, 4.0, 2.0) == (2.0, 2.0, 2.0) 44 | assert apply_constraint("hmean", 8.0, 2.0) == (3.2, 3.2) 45 | with pytest.raises(ValueError): 46 | apply_constraint("invalid", 8.0, 2.0) 47 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_docs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | 5 | from ..docs import _validate 6 | 7 | 8 | def f(a, b: int, c="3", d: float = 4.0) -> str: # type: ignore 9 | return f"{a} {b} {c} {d}" 10 | 11 | 12 | def test_validate_no_args() -> None: 13 | def g() -> int: 14 | return 0 15 | 16 | valid_g = _validate(g) 17 | assert valid_g() == 0 18 | 19 | 20 | def test_validate_positional_args() -> None: 21 | # Works with no unsupported args 22 | valid_f = _validate(f) 23 | assert valid_f(None, 2) == "None 2 3 4.0" 24 | assert valid_f(None, 2, "3", 4.5) == "None 2 3 4.5" 25 | 26 | # Works with some unsupported args 27 | valid_f = _validate(f, unsupported_args=["c", "d"]) 28 | 29 | # Works if unsupported args are not present or equal default 30 | assert valid_f(None, 2) == "None 2 3 4.0" 31 | assert valid_f(None, 2, "3", 4.0) == "None 2 3 4.0" 32 | 33 | # Doesn't work if non-default unsupported args provided 34 | with pytest.raises(ValueError) as e: 35 | valid_f(None, 2, "3.4") 36 | assert "argument has not been implemented" in str(e.value) 37 | with pytest.raises(ValueError) as e: 38 | valid_f(None, 2, "3", 4.5) 39 | assert "argument has not been implemented" in str(e.value) 40 | 41 | 42 | def test_validate_positional_kwargs() -> None: 43 | # Works with no unsupported args 44 | valid_f = _validate(f) 45 | assert valid_f(None, 2) == "None 2 3 4.0" 46 | assert valid_f(None, 2, c="3", d=4.5) == "None 2 3 4.5" 47 | 48 | # Works with some unsupported args 49 | valid_f = _validate(f, unsupported_args=["c", "d"]) 50 | 51 | # Works if unsupported args are not present or equal default 52 | assert valid_f(None, 2) == "None 2 3 4.0" 53 | assert valid_f(None, 2, c="3") 54 | 55 | # Doesn't work if non-default unsupported args provided 56 | with pytest.raises(ValueError) as e: 57 | valid_f(None, 2, c="3.4") 58 | assert "argument has not been implemented" in str(e.value) 59 | with pytest.raises(ValueError) as e: 60 | valid_f(None, 2, d=4.5) 61 | assert "argument has not been implemented" in str(e.value) 62 | 63 | 64 | def test_validate_invalid_arg() -> None: 65 | with pytest.raises(ValueError) as e: 66 | _validate(f, unsupported_args=["z"]) 67 | assert "is not valid" in str(e.value) 68 | 69 | 70 | def test_validate_no_default_arg() -> None: 71 | with pytest.raises(ValueError) as e: 72 | _validate(f, unsupported_args=["b"]) 73 | assert "has no default value" in str(e.value) 74 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_formats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import collections 4 | 5 | import torch 6 | 7 | from ..formats import FPFormat 8 | 9 | 10 | def test_fp_format() -> None: 11 | fmt = FPFormat(2, 1) 12 | assert fmt.bits == 4 13 | assert str(fmt) == "E2M1-SR" 14 | assert fmt.max_absolute_value == 3 15 | assert fmt.min_absolute_normal == 0.5 16 | assert fmt.min_absolute_subnormal == 0.25 17 | assert set(fmt.quantise_fwd(torch.linspace(-4, 4, steps=100)).tolist()) == { 18 | sx for x in [0, 0.25, 0.5, 0.75, 1, 1.5, 2, 3] for sx in [x, -x] 19 | } 20 | assert set( 21 | FPFormat(3, 0).quantise_fwd(torch.linspace(-10, 10, steps=1000)).abs().tolist() 22 | ) == {0, 0.125, 0.25, 0.5, 1, 2, 4, 8} 23 | 24 | 25 | def test_fp_format_rounding() -> None: 26 | n = 10000 27 | x = -1.35 28 | 29 | y_nearest = FPFormat(2, 1, rounding="nearest").quantise(torch.full((n,), x)) 30 | assert collections.Counter(y_nearest.tolist()) == {-1.5: n} 31 | 32 | for srbits in (0, 13): 33 | srformat = FPFormat(2, 1, rounding="stochastic", srbits=srbits) 34 | y_stochastic = srformat.quantise(torch.full((n,), x)) 35 | count = collections.Counter(y_stochastic.tolist()) 36 | assert count.keys() == {-1.5, -1.0} 37 | expected_ratio = (1.35 - 1.0) / 0.5 38 | nearest_ratio = count[-1.5] / sum(count.values()) 39 | std_x3 = 3 * (expected_ratio * (1 - expected_ratio) / n) ** 0.5 40 | assert expected_ratio - std_x3 < nearest_ratio < expected_ratio + std_x3 41 | 42 | 43 | def test_fp_format_bwd() -> None: 44 | fmt = FPFormat(2, 1) 45 | x = torch.randn(100, requires_grad=True) 46 | y = fmt.quantise_bwd(x * 1) 47 | y.backward(torch.linspace(-4, 4, steps=100)) # type: ignore[no-untyped-call] 48 | assert x.grad is not None 49 | assert set(x.grad.tolist()) == { 50 | sx for x in [0, 0.25, 0.5, 0.75, 1, 1.5, 2, 3] for sx in [x, -x] 51 | } 52 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | from torch import randint, randn 6 | 7 | from .._modules import ( 8 | GELU, 9 | MHSA, 10 | MLP, 11 | Conv1d, 12 | CrossEntropyLoss, 13 | DepthModuleList, 14 | DepthSequential, 15 | Dropout, 16 | Embedding, 17 | LayerNorm, 18 | Linear, 19 | LinearReadout, 20 | RMSNorm, 21 | SiLU, 22 | Softmax, 23 | TransformerDecoder, 24 | TransformerLayer, 25 | ) 26 | from ..optim import SGD 27 | from ..parameter import has_parameter_data 28 | from .helper import ( 29 | assert_non_zeros, 30 | assert_not_unit_scaled, 31 | assert_unit_scaled, 32 | assert_zeros, 33 | unit_backward, 34 | ) 35 | 36 | 37 | def test_gelu() -> None: 38 | input = randn(2**10) 39 | model = GELU() 40 | output = model(input) 41 | 42 | assert float(output.std()) == pytest.approx(1, abs=0.1) 43 | 44 | 45 | def test_silu() -> None: 46 | input = randn(2**10) 47 | model = SiLU() 48 | output = model(input) 49 | 50 | assert float(output.std()) == pytest.approx(1, abs=0.1) 51 | 52 | 53 | def test_softmax() -> None: 54 | input = randn(2**14) 55 | model = Softmax(dim=0) 56 | output = model(input) 57 | 58 | # The approximation is quite rough at mult=1 59 | assert 0.1 < float(output.std()) < 10 60 | 61 | 62 | def test_dropout() -> None: 63 | input = randn(2**12, requires_grad=True) 64 | model = Dropout() 65 | output = model(input) 66 | 67 | assert float(output.std()) == pytest.approx(1, abs=0.1) 68 | 69 | with pytest.raises(ValueError): 70 | Dropout(0.5, inplace=True) 71 | 72 | 73 | def test_linear() -> None: 74 | input = randn(2**8, 2**10, requires_grad=True) 75 | model = Linear(2**10, 2**12, bias=True) 76 | output = model(input) 77 | 78 | assert_unit_scaled(model.weight) 79 | assert_zeros(model.bias) 80 | assert output.shape == torch.Size([2**8, 2**12]) 81 | 82 | unit_backward(output) 83 | SGD(model.parameters(), lr=1, readout_constraint="to_output_scale").step() 84 | 85 | assert float(output.std()) == pytest.approx(1, abs=0.1) 86 | 87 | assert_not_unit_scaled(model.weight) 88 | assert_non_zeros(model.bias) 89 | 90 | 91 | def test_conv1d() -> None: 92 | batch_size = 2**6 93 | d_in = 2**6 * 3 94 | d_out = 2**6 * 5 95 | kernel_size = 11 96 | seq_len = 2**6 * 7 97 | input = randn(batch_size, d_in, seq_len, requires_grad=True) 98 | model = Conv1d(d_in, d_out, kernel_size, bias=True) 99 | output = model(input) 100 | 101 | assert_unit_scaled(model.weight) 102 | assert_zeros(model.bias) 103 | 104 | unit_backward(output) 105 | SGD(model.parameters(), lr=1, readout_constraint="to_output_scale").step() 106 | 107 | assert float(output.std()) == pytest.approx(1, abs=0.1) 108 | 109 | assert_not_unit_scaled(model.weight) 110 | assert_non_zeros(model.bias) 111 | 112 | 113 | def test_linear_readout() -> None: 114 | input = randn(2**8, 2**10, requires_grad=True) 115 | model = LinearReadout(2**10, 2**12) 116 | output = model(input) 117 | 118 | assert model.weight.mup_type == "output" # type:ignore[attr-defined] 119 | assert_unit_scaled(model.weight) 120 | assert output.shape == torch.Size([2**8, 2**12]) 121 | assert float(output.std()) == pytest.approx(2**-5, rel=0.1) 122 | 123 | unit_backward(output) 124 | SGD(model.parameters(), lr=1).step() 125 | assert_not_unit_scaled(model.weight) 126 | 127 | 128 | def test_layer_norm() -> None: 129 | input = randn(2**8, 2**10, requires_grad=True) 130 | model = LayerNorm(2**10, elementwise_affine=True) 131 | output = model(input) 132 | 133 | assert output.shape == torch.Size([2**8, 2**10]) 134 | 135 | unit_backward(output) 136 | SGD(model.parameters(), lr=1).step() 137 | 138 | assert_unit_scaled(output, input.grad, model.weight.grad, model.bias.grad) 139 | 140 | 141 | def test_rms_norm() -> None: 142 | input = randn(2**8, 2**10, requires_grad=True) 143 | model = RMSNorm(2**10, elementwise_affine=True) 144 | output = model(input) 145 | 146 | assert output.shape == torch.Size([2**8, 2**10]) 147 | assert model.weight is not None 148 | 149 | unit_backward(output) 150 | SGD(model.parameters(), lr=1).step() 151 | 152 | assert_unit_scaled(output, input.grad, model.weight.grad) 153 | 154 | 155 | def test_embedding() -> None: 156 | batch_sz, seq_len, embedding_dim, num_embeddings = 2**4, 2**5, 2**6, 2**12 157 | input_idxs = randint(low=0, high=2**12, size=(batch_sz, seq_len)) 158 | model = Embedding(num_embeddings, embedding_dim) 159 | output = model(input_idxs) 160 | 161 | assert output.shape == torch.Size([batch_sz, seq_len, embedding_dim]) 162 | 163 | unit_backward(output) 164 | 165 | assert_unit_scaled(model.weight.grad) 166 | 167 | with pytest.raises(ValueError): 168 | Embedding(num_embeddings, embedding_dim, scale_grad_by_freq=True) 169 | with pytest.raises(ValueError): 170 | Embedding(num_embeddings, embedding_dim, sparse=True) 171 | 172 | 173 | def test_cross_entropy_loss() -> None: 174 | num_tokens, vocab_sz = 2**12, 2**8 175 | input = randn(num_tokens, vocab_sz, requires_grad=True) 176 | labels = randint(low=0, high=vocab_sz, size=(num_tokens,)) 177 | model = CrossEntropyLoss() 178 | loss = model(input, labels) 179 | loss.backward() 180 | 181 | assert_unit_scaled(input.grad) 182 | 183 | with pytest.raises(ValueError): 184 | CrossEntropyLoss(weight=randn(vocab_sz)) 185 | with pytest.raises(ValueError): 186 | CrossEntropyLoss(label_smoothing=0.5) 187 | 188 | 189 | def test_mlp() -> None: 190 | input = randn(2**8, 2**10, requires_grad=True) 191 | model = MLP(2**10) 192 | output = model(input) 193 | 194 | assert_unit_scaled( 195 | model.linear_1.weight, model.linear_gate.weight, model.linear_2.weight 196 | ) 197 | assert output.shape == torch.Size([2**8, 2**10]) 198 | 199 | unit_backward(output) 200 | SGD(model.parameters(), lr=1, readout_constraint="to_output_scale").step() 201 | 202 | assert float(output.std()) == pytest.approx(1, abs=0.2) 203 | 204 | assert_unit_scaled( 205 | model.linear_1.weight.grad, 206 | model.linear_gate.weight.grad, 207 | model.linear_2.weight.grad, 208 | ) 209 | 210 | assert_not_unit_scaled( 211 | model.linear_1.weight, model.linear_gate.weight, model.linear_2.weight 212 | ) 213 | 214 | 215 | def test_mhsa() -> None: 216 | batch_sz, seq_len, hidden_dim = 2**8, 2**6, 2**6 217 | input = randn(batch_sz, seq_len, hidden_dim, requires_grad=True) 218 | model = MHSA(hidden_dim, heads=8, is_causal=False, dropout_p=0.1) 219 | output = model(input) 220 | 221 | assert_unit_scaled(model.linear_qkv.weight, model.linear_o.weight) 222 | assert output.shape == torch.Size([batch_sz, seq_len, hidden_dim]) 223 | 224 | unit_backward(output) 225 | SGD(model.parameters(), lr=1, readout_constraint="to_output_scale").step() 226 | 227 | assert float(output.std()) == pytest.approx(1, abs=0.5) 228 | 229 | assert_not_unit_scaled(model.linear_qkv.weight, model.linear_o.weight) 230 | 231 | 232 | def test_transformer_layer() -> None: 233 | batch_sz, seq_len, hidden_dim, heads = 2**8, 2**6, 2**6, 8 234 | input = randn(batch_sz, seq_len, hidden_dim, requires_grad=True) 235 | model = TransformerLayer( 236 | hidden_dim, 237 | heads=heads, 238 | is_causal=False, 239 | dropout_p=0.1, 240 | mhsa_tau=0.1, 241 | mlp_tau=1.0, 242 | ) 243 | output = model(input) 244 | 245 | assert output.shape == torch.Size([batch_sz, seq_len, hidden_dim]) 246 | 247 | unit_backward(output) 248 | SGD(model.parameters(), lr=1).step() 249 | 250 | assert float(output.std()) == pytest.approx(1, abs=0.1) 251 | 252 | 253 | def test_depth_module_list() -> None: 254 | layers = DepthModuleList([Linear(10, 10) for _ in range(5)]) 255 | assert len(layers) == 5 256 | for layer in layers: 257 | assert layer.weight.mup_scaling_depth == 5 258 | 259 | with pytest.raises(ValueError): 260 | DepthModuleList([torch.nn.Linear(10, 10) for _ in range(5)]) 261 | 262 | 263 | def test_depth_sequential() -> None: 264 | model = DepthSequential(*(Linear(2**6, 2**6) for _ in range(7))) 265 | for param in model.parameters(): 266 | assert has_parameter_data(param) 267 | assert param.mup_scaling_depth == 7 268 | 269 | input = randn(2**4, 2**6, requires_grad=True) 270 | output = model(input) 271 | unit_backward(output) 272 | assert_unit_scaled(output, input.grad) 273 | 274 | with pytest.raises(ValueError): 275 | DepthSequential(*[torch.nn.Linear(2**6, 2**6) for _ in range(7)]) 276 | 277 | 278 | def test_transformer_decoder() -> None: 279 | batch_size = 2**8 280 | seq_len = 2**6 281 | hidden_size = 2**6 282 | vocab_size = 2**12 283 | layers = 2 284 | heads = 4 285 | 286 | input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) 287 | model = TransformerDecoder(hidden_size, vocab_size, layers, heads, dropout_p=0.1) 288 | loss = model.loss(input_ids) 289 | 290 | expected_loss = torch.tensor(vocab_size).log() 291 | assert expected_loss / 2 < loss.item() < expected_loss * 2 292 | 293 | loss.backward() # type:ignore[no-untyped-call] 294 | SGD(model.parameters(), lr=1).step() 295 | 296 | for name, p in model.named_parameters(): 297 | threshold = 5.0 298 | assert p.grad is not None 299 | assert 1 / threshold <= p.grad.std().detach() <= threshold, name 300 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import Any, Dict, List, Optional, Type, cast 4 | 5 | import pytest 6 | import torch 7 | from torch import nn, tensor, zeros 8 | from torch.testing import assert_close 9 | 10 | import unit_scaling as uu 11 | import unit_scaling.functional as U 12 | 13 | from ..optim import ( 14 | SGD, 15 | Adam, 16 | AdamW, 17 | lr_scale_func_adam, 18 | lr_scale_func_sgd, 19 | scaled_parameters, 20 | ) 21 | 22 | 23 | @pytest.mark.parametrize("opt_type", (Adam, AdamW, SGD)) 24 | def test_optim_optimizers(opt_type: Type[torch.optim.Optimizer]) -> None: 25 | torch.manual_seed(100) 26 | inputs = torch.randn(10, 16) 27 | outputs = torch.randn(10, 25) 28 | model = uu.Linear(16, 25) 29 | opt = opt_type( 30 | model.parameters(), lr=0.01, weight_decay=1e-6 # type:ignore[call-arg] 31 | ) 32 | opt.zero_grad() 33 | loss = U.mse_loss(model(inputs), outputs) 34 | loss.backward() # type:ignore[no-untyped-call] 35 | opt.step() 36 | assert U.mse_loss(model(inputs), outputs) < loss 37 | 38 | 39 | @pytest.mark.parametrize("opt", ["adam", "sgd"]) 40 | @pytest.mark.parametrize("readout_constraint", [None, "to_output_scale"]) 41 | def test_scaled_parameters(opt: str, readout_constraint: Optional[str]) -> None: 42 | model = nn.Sequential( 43 | uu.Embedding(2**8, 2**4), 44 | uu.DepthSequential(*(uu.Linear(2**4, 2**4, bias=True) for _ in range(3))), 45 | uu.LinearReadout( 46 | 2**4, 47 | 2**10, 48 | bias=False, 49 | weight_mup_type="output", 50 | constraint=readout_constraint, 51 | ), 52 | ) 53 | 54 | base_lr = 0.1 55 | base_wd = 0.001 56 | param_groups = scaled_parameters( 57 | model.parameters(), 58 | dict(sgd=lr_scale_func_sgd(readout_constraint), adam=lr_scale_func_adam)[opt], 59 | lr=base_lr, 60 | weight_decay=base_wd, 61 | ) 62 | 63 | # Match parameters based on shapes, as their names have disappeared 64 | sqrt_d = 3**0.5 65 | shape_to_expected_lr = { 66 | (2**8, 2**4): base_lr / 2**2, # embedding.weight 67 | (2**4, 2**4): base_lr / 2**2 / sqrt_d, # stack.linear.weight 68 | (2**4,): base_lr / sqrt_d, # stack.linear.bias 69 | (2**10, 2**4): base_lr, # linear.weight (output) 70 | } 71 | if opt == "sgd" and readout_constraint == "to_output_scale": 72 | shape_to_expected_lr[(2**8, 2**4)] *= 2**4 73 | shape_to_expected_lr[(2**4, 2**4)] *= 2**4 74 | shape_to_expected_lr[(2**4,)] *= 2**4 75 | 76 | for shape, expected_lr in shape_to_expected_lr.items(): 77 | for g in param_groups: 78 | assert isinstance(g, dict) 79 | (param,) = g["params"] 80 | if param.shape == shape: 81 | assert g["lr"] == pytest.approx( 82 | expected_lr, rel=1e-3 83 | ), f"bad LR for param.shape={shape}" 84 | assert g["weight_decay"] == pytest.approx( 85 | base_wd / expected_lr, rel=1e-3 86 | ), f"bad WD for param.shape={shape}" 87 | 88 | 89 | def test_scaled_parameters_with_existing_groups() -> None: 90 | original_params = cast( 91 | List[Dict[str, Any]], 92 | [ 93 | # Two parameters in this group, sharing a tensor LR, which must be cloned 94 | dict( 95 | params=[ 96 | uu.Parameter(zeros(1, 2**4), mup_type="weight"), 97 | uu.Parameter(zeros(2, 2**6), mup_type="output"), 98 | ], 99 | lr=torch.tensor(0.3), 100 | weight_decay=0.05, 101 | ), 102 | # One parameter in this group, with no explicit LR or WD 103 | dict( 104 | params=[ 105 | uu.Parameter( 106 | zeros(3, 2**6), mup_type="weight", mup_scaling_depth=5 107 | ), 108 | ], 109 | ), 110 | ], 111 | ) 112 | 113 | g0, g1, g2 = scaled_parameters( 114 | original_params, lr_scale_func_adam, lr=torch.tensor(0.02) 115 | ) 116 | 117 | assert isinstance(g0, dict) 118 | assert g0["params"][0].shape == (1, 2**4) 119 | assert_close(g0["lr"], tensor(0.3 / 2**2)) # also checks it's still a Tensor 120 | assert_close(g0["weight_decay"], 0.05 * 2**2 / 0.3) 121 | 122 | assert isinstance(g1, dict) 123 | assert g1["params"][0].shape == (2, 2**6) 124 | assert_close(g1["lr"], tensor(0.3)) 125 | assert_close(g1["weight_decay"], 0.05 / 0.3) 126 | 127 | assert isinstance(g2, dict) 128 | assert g2["params"][0].shape == (3, 2**6) 129 | assert_close(g2["lr"], tensor(0.02 / 2**3 / 5**0.5)) 130 | assert g2["weight_decay"] == 0 131 | 132 | # ### Check error conditions ### 133 | 134 | # No lr, missing for group 1 135 | with pytest.raises(ValueError): 136 | params = scaled_parameters(original_params, lr_scale_func_adam) 137 | # No need for an lr when all groups have it explicitly 138 | params = scaled_parameters(original_params[:1], lr_scale_func_adam) 139 | assert len(params) == 2 # type:ignore[arg-type] 140 | 141 | # Non-unit-scaling parameters 142 | with pytest.raises(ValueError): 143 | params = scaled_parameters( 144 | original_params + [dict(params=[nn.Parameter(zeros(4, 2**4))])], 145 | lr_scale_func_adam, 146 | lr=0.1, 147 | ) 148 | # Allow non-unit-scaling parameters 149 | params = scaled_parameters( 150 | original_params + [dict(params=[nn.Parameter(zeros(4, 2**4))])], 151 | lr_scale_func_adam, 152 | lr=0.1, 153 | allow_non_unit_scaling_params=True, 154 | ) 155 | assert len(params) == 4 # type:ignore[arg-type] 156 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_parameter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | 3 | import copy 4 | import pickle 5 | 6 | import torch 7 | from torch import Tensor, empty, full, tensor 8 | from torch.testing import assert_close 9 | 10 | from ..parameter import Parameter, has_parameter_data 11 | 12 | 13 | def test_parameter() -> None: 14 | param = Parameter(torch.zeros(10), "weight") 15 | assert has_parameter_data(param) 16 | assert param.mup_type == "weight" 17 | assert param.mup_scaling_depth is None 18 | 19 | param.mup_scaling_depth = 8 20 | 21 | param_copy = copy.deepcopy(param) 22 | assert has_parameter_data(param_copy) # type:ignore[arg-type] 23 | assert param_copy.mup_type == "weight" 24 | assert param_copy.mup_scaling_depth == 8 25 | 26 | param_pickle = pickle.loads(pickle.dumps(param)) 27 | assert has_parameter_data(param_pickle) 28 | assert param_pickle.mup_type == "weight" 29 | assert param_pickle.mup_scaling_depth == 8 30 | 31 | param_pickle_copy = copy.deepcopy(param_pickle) 32 | assert has_parameter_data(param_pickle_copy) # type:ignore[arg-type] 33 | assert param_pickle_copy.mup_type == "weight" 34 | assert param_pickle_copy.mup_scaling_depth == 8 35 | 36 | 37 | def test_parameter_compile() -> None: 38 | parameter = Parameter(empty(3), mup_type="norm") 39 | 40 | def update_parameter(mult: Tensor) -> Tensor: 41 | parameter.data.mul_(mult) 42 | return parameter 43 | 44 | parameter.data.fill_(1) 45 | assert_close(update_parameter(tensor(123.0)), full((3,), 123.0)) 46 | 47 | parameter.data.fill_(1) 48 | update_parameter = torch.compile(fullgraph=True)(update_parameter) 49 | assert_close(update_parameter(tensor(0.5)), full((3,), 0.5)) 50 | assert_close(update_parameter(tensor(8.0)), full((3,), 4.0)) 51 | assert_close(update_parameter(parameter), full((3,), 16.0)) 52 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import torch 4 | from torch import randn 5 | 6 | from ..scale import scale_bwd, scale_fwd 7 | from .helper import unit_backward 8 | 9 | 10 | def test_scale_fwd() -> None: 11 | x = randn(2**10, requires_grad=True) 12 | x_scaled = scale_fwd(x, 3.5) 13 | grad_in = unit_backward(x_scaled) 14 | 15 | assert torch.equal(x_scaled, x * 3.5) 16 | assert torch.equal(x.grad, grad_in) # type: ignore 17 | 18 | 19 | def test_scale_bwd() -> None: 20 | x = randn(2**10, requires_grad=True) 21 | x_scaled = scale_bwd(x, 3.5) 22 | grad_in = unit_backward(x_scaled) 23 | 24 | assert torch.equal(x_scaled, x) 25 | assert torch.equal(x.grad, grad_in * 3.5) # type: ignore 26 | -------------------------------------------------------------------------------- /unit_scaling/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import re 4 | 5 | import torch 6 | 7 | from .._modules import MHSA, MLP 8 | from ..utils import analyse_module 9 | 10 | 11 | def remove_scales(code: str) -> str: 12 | """Takes a code string containing scale annotations such as `(-> 1.0, <- 0.819)` and 13 | replaces them all with `(-> _, <- _)`.""" 14 | return re.sub(r"\(-> \d+(\.\d+)?, <- \d+(\.\d+)?\)", "(-> _, <- _)", code) 15 | 16 | 17 | def test_analyse_mlp() -> None: 18 | batch_size = 2**10 19 | hidden_size = 2**10 20 | input = torch.randn(batch_size, hidden_size).requires_grad_() 21 | backward = torch.randn(batch_size, hidden_size) 22 | 23 | annotated_code = analyse_module( 24 | MLP(hidden_size), input, backward, syntax_highlight=False 25 | ) 26 | print(annotated_code) 27 | 28 | expected_code = """ 29 | def forward(self, input : Tensor) -> Tensor: 30 | input_1 = input; (-> 1.0, <- 1.44) 31 | linear_1_weight = self.linear_1.weight; (-> 1.0, <- 0.503) 32 | linear = U.linear(input_1, linear_1_weight, None, None); (-> 1.0, <- 0.502) 33 | linear_gate_weight = self.linear_gate.weight; (-> 1.0, <- 0.519) 34 | linear_1 = U.linear(input_1, linear_gate_weight, None, None); (-> 1.0, <- 0.518) 35 | silu_glu = U.silu_glu(linear, linear_1); (-> 1.0, <- 0.5) 36 | linear_2_weight = self.linear_2.weight; (-> 1.0, <- 1.0) 37 | linear_2 = U.linear(silu_glu, linear_2_weight, None, None); (-> 1.0, <- 1.0) 38 | return linear_2 39 | """.strip() # noqa: E501 40 | 41 | assert remove_scales(annotated_code) == remove_scales(expected_code) 42 | 43 | 44 | def test_analyse_mhsa() -> None: 45 | batch_size = 2**8 46 | seq_len = 2**6 47 | hidden_size = 2**6 48 | heads = 4 49 | input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_() 50 | backward = torch.randn(batch_size, seq_len, hidden_size) 51 | 52 | annotated_code = analyse_module( 53 | MHSA(hidden_size, heads, is_causal=False, dropout_p=0.1), 54 | input, 55 | backward, 56 | syntax_highlight=False, 57 | ) 58 | print(annotated_code) 59 | 60 | expected_code = """ 61 | def forward(self, input : Tensor) -> Tensor: 62 | input_1 = input; (-> 1.0, <- 1.13) 63 | linear_qkv_weight = self.linear_qkv.weight; (-> 1.01, <- 0.662) 64 | linear = U.linear(input_1, linear_qkv_weight, None, 'to_output_scale'); (-> 1.01, <- 0.633) 65 | rearrange = einops_einops_rearrange(linear, 'b s (z h d) -> z b h s d', h = 4, z = 3); (-> 1.01, <- 0.633) 66 | getitem = rearrange[0]; (-> 1.0, <- 0.344) 67 | getitem_1 = rearrange[1]; (-> 1.0, <- 0.257) 68 | getitem_2 = rearrange[2]; (-> 1.02, <- 1.01) 69 | scaled_dot_product_attention = U.scaled_dot_product_attention(getitem, getitem_1, getitem_2, dropout_p = 0.1, is_causal = False, mult = 1.0); (-> 1.04, <- 1.0) 70 | rearrange_1 = einops_einops_rearrange(scaled_dot_product_attention, 'b h s d -> b s (h d)'); (-> 1.04, <- 1.0) 71 | linear_o_weight = self.linear_o.weight; (-> 1.0, <- 1.03) 72 | linear_1 = U.linear(rearrange_1, linear_o_weight, None, 'to_output_scale'); (-> 1.06, <- 1.0) 73 | return linear_1 74 | """.strip() # noqa: E501 75 | 76 | assert remove_scales(annotated_code) == remove_scales(expected_code) 77 | -------------------------------------------------------------------------------- /unit_scaling/tests/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | -------------------------------------------------------------------------------- /unit_scaling/tests/transforms/test_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import Tuple 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import Tensor, nn 8 | 9 | from ...transforms import compile, unit_scale 10 | 11 | 12 | def test_compile() -> None: 13 | class Module(nn.Module): # pragma: no cover 14 | def __init__( 15 | self, 16 | hidden_size: int, 17 | ) -> None: 18 | super().__init__() 19 | self.layer_norm = nn.LayerNorm(hidden_size) 20 | self.l1 = nn.Linear(hidden_size, 4 * hidden_size) 21 | self.l2 = nn.Linear(4 * hidden_size, hidden_size) 22 | 23 | def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: 24 | input = self.layer_norm(input) 25 | input = self.l1(input) 26 | input = F.gelu(input) 27 | input = self.l2(input) 28 | input = F.dropout(input, 0.2) 29 | return input, input.sum() 30 | 31 | mod = Module(2**6) 32 | x = torch.randn(2**3, 2**6) 33 | 34 | compile(mod)(x) 35 | compile(unit_scale(mod))(x) 36 | -------------------------------------------------------------------------------- /unit_scaling/tests/transforms/test_simulate_format.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import logging 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from pytest import LogCaptureFixture 8 | from torch import Tensor, nn, randn 9 | 10 | from ... import _modules as uu 11 | from ... import functional as U 12 | from ...transforms import simulate_fp8 13 | 14 | 15 | def test_simulate_fp8_linear(caplog: LogCaptureFixture) -> None: 16 | caplog.set_level(logging.INFO) 17 | 18 | class Model(nn.Module): 19 | def __init__(self, d_in: int, d_out: int) -> None: 20 | super().__init__() 21 | self.linear = nn.Linear(d_in, d_out) 22 | 23 | def forward(self, t: Tensor) -> Tensor: 24 | return self.linear(t).sum() # type: ignore[no-any-return] 25 | 26 | input = randn(2**8, 2**9, requires_grad=True) 27 | model = Model(2**9, 2**10) 28 | output = model(input) 29 | output.backward() 30 | 31 | fp8_input = input.clone().detach().requires_grad_() 32 | fp8_model = simulate_fp8(model) 33 | fp8_output = fp8_model(fp8_input) 34 | fp8_output.backward() 35 | 36 | assert not torch.all(fp8_output == output) 37 | assert not torch.all(fp8_input.grad == input.grad) # type: ignore 38 | assert not torch.all( 39 | fp8_model.linear.weight.grad == model.linear.weight.grad # type: ignore 40 | ) 41 | assert "quantising function" in caplog.text 42 | 43 | 44 | def test_simulate_fp8_unit_scaled_linear() -> None: 45 | class Model(nn.Module): 46 | def __init__(self, d_in: int, d_out: int) -> None: 47 | super().__init__() 48 | self.linear = uu.Linear(d_in, d_out) 49 | 50 | def forward(self, t: Tensor) -> Tensor: 51 | return self.linear(t).sum() # type: ignore[no-any-return] 52 | 53 | input = randn(2**8, 2**9, requires_grad=True) 54 | model = Model(2**9, 2**10) 55 | output = model(input) 56 | output.backward() 57 | 58 | fp8_input = input.clone().detach().requires_grad_() 59 | fp8_model = simulate_fp8(model) 60 | fp8_output = fp8_model(fp8_input) 61 | fp8_output.backward() 62 | 63 | assert not torch.all(fp8_output == output) 64 | assert not torch.all(fp8_input.grad == input.grad) # type: ignore 65 | assert not torch.all( 66 | fp8_model.linear.weight.grad == model.linear.weight.grad # type: ignore 67 | ) 68 | 69 | 70 | def test_simulate_fp8_attention() -> None: 71 | class Model(nn.Module): 72 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 73 | return F.scaled_dot_product_attention(q, k, v).sum() 74 | 75 | inputs = list(randn(2**8, 2**8, requires_grad=True) for _ in range(3)) 76 | model = Model() 77 | output = model(*inputs) 78 | output.backward() 79 | 80 | fp8_inputs = list(t.clone().detach().requires_grad_() for t in inputs) 81 | fp8_model = simulate_fp8(model) 82 | 83 | fp8_output = fp8_model(*fp8_inputs) 84 | fp8_output.backward() 85 | 86 | assert not torch.all(fp8_output == output) 87 | for fp8_input, input in zip(fp8_inputs, inputs): 88 | assert not torch.all(fp8_input.grad == input.grad) # type: ignore 89 | 90 | 91 | def test_simulate_fp8_unit_scaled_attention() -> None: 92 | class Model(nn.Module): 93 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 94 | return U.scaled_dot_product_attention(q, k, v).sum() 95 | 96 | inputs = list(randn(2**8, 2**8, requires_grad=True) for _ in range(3)) 97 | model = Model() 98 | output = model(*inputs) 99 | output.backward() 100 | 101 | fp8_inputs = list(t.clone().detach().requires_grad_() for t in inputs) 102 | fp8_model = simulate_fp8(model) 103 | 104 | fp8_output = fp8_model(*fp8_inputs) 105 | fp8_output.backward() 106 | 107 | assert not torch.all(fp8_output == output) 108 | for fp8_input, input in zip(fp8_inputs, inputs): 109 | assert not torch.all(fp8_input.grad == input.grad) # type: ignore 110 | -------------------------------------------------------------------------------- /unit_scaling/tests/transforms/test_track_scales.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import operator 4 | from math import pi, sqrt 5 | from typing import Any, Callable, Dict, Set, Tuple, Union 6 | 7 | import pytest 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import Tensor, nn, randint, randn, randn_like 11 | from torch.fx.graph import Graph 12 | from torch.fx.node import Node 13 | 14 | from ...transforms import ( 15 | prune_non_float_tensors, 16 | prune_same_scale_tensors, 17 | prune_selected_nodes, 18 | track_scales, 19 | ) 20 | 21 | 22 | def get_target_or_node_name(node: Node) -> Union[str, Callable[..., Any]]: 23 | return node.meta["clean_name"] if isinstance(node.target, str) else node.target 24 | 25 | 26 | def get_targets(graph: Graph) -> Set[Union[str, Callable]]: # type: ignore[type-arg] 27 | return set(get_target_or_node_name(node) for node in graph.nodes) 28 | 29 | 30 | def get_target_map( 31 | graph: Graph, 32 | ) -> Dict[Union[str, Callable], Dict[str, Any]]: # type: ignore[type-arg] 33 | return {get_target_or_node_name(node): node.meta for node in graph.nodes} 34 | 35 | 36 | def test_track_scales() -> None: 37 | class Model(nn.Module): 38 | def forward(self, x: Tensor) -> Tensor: # pragma: no cover 39 | x = F.relu(x) 40 | y = torch.ones_like(x, dtype=x.dtype) 41 | z = x + y 42 | return z.sum() 43 | 44 | model = Model() 45 | model = track_scales(model) 46 | assert len(model.scales_graph().nodes) == 0 47 | 48 | input = randn(2**4, 2**10) 49 | loss = model(input) 50 | 51 | graph = model.scales_graph() 52 | 53 | assert all("outputs_float_tensor" in n.meta for n in graph.nodes) 54 | meta_map = get_target_map(graph) 55 | 56 | assert "metrics" in meta_map["x"] 57 | assert meta_map["x"]["metrics"].bwd is None 58 | assert meta_map["x"]["metrics"].fwd.mean_abs == pytest.approx( 59 | sqrt(2 / pi), abs=0.01 60 | ) 61 | assert meta_map["x"]["metrics"].fwd.abs_mean == pytest.approx(0, abs=0.01) 62 | assert meta_map["x"]["metrics"].fwd.std == pytest.approx(1, abs=0.01) 63 | assert meta_map["x"]["metrics"].fwd.numel == 2**14 64 | 65 | assert "metrics" in meta_map[F.relu] 66 | assert meta_map[F.relu]["metrics"].bwd is None 67 | assert ( 68 | meta_map[F.relu]["metrics"].fwd.mean_abs 69 | == meta_map[F.relu]["metrics"].fwd.abs_mean 70 | == pytest.approx(sqrt(1 / (2 * pi)), abs=0.01) 71 | ) 72 | assert meta_map[F.relu]["metrics"].fwd.std == pytest.approx( 73 | sqrt((1 - 1 / pi) / 2), abs=0.01 74 | ) 75 | assert meta_map[F.relu]["metrics"].fwd.numel == 2**14 76 | 77 | assert "metrics" in meta_map[torch.ones_like] 78 | assert meta_map[torch.ones_like]["metrics"].bwd is None 79 | assert meta_map[torch.ones_like]["metrics"].fwd.mean_abs == 1.0 80 | assert meta_map[torch.ones_like]["metrics"].fwd.abs_mean == 1.0 81 | assert meta_map[torch.ones_like]["metrics"].fwd.std == 0.0 82 | assert meta_map[torch.ones_like]["metrics"].fwd.abs_max == 1.0 83 | assert meta_map[torch.ones_like]["metrics"].fwd.abs_min == 1.0 84 | assert meta_map[torch.ones_like]["metrics"].fwd.numel == 2**14 85 | 86 | assert "metrics" in meta_map[operator.add] 87 | assert "metrics" in meta_map["sum_1"] 88 | assert meta_map["sum_1"]["metrics"].fwd.numel == 1 89 | 90 | loss.backward() 91 | graph = model.scales_graph() 92 | 93 | meta_map = get_target_map(graph) 94 | 95 | assert meta_map["sum_1"]["metrics"].bwd is not None 96 | assert meta_map["sum_1"]["metrics"].bwd.numel == 1 97 | 98 | assert meta_map[operator.add]["metrics"].bwd is not None 99 | assert meta_map[operator.add]["metrics"].bwd.mean_abs == 1.0 100 | assert meta_map[operator.add]["metrics"].bwd.abs_mean == 1.0 101 | assert meta_map[operator.add]["metrics"].bwd.std == 0.0 102 | assert meta_map[operator.add]["metrics"].bwd.abs_max == 1.0 103 | assert meta_map[operator.add]["metrics"].bwd.abs_min == 1.0 104 | assert meta_map[operator.add]["metrics"].bwd.numel == 2**14 105 | 106 | assert meta_map["x"]["metrics"].bwd is not None 107 | assert meta_map["x"]["metrics"].bwd.std == pytest.approx( 108 | 0.5, abs=0.01 # same as fwd pass except 0s are now 1s 109 | ) 110 | 111 | 112 | def test_prune_non_float_tensors() -> None: 113 | class Model(nn.Module): 114 | def __init__(self, emb_size: int, dim: int) -> None: 115 | super().__init__() 116 | self.emb = nn.Embedding(emb_size, dim) 117 | self.linear = nn.Linear(dim, dim) 118 | 119 | def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover 120 | x = self.emb(idxs) 121 | scores = F.softmax(self.linear(x), dim=-1) 122 | top_idx = torch.argmax(scores, dim=-1) 123 | top_idx = torch.unsqueeze(top_idx, -1) 124 | top_score_x = torch.gather(x, -1, top_idx) 125 | top_score_x -= x.mean() 126 | return top_score_x, top_idx 127 | 128 | idxs = randint(0, 2**10, (2**3, 2**5)) 129 | model = Model(2**10, 2**6) 130 | model = track_scales(model) 131 | model(idxs) 132 | 133 | graph = model.scales_graph() 134 | expected_targets = { 135 | "idxs", 136 | "self_modules_emb_parameters_weight", 137 | F.embedding, 138 | F.linear, 139 | "self_modules_linear_parameters_weight", 140 | "self_modules_linear_parameters_bias", 141 | F.softmax, 142 | torch.argmax, 143 | torch.unsqueeze, 144 | torch.gather, 145 | operator.isub, 146 | "mean", 147 | "output", 148 | } 149 | graph_targets = get_targets(graph) 150 | assert graph_targets == expected_targets 151 | 152 | graph = prune_non_float_tensors(graph) 153 | graph_targets = get_targets(graph) 154 | expected_targets -= {"idxs", torch.argmax, torch.unsqueeze} 155 | assert graph_targets == expected_targets 156 | 157 | 158 | def test_prune_same_scale_tensors() -> None: 159 | class Model(nn.Module): 160 | def __init__(self, emb_size: int, dim: int) -> None: 161 | super().__init__() 162 | self.emb = nn.Embedding(emb_size, dim) 163 | self.linear = nn.Linear(dim, dim // 2) 164 | 165 | def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover 166 | # idxs has 0 args -> shouldn't be pruned 167 | x = self.emb(idxs) # emb has 1 float arg (weights) -> depends on tol 168 | _x = x.flatten(start_dim=0, end_dim=-1) # 1 float, same scale -> prune 169 | x = _x.view(x.shape) # 1 float arg, same scale -> prune 170 | y = self.linear(x) # scale changes -> shouldn't be pruned 171 | scores = F.softmax(y, dim=-1) # scale changes -> shouldn't be pruned 172 | top_idx = torch.argmax(scores, dim=-1) # not float -> shouldn't be pruned 173 | top_idx = torch.unsqueeze(top_idx, -1) # not float -> shouldn't be pruned 174 | top_score_x = torch.gather(x, -1, top_idx) # small change -> depends on tol 175 | top_score_x += randn_like(top_score_x) # 2 floats, same scale -> no prune 176 | return top_score_x, top_idx 177 | 178 | idxs = randint(0, 2**10, (2**3, 2**5)) 179 | model = Model(2**10, 2**6) 180 | model = track_scales(model) 181 | model(idxs) 182 | 183 | graph = model.scales_graph() 184 | 185 | # Version-dependent, see https://github.com/graphcore-research/unit-scaling/pull/52 186 | var_lhs_flatten = "x" 187 | var_lhs_view = "x_1" 188 | expected_targets = { 189 | "idxs", 190 | "self_modules_emb_parameters_weight", 191 | F.embedding, 192 | var_lhs_flatten, 193 | var_lhs_view, 194 | "self_modules_linear_parameters_weight", 195 | "self_modules_linear_parameters_bias", 196 | F.linear, 197 | F.softmax, 198 | torch.argmax, 199 | torch.unsqueeze, 200 | torch.gather, 201 | randn_like, 202 | operator.iadd, 203 | "output", 204 | } 205 | graph_targets = get_targets(graph) 206 | assert graph_targets == expected_targets 207 | 208 | graph = prune_same_scale_tensors(graph) 209 | graph_targets = get_targets(graph) 210 | expected_targets -= {var_lhs_flatten, var_lhs_view} 211 | assert graph_targets == expected_targets 212 | 213 | graph = prune_same_scale_tensors(graph, rtol=2**-4) 214 | graph_targets = get_targets(graph) 215 | expected_targets -= {torch.gather, F.embedding} 216 | assert graph_targets == expected_targets 217 | 218 | 219 | def test_prune_same_scale_tensors_with_grad() -> None: 220 | class Model(nn.Module): 221 | def forward(self, a: Tensor) -> Tensor: # pragma: no cover 222 | b = a / 1.0 # same scale fwd & bwd 223 | c = b * 1.0 # same scale fwd, as b sums grads -> different scale bwd 224 | d = F.relu(c) # different scale fwd & bwd 225 | e = b - d # different scale fwd, same bwd 226 | f = e.sum() # different scale fwd & bwd 227 | return f 228 | 229 | input = randn(2**6, 2**8) 230 | model = Model() 231 | model = track_scales(model) 232 | loss = model(input) 233 | 234 | graph = model.scales_graph() 235 | expected_targets = { 236 | "a", 237 | operator.truediv, 238 | operator.mul, 239 | F.relu, 240 | operator.sub, 241 | "f", 242 | "output", 243 | } 244 | graph_targets = get_targets(graph) 245 | assert graph_targets == expected_targets 246 | 247 | graph = prune_same_scale_tensors(graph) 248 | graph_targets = get_targets(graph) 249 | expected_targets -= {operator.truediv, operator.mul} 250 | assert graph_targets == expected_targets 251 | 252 | # The mul still has the same scale before & after in the fwd pass, but the same is 253 | # not true for its grads. It should no longer be pruned after `loss.backward`. 254 | loss.backward() 255 | graph = model.scales_graph() 256 | graph = prune_same_scale_tensors(graph) 257 | graph_targets = get_targets(graph) 258 | expected_targets.add(operator.mul) 259 | assert graph_targets == expected_targets 260 | 261 | 262 | def test_prune_selected_nodes() -> None: 263 | class Model(nn.Module): 264 | def forward(self, x: Tensor) -> Tensor: # pragma: no cover 265 | x = x + 1 266 | x = F.relu(x) 267 | x = torch.abs(x) 268 | return x.sum() 269 | 270 | input = randn(2**6, 2**8) 271 | model = Model() 272 | model = track_scales(model) 273 | model(input) 274 | 275 | graph = model.scales_graph() 276 | expected_targets = { 277 | "x", 278 | operator.add, 279 | F.relu, 280 | torch.abs, 281 | "sum_1", 282 | "output", 283 | } 284 | graph_targets = get_targets(graph) 285 | assert graph_targets == expected_targets 286 | 287 | graph = prune_selected_nodes(graph, targets=[torch.abs, F.relu]) 288 | graph_targets = get_targets(graph) 289 | expected_targets -= {torch.abs, F.relu} 290 | assert graph_targets == expected_targets 291 | -------------------------------------------------------------------------------- /unit_scaling/tests/transforms/test_unit_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import logging 4 | import math 5 | import re 6 | from typing import Tuple 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from pytest import LogCaptureFixture 11 | from torch import Tensor, nn, randn 12 | 13 | from ...transforms import simulate_fp8, unit_scale 14 | from ..helper import assert_unit_scaled 15 | 16 | 17 | def test_unit_scale(caplog: LogCaptureFixture) -> None: 18 | caplog.set_level(logging.INFO) 19 | 20 | def custom_gelu(x: Tensor) -> Tensor: 21 | inner = math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) 22 | return 0.5 * x * (1.0 + torch.tanh(inner)) 23 | 24 | class MLPLayer(nn.Module): # pragma: no cover 25 | def __init__( 26 | self, 27 | hidden_size: int, 28 | ) -> None: 29 | super().__init__() 30 | self.layer_norm = nn.LayerNorm(hidden_size) 31 | self.l1 = nn.Linear(hidden_size, 4 * hidden_size) 32 | self.l2 = nn.Linear(4 * hidden_size, hidden_size) 33 | 34 | def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: 35 | input = self.layer_norm(input) 36 | input = self.l1(input) 37 | input = custom_gelu(input) 38 | input = self.l2(input) 39 | input = F.dropout(input, 0.2) 40 | return input, input.sum() 41 | 42 | input = randn(2**6, 2**10, requires_grad=True) 43 | model = nn.Sequential(MLPLayer(2**10)) 44 | model = unit_scale(model, replace={custom_gelu: F.gelu}) 45 | output, loss = model(input) 46 | loss.backward() 47 | 48 | assert_unit_scaled( 49 | output, 50 | input.grad, 51 | model[0].layer_norm.weight.grad, 52 | model[0].l1.weight.grad, 53 | model[0].l2.weight.grad, 54 | abs=0.2, 55 | ) 56 | 57 | expected_logs = [ 58 | "unit scaling weight", 59 | "setting bias to zero", 60 | "unit scaling function", 61 | "replacing function", 62 | "unconstraining node", 63 | ] 64 | for log_msg in expected_logs: 65 | assert log_msg in caplog.text 66 | 67 | 68 | def test_unit_scale_residual_add(caplog: LogCaptureFixture) -> None: 69 | caplog.set_level(logging.INFO) 70 | 71 | class MLPLayer(nn.Module): 72 | def __init__( 73 | self, 74 | hidden_size: int, 75 | ) -> None: 76 | super().__init__() 77 | self.l1 = nn.Linear(hidden_size, hidden_size) 78 | self.l2 = nn.Linear(hidden_size, hidden_size) 79 | 80 | def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover 81 | skip = input 82 | input = input + 1 83 | input = self.l1(input) 84 | input = input + skip 85 | skip = input 86 | input += 1 87 | input = self.l2(input) 88 | input += skip 89 | return input, input.sum() 90 | 91 | input = randn(2**6, 2**10, requires_grad=True) 92 | model = MLPLayer(2**10) 93 | us_model = unit_scale(model) 94 | output, loss = us_model(input) 95 | loss.backward() 96 | 97 | expected_logs = [ 98 | r"unit scaling function: (input_2)\n", 99 | r"unit scaling function: (input_4)\n", 100 | r"unit scaling function: (skip_1|input_3) \(residual-add, tau=0\.5\)", 101 | r"unit scaling function: (add_1|input_6) \(residual-add, tau=0\.5\)", 102 | ] 103 | 104 | for log_msg in expected_logs: 105 | assert re.search(log_msg, caplog.text) 106 | 107 | 108 | def test_fp8_unit_scaling(caplog: LogCaptureFixture) -> None: 109 | caplog.set_level(logging.INFO) 110 | 111 | class Model(nn.Module): 112 | def __init__(self, d_in: int, d_out: int) -> None: 113 | super().__init__() 114 | self.linear = nn.Linear(d_in, d_out) 115 | 116 | def forward(self, t: Tensor) -> Tensor: # pragma: no cover 117 | return self.linear(t) # type: ignore[no-any-return] 118 | 119 | input = randn(2**8, 2**9) 120 | model = Model(2**9, 2**10) 121 | model = simulate_fp8(model) 122 | model = unit_scale(model) 123 | model(input) 124 | 125 | expected_logs = [ 126 | "moving unit scaling backend to precede quantisation backend", 127 | "running unit scaling backend", 128 | "running quantisation backend", 129 | ] 130 | for log_msg in expected_logs: 131 | assert log_msg in caplog.text 132 | -------------------------------------------------------------------------------- /unit_scaling/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | """Useful torch dynamo transforms of modules for the sake of numerics and unit 4 | scaling.""" 5 | 6 | # This all has to be done manually to keep mypy happy. 7 | # Removing the `--no-implicit-reexport` option ought to fix this, but doesn't appear to. 8 | 9 | from ._compile import compile 10 | from ._simulate_format import simulate_format, simulate_fp8 11 | from ._track_scales import ( 12 | Metrics, 13 | prune_non_float_tensors, 14 | prune_same_scale_tensors, 15 | prune_selected_nodes, 16 | track_scales, 17 | ) 18 | from ._unit_scale import unit_scale 19 | 20 | __all__ = [ 21 | "Metrics", 22 | "compile", 23 | "prune_non_float_tensors", 24 | "prune_same_scale_tensors", 25 | "prune_selected_nodes", 26 | "simulate_format", 27 | "simulate_fp8", 28 | "track_scales", 29 | "unit_scale", 30 | ] 31 | -------------------------------------------------------------------------------- /unit_scaling/transforms/_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from typing import TypeVar 4 | 5 | from torch import _TorchCompileInductorWrapper, nn 6 | 7 | from .utils import apply_transform 8 | 9 | M = TypeVar("M", bound=nn.Module) 10 | 11 | 12 | def compile(module: M) -> M: 13 | """A transform that applies torch.compile to a module. 14 | 15 | Note that this is slightly different to calling :code:`torch.compile(module)`. 16 | 17 | The current version of :func:`torch.compile` doesn't allow for nested transforms, so 18 | the following is not supported: 19 | 20 | .. code-block:: python 21 | 22 | import torch 23 | 24 | from unit_scaling.transforms import unit_scale 25 | 26 | module = torch.compile(unit_scale(module)) 27 | 28 | :mod:`unit_scaling.transforms` addresses this by introducing a range of composable 29 | transforms. This works by moving the call to 30 | :func:`torch._dynamo.optimize` within the :code:`forward()` method of the module 31 | and only executing it on the first call to the module, or if a new transform 32 | is applied, the optimised call being cached thereafter. 33 | 34 | The :func:`unit_scaling.transforms.compile` function is one such composable 35 | transform. This means that the following can be written: 36 | 37 | .. code-block:: python 38 | 39 | from unit_scaling.transforms import compile, unit_scale 40 | 41 | module = compile(unit_scale(module)) 42 | 43 | This will successfully combine the two transforms in a single module. Note that 44 | the call to compile must still come last, as its underlying backend returns a 45 | standard :class:`torch.nn.Module` rather than a :class:`torch.fx.GraphModule`. 46 | 47 | Currently :func:`unit_scaling.transforms.compile` does not support the ops needed 48 | for the :func:`unit_scaling.transforms.simulate_fp8` transform, though this may 49 | change in future PyTorch releases. 50 | 51 | Modules implemented manually with unit-scaled layers (i.e. without the global 52 | :code:`unit_scale(module)` transform) can still use :func:`torch.compile` in the 53 | standard way. 54 | 55 | Args: 56 | module (M): the module to be compiled. 57 | 58 | Returns: 59 | M: the compiled module. 60 | """ 61 | return apply_transform( # type: ignore[no-any-return] 62 | module, _TorchCompileInductorWrapper("default", None, None) # type: ignore 63 | ) 64 | -------------------------------------------------------------------------------- /unit_scaling/transforms/_simulate_format.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import logging 3 | from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar 4 | 5 | import torch.nn.functional as F 6 | from torch import Tensor, nn 7 | from torch.fx.graph import Graph 8 | from torch.fx.graph_module import GraphModule 9 | from torch.fx.node import Node 10 | 11 | from .. import functional as U 12 | from .._internal_utils import generate__all__ 13 | from ..formats import FPFormat, format_to_tuple, tuple_to_format 14 | from .utils import Backend, apply_transform, replace_node_with_function 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | # These functions currently have to be defined explicitly to make PyTorch happy 20 | # Creating temporary wrapped functions doesn't work... 21 | def _quantised_linear( 22 | input: Tensor, 23 | weight: Tensor, 24 | bias: Optional[Tensor], 25 | fwd_format_tuple: Tuple[int, int], 26 | bwd_format_tuple: Tuple[int, int], 27 | ) -> Tensor: 28 | fwd_format = tuple_to_format(fwd_format_tuple) 29 | bwd_format = tuple_to_format(bwd_format_tuple) 30 | input = fwd_format.quantise_fwd(input) 31 | weight = fwd_format.quantise_fwd(weight) 32 | output = F.linear(input, weight, bias) 33 | return bwd_format.quantise_bwd(output) 34 | 35 | 36 | def _quantised_u_linear( 37 | input: Tensor, 38 | weight: Tensor, 39 | bias: Optional[Tensor], 40 | fwd_format_tuple: Tuple[int, int], 41 | bwd_format_tuple: Tuple[int, int], 42 | constraint: Optional[str] = "to_output_scale", 43 | ) -> Tensor: 44 | fwd_format = tuple_to_format(fwd_format_tuple) 45 | bwd_format = tuple_to_format(bwd_format_tuple) 46 | input, weight = (fwd_format.quantise_fwd(t) for t in (input, weight)) 47 | output = U.linear(input, weight, bias, constraint) 48 | return bwd_format.quantise_bwd(output) 49 | 50 | 51 | def _quantised_scaled_dot_product_attention( 52 | query: Tensor, 53 | key: Tensor, 54 | value: Tensor, 55 | fwd_format_tuple: Tuple[int, int], 56 | bwd_format_tuple: Tuple[int, int], 57 | **kwargs: Any, 58 | ) -> Tensor: 59 | fwd_format = tuple_to_format(fwd_format_tuple) 60 | bwd_format = tuple_to_format(bwd_format_tuple) 61 | query, key, value = (fwd_format.quantise_fwd(t) for t in (query, key, value)) 62 | output = F.scaled_dot_product_attention(query, key, value, **kwargs) 63 | return bwd_format.quantise_bwd(output) 64 | 65 | 66 | def _quantised_u_scaled_dot_product_attention( 67 | query: Tensor, 68 | key: Tensor, 69 | value: Tensor, 70 | fwd_format_tuple: Tuple[int, int], 71 | bwd_format_tuple: Tuple[int, int], 72 | **kwargs: Any, 73 | ) -> Tensor: 74 | fwd_format = tuple_to_format(fwd_format_tuple) 75 | bwd_format = tuple_to_format(bwd_format_tuple) 76 | query, key, value = (fwd_format.quantise_fwd(t) for t in (query, key, value)) 77 | output = U.scaled_dot_product_attention(query, key, value, **kwargs) 78 | return bwd_format.quantise_bwd(output) 79 | 80 | 81 | _replacement_map: Dict[Callable[..., Any], Callable[..., Any]] = { 82 | F.linear: _quantised_linear, 83 | U.linear: _quantised_u_linear, 84 | F.scaled_dot_product_attention: _quantised_scaled_dot_product_attention, 85 | U.scaled_dot_product_attention: _quantised_u_scaled_dot_product_attention, 86 | } 87 | 88 | 89 | def _replace_with_quantised( 90 | graph: Graph, 91 | node: Node, 92 | fwd_format: FPFormat, 93 | bwd_format: FPFormat, 94 | ) -> None: 95 | # Ideally we'd pass the formats as kwargs, but it currently causes a torch fx bug. 96 | # This workaround will suffice for now... 97 | args = [*node.args] 98 | if len(node.args) == 2: # pragma: no cover 99 | args.append(None) 100 | # Breaks when I pass in FPFormat objects, so convert to tuple and back 101 | args = ( 102 | args[:3] + [format_to_tuple(fwd_format), format_to_tuple(bwd_format)] + args[3:] 103 | ) 104 | 105 | assert callable(node.target) 106 | quantised_fn = _replacement_map[node.target] 107 | logger.info("quantising function: %s", node) 108 | replace_node_with_function(graph, node, quantised_fn, args=tuple(args)) 109 | 110 | 111 | def _quantisation_backend(fwd_format: FPFormat, bwd_format: FPFormat) -> Backend: 112 | def backend_fn(gm: GraphModule, example_inputs: List[Tensor]) -> GraphModule: 113 | logger.info("running quantisation backend") 114 | graph = gm.graph 115 | for node in graph.nodes: 116 | if node.op == "call_function" and node.target in _replacement_map: 117 | _replace_with_quantised(graph, node, fwd_format, bwd_format) 118 | graph.lint() # type: ignore[no-untyped-call] 119 | return GraphModule(gm, graph) 120 | 121 | return backend_fn 122 | 123 | 124 | M = TypeVar("M", bound=nn.Module) 125 | 126 | 127 | def simulate_format(module: M, fwd_format: FPFormat, bwd_format: FPFormat) -> M: 128 | """**[Experimental]** Given a module, uses TorchDynamo to return a new module which 129 | simulates the effect of using the supplied formats for matmuls. 130 | 131 | Specifically, before each :func:`torch.nn.functional.linear` and 132 | :func:`torch.nn.functional.scaled_dot_product_attention` call, a quantisation op 133 | is inserted which simulates the effect of using the supplied `fwd_format`. This op 134 | reduces the range of values to that of the given format, and (stochastically) rounds 135 | values to only those representable by the format. 136 | 137 | The same is true for the backward pass, where an op is inserted to quantise to the 138 | `bwd_format`. Models which use modules that contain these functions internally 139 | (such as :class:`torch.nn.Linear`) will be inspected by TorchDynamo and have the 140 | correct quantisation ops inserted. 141 | 142 | If the equivalent unit-scaled functions from :mod:`unit_scaling.functional` are 143 | used in the module, these too will be quantised. 144 | 145 | Simulation of formats is run in FP32. Users should not expect speedups from using 146 | this method. The purpose is to simulate the numerical effects of running matmuls 147 | in various formats. 148 | 149 | Args: 150 | module (nn.Module): the module to be quantised 151 | fwd_format (FPFormat): the quantisation format to be used in the forward pass 152 | (activations and weights) 153 | bwd_format (FPFormat): the quantisation format to be used in the backward pass 154 | (gradients of activations and weights) 155 | 156 | Returns: 157 | nn.Module: a new module which when used, will run using the simulated formats. 158 | """ 159 | return apply_transform( # type: ignore[no-any-return] 160 | module, _quantisation_backend(fwd_format, bwd_format) 161 | ) 162 | 163 | 164 | def simulate_fp8(module: M) -> M: 165 | """**[Experimental]** Given a module, uses TorchDynamo to return a new module which 166 | simulates the effect of running matmuls in FP8. As is standard in the literature 167 | (Noune et al., 2022; Micikevicius et al., 2022), we use the FP8 E4 format in the 168 | forwards pass, and FP8 E5 in the backward pass. 169 | 170 | Specifically, before each :func:`torch.nn.functional.linear` and 171 | :func:`torch.nn.functional.scaled_dot_product_attention` call, a quantisation op 172 | is inserted which simulates the effect of using FP8. This op 173 | reduces the range of values to that of the format, and (stochastically) rounds 174 | values to only those representable by the format. 175 | 176 | The same is true for the backward pass. 177 | Models which use modules that contain these functions internally 178 | (such as :class:`torch.nn.Linear`) will be inspected by TorchDynamo and have the 179 | correct quantisation ops inserted. 180 | 181 | If the equivalent unit-scaled functions from :mod:`unit_scaling.functional` are 182 | used in the module, these too will be quantised. 183 | 184 | Simulation of formats is run in FP32. Users should not expect speedups from using 185 | this method. The purpose is to simulate the numerical effects of running matmuls 186 | in FP8. 187 | 188 | Args: 189 | module (nn.Module): the module to be quantised 190 | 191 | Returns: 192 | nn.Module: a new module which when used, will run with matmul inputs in FP8. 193 | """ 194 | return simulate_format(module, fwd_format=FPFormat(4, 3), bwd_format=FPFormat(5, 2)) 195 | 196 | 197 | __all__ = generate__all__(__name__) 198 | -------------------------------------------------------------------------------- /unit_scaling/transforms/_unit_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | import logging 4 | from inspect import signature 5 | from operator import getitem 6 | from types import BuiltinFunctionType 7 | from typing import Any, Callable, Dict, List, Set, TypeVar 8 | 9 | import torch 10 | import torch._dynamo 11 | import torch.nn.functional as F 12 | from torch import Tensor, nn 13 | from torch.fx.graph import Graph 14 | from torch.fx.graph_module import GraphModule 15 | from torch.fx.node import Node 16 | 17 | from .. import functional as U 18 | from .._internal_utils import generate__all__ 19 | from .utils import Backend, apply_transform, replace_node_with_function 20 | 21 | T = TypeVar("T") 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def _add_dependency_meta(graph: Graph, recalculate: bool = False) -> None: 26 | def recurse(n: Node) -> Set[Node]: 27 | if "dependencies" in n.meta: 28 | return n.meta["dependencies"] # type: ignore[no-any-return] 29 | deps = set(n.all_input_nodes) 30 | for parent in n.all_input_nodes: 31 | deps.update(recurse(parent)) 32 | n.meta["dependencies"] = deps 33 | return deps 34 | 35 | if recalculate: 36 | for n in graph.nodes: 37 | if "dependencies" in n.meta: 38 | del n.meta["dependencies"] 39 | 40 | for n in graph.nodes: 41 | if n.op == "output": 42 | recurse(n) 43 | 44 | 45 | def _is_add(n: Node) -> bool: 46 | return ( 47 | n.op == "call_function" 48 | and isinstance(n.target, BuiltinFunctionType) 49 | and n.target.__name__ in ["add", "iadd"] 50 | ) 51 | 52 | 53 | def _is_self_attention(skip_node: Node, residual_node: Node) -> bool: 54 | # Identify all operations on the residual branch 55 | residual_fns = {residual_node.target} 56 | parents = [*residual_node.all_input_nodes] 57 | while parents: 58 | p = parents.pop() 59 | if p == skip_node: 60 | continue 61 | residual_fns.add(p.target) 62 | parents += p.all_input_nodes 63 | # Check if any of them are self-attention ops (softmax or fused self-attention) 64 | self_attention_fns = { 65 | F.scaled_dot_product_attention, 66 | U.scaled_dot_product_attention, 67 | F.softmax, 68 | U.softmax, 69 | } 70 | return bool(residual_fns.intersection(self_attention_fns)) 71 | 72 | 73 | def _unit_scale_residual( 74 | graph: Graph, 75 | add: Node, 76 | residual_arg_idx: int, 77 | is_self_attention: bool, 78 | ) -> None: 79 | residual, skip = add.args[residual_arg_idx], add.args[1 - residual_arg_idx] 80 | tau = 0.01 if is_self_attention else 0.5 81 | logger.info("unit scaling function: %s (residual-add, tau=%s)", add, tau) 82 | old_start_residuals = [ 83 | u for u in skip.users if u is not add # type: ignore[union-attr] 84 | ] 85 | with graph.inserting_after(skip): # type: ignore[arg-type] 86 | split = graph.call_function( 87 | U.residual_split, 88 | args=(skip, tau), 89 | type_expr=getattr(skip, "type", None), 90 | ) 91 | with graph.inserting_after(split): 92 | new_start_residual = graph.call_function(getitem, args=(split, 0)) 93 | for old_start_residual in old_start_residuals: 94 | old_start_residual.replace_input_with(skip, new_start_residual) # type: ignore 95 | with graph.inserting_after(split): 96 | new_skip = graph.call_function(getitem, args=(split, 1)) 97 | replace_node_with_function( 98 | graph, add, U.residual_add, args=(residual, new_skip, tau) 99 | ) 100 | 101 | 102 | def _unconstrain_node(node: Node) -> None: 103 | if ( 104 | node.op == "call_function" 105 | and callable(node.target) 106 | and not isinstance(node.target, BuiltinFunctionType) 107 | and "constraint" in signature(node.target).parameters 108 | ): 109 | logger.info("unconstraining node: %s", node) 110 | node.kwargs = dict(node.kwargs, constraint=None) 111 | 112 | 113 | def unit_scaling_backend( 114 | replacement_map: Dict[Callable[..., Any], Callable[..., Any]] = dict() 115 | ) -> Backend: 116 | def inner_backend(gm: GraphModule, example_inputs: List[Tensor]) -> GraphModule: 117 | logger.info("running unit scaling backend") 118 | graph = gm.graph 119 | # Replace function nodes with those in `replacement_map` or with their 120 | # unit scaled equivalents 121 | for node in graph.nodes: 122 | if node.op == "call_function": 123 | if node.target in replacement_map: 124 | target_fn = replacement_map[node.target] 125 | logger.info( 126 | "replacing function: %s with %s", node, target_fn.__name__ 127 | ) 128 | replace_node_with_function(graph, node, target_fn) 129 | elif node.target in U.torch_map: 130 | target_fn = U.torch_map[node.target] 131 | logger.info("unit scaling function: %s", node) 132 | replace_node_with_function(graph, node, target_fn) 133 | 134 | # Add metadata denoting the dependencies of every node in the graph 135 | _add_dependency_meta(graph) 136 | 137 | # Go through and mark nodes which represent residual-adds 138 | residual_layer_number = 1 139 | for node in graph.nodes: 140 | if _is_add(node): 141 | is_residual_add = False 142 | if len(node.args) == 2: 143 | l, r = node.args 144 | if isinstance(l, Node) and isinstance(r, Node): 145 | l_deps = l.meta.get("dependencies", list()) 146 | r_deps = r.meta.get("dependencies", list()) 147 | if l in r_deps or r in l_deps: 148 | node.meta["residual_add"] = { 149 | "residual_arg_idx": 1 if l in r_deps else 0, 150 | } 151 | residual_layer_number += 1 152 | is_residual_add = True 153 | skip_node, residual_node = (l, r) if l in r_deps else (r, l) 154 | is_sa = _is_self_attention(skip_node, residual_node) 155 | node.meta["residual_add"]["is_self_attention"] = is_sa 156 | # Regular adds are not picked up by the unit scaling sweep above as 157 | # the inbuilt + operation is handled differently when traced. It is 158 | # instead substituted for its unit scaled equivalent here. 159 | if not is_residual_add: 160 | logger.info("unit scaling function: %s", node) 161 | args = (*node.args, None) # None denotes unconstrained 162 | replace_node_with_function(graph, node, U.add, args=args) 163 | 164 | # Replace nodes marked as residual-adds with unit scaled equivalent 165 | for node in graph.nodes: 166 | residual_add = node.meta.get("residual_add", None) 167 | if residual_add is not None: 168 | _unit_scale_residual(graph, node, **residual_add) 169 | 170 | _add_dependency_meta(graph, recalculate=True) 171 | for node in graph.nodes: 172 | if node.target == U.residual_add: 173 | node.meta["has_residual_successor"] = True 174 | dependencies = node.meta.get("dependencies", []) 175 | for d in dependencies: 176 | d.meta["has_residual_successor"] = True 177 | 178 | for node in graph.nodes: 179 | if "has_residual_successor" not in node.meta: 180 | _unconstrain_node(node) 181 | 182 | graph.lint() # type: ignore[no-untyped-call] 183 | return GraphModule(gm, graph) 184 | 185 | return inner_backend 186 | 187 | 188 | def _unit_init_weights(m: nn.Module) -> None: 189 | for name, mod in m.named_modules(): 190 | if isinstance(mod, (nn.Linear, nn.Embedding)): 191 | with torch.no_grad(): 192 | if isinstance(mod.weight, Tensor): 193 | logger.info("unit scaling weight: %s.weight", name) 194 | mod.weight /= mod.weight.std() 195 | 196 | 197 | def _zero_init_biases(m: nn.Module) -> None: 198 | for name, mod in m.named_modules(): 199 | if isinstance(mod, (nn.Linear, nn.Embedding)): 200 | with torch.no_grad(): 201 | if hasattr(mod, "bias") and isinstance(mod.bias, Tensor): 202 | logger.info("setting bias to zero: %s.bias", name) 203 | mod.bias -= mod.bias 204 | 205 | 206 | # Unit scaling should come before quantisation 207 | def _order_backends(backends: List[Backend]) -> None: 208 | unit_scaling_backend_idx = -1 209 | quantisation_backend_idx = float("inf") 210 | for i, b in enumerate(backends): 211 | if "unit_scaling_backend" in b.__qualname__: 212 | unit_scaling_backend_idx = i 213 | if "quantisation_backend" in b.__qualname__: 214 | quantisation_backend_idx = i 215 | if unit_scaling_backend_idx > quantisation_backend_idx: 216 | logger.info("moving unit scaling backend to precede quantisation backend") 217 | u = backends.pop(unit_scaling_backend_idx) 218 | backends.insert(quantisation_backend_idx, u) # type: ignore[arg-type] 219 | 220 | 221 | M = TypeVar("M", bound=nn.Module) 222 | 223 | 224 | def unit_scale( 225 | module: M, replace: Dict[Callable[..., Any], Callable[..., Any]] = dict() 226 | ) -> M: 227 | """**[Experimental]** Returns a unit-scaled version of the input model. 228 | 229 | Uses TorchDynamo to trace and transform the user-supplied module. 230 | This transformation identifies all :class:`torch.nn.functional` uses in the input 231 | module, and replaces them with their unit-scaled equivalents, should they exist. 232 | 233 | The tracing procedure automatically recurses into modules 234 | (whether defined in libraries, or by the user), identifying inner calls to any 235 | :class:`torch.nn.functional` operations, to build a graph of fundamental operations. 236 | Unit scaling is then applied as a transformation on top of this graph. 237 | 238 | This transformation proceeds in five stages: 239 | 240 | #. **Replacement of user-defined functions** according to the supplied `replace` 241 | dictionary. 242 | #. **Replacement of all functions with unit-scaled equivalents** defined in 243 | :mod:`unit_scaling.functional`. 244 | #. Identification & **replacement of all add operations that represent 245 | residual-adds**. The identification of residual connections is done via a 246 | dependency analysis on the graph. Residual-adds require special scaling compared 247 | with regular adds (see paper / User Guide for details). 248 | #. **Unconstraining of all operations after the final residual layer**. By default 249 | all unit scaled operations have their scaling factors constrained in the forward 250 | and backward pass to give valid gradients. This is not required in these final 251 | layers (see paper for proof), and hence we can unconstrain the operations to give 252 | better scaling. 253 | #. **Unit-scaling of all weights** and zero-initialisation of all biases. 254 | 255 | Note that by using TorchDynamo, `unit_scale()` is able to trace a much larger set of 256 | modules / operations than with previous PyTorch tracing approaches. This enables 257 | the process of unit scaling to be expressed as a generic graph transform that can be 258 | applied to arbitrary modules. 259 | 260 | Note that the current version of TorchDynamo (or :func:`torch.compile`, which is a 261 | wrapper around TorchDynamo) doesn't support nested transforms, so we implement our 262 | own system here. This makes it easy to nest transforms: 263 | 264 | .. code-block:: python 265 | 266 | from unit_scaling.transforms import compile, simulate_fp8, unit_scale 267 | 268 | module = compile(simulate_fp8(unit_scale(module))) 269 | 270 | However, these transforms are not interoperable with the standard 271 | :func:`torch.compile` interface. 272 | 273 | In some cases users may have a model definition that uses a custom implementation of 274 | a basic operation. In this case, `unit_scale()` can be told explicitly to substitute 275 | the layer for an equivalent, using the `replace` dictionary: 276 | 277 | .. code-block:: python 278 | 279 | import unit_scaling.functional as U 280 | from unit_scaling.transforms import unit_scale 281 | 282 | def new_gelu(x): 283 | ... 284 | 285 | class Model(nn.Module): 286 | def forward(x): 287 | ... 288 | x = new_gelu(x) 289 | ... 290 | 291 | model = unit_scale(Model(), replace={new_gelu: U.gelu}) 292 | 293 | This can also be used to substitute a particular function for a user-defined 294 | unit-scaled function not provided by :mod:`unit_scaling.functional`. 295 | 296 | **Note:** `unit_scale()` is experimental and has not yet been widely tested on a 297 | range of models. The standard approach to unit scaling a model is still to 298 | manually substitute the layers/operations in a model with their unit-scaled 299 | equivalents. Having said this, `unit_scale()` is implemented in a sufficiently 300 | generic way that we anticipate many users will ultimately be able to rely on this 301 | graph transform alone. 302 | 303 | Args: 304 | module (nn.Module): the input module to be unit scaled. 305 | replace (Dict[Callable, Callable], optional): a dictionary where keys represent 306 | functions to be replaced by the corresponding value-functions. Note that 307 | these substitutions take priority over the standard unit scaling 308 | substitutions. Defaults to dict(). 309 | 310 | Returns: 311 | nn.Module: the unit scaled module (with an independent copy of parameters) 312 | """ 313 | unit_scaled_module = apply_transform( 314 | module, 315 | unit_scaling_backend(replace), 316 | non_recurse_functions=list(replace.keys()), 317 | ) 318 | _order_backends(unit_scaled_module.backends) 319 | _unit_init_weights(unit_scaled_module) 320 | _zero_init_biases(unit_scaled_module) 321 | return unit_scaled_module # type: ignore[no-any-return] 322 | 323 | 324 | __all__ = generate__all__(__name__) 325 | -------------------------------------------------------------------------------- /unit_scaling/transforms/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | """Utilities for working with transforms.""" 4 | 5 | import copy 6 | import functools 7 | from typing import ( 8 | Any, 9 | Callable, 10 | Dict, 11 | Iterable, 12 | List, 13 | Optional, 14 | Tuple, 15 | TypeVar, 16 | no_type_check, 17 | ) 18 | from unittest.mock import patch 19 | 20 | import torch 21 | import torch._dynamo 22 | from torch import Tensor, nn 23 | from torch.fx.graph import Graph 24 | from torch.fx.graph_module import GraphModule 25 | from torch.fx.node import Node 26 | 27 | from .. import functional as U 28 | from .._internal_utils import generate__all__ 29 | 30 | T = TypeVar("T") 31 | 32 | Backend = Callable[[GraphModule, List[Tensor]], Callable[..., Any]] 33 | 34 | _unit_scaled_functions = [getattr(U, f) for f in U.__all__] 35 | 36 | 37 | def torch_nn_modules_to_user_modules(mod: nn.Module) -> None: 38 | """ 39 | Convert torch.nn.module classes to `trivial_subclass` versions. 40 | 41 | By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or 42 | :mod:`torch.nn.functional` functions when capturing the FX graph. 43 | 44 | This function makes `torch.nn` modules into user modules. 45 | 46 | To use this with a :class:`torch.nn.Module` the typical use case 47 | is to call `module = torch_nn_modules_to_user_modules(module)`. 48 | """ 49 | 50 | for n, submod in mod.named_children(): 51 | torch_nn_modules_to_user_modules(submod) 52 | 53 | # Mirroring the check at https://github.com/pytorch/pytorch/blob/34bce27f0d12bf7226b37dfe365660aad456701a/torch/_dynamo/variables/nn_module.py#L307 # noqa: E501 54 | if submod.__module__.startswith(("torch.nn.", "torch.ao.")): 55 | # Generate a new name, so e.g. torch.nn.modules.sparse.Embedding 56 | # becomes trivial_subclass_modules_sparse_Embedding 57 | modulename = submod.__module__ 58 | modulename = modulename.replace("torch.nn.", "", 1) 59 | modulename = modulename.replace(".", "_") 60 | newtypename = "trivial_subclass_" + modulename + "_" + type(submod).__name__ 61 | 62 | # Create a new type object deriving from type(submod) 63 | newmodtype = type(newtypename, (type(submod),), {}) 64 | 65 | # Initialize and copy state using pickle 66 | newsubmod = newmodtype.__new__(newmodtype) # type: ignore [call-overload] 67 | state = submod.__getstate__() # type: ignore [no-untyped-call] 68 | newsubmod.__setstate__(state) 69 | 70 | # Update module in mod 71 | setattr(mod, n, newsubmod) 72 | 73 | 74 | def patch_to_expand_modules(fn: Callable[..., T]) -> Callable[..., T]: 75 | """By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or 76 | :mod:`torch.nn.functional` functions when capturing the FX graph. 77 | Any function which is wrapped in 78 | :func:`torch._dynamo.optimise` (or :func:`torch.compile`) and is then passed to 79 | this function as `fn` will now automatically recurse into 80 | :mod:`torch.nn` modules or :mod:`torch.nn.functional` functions. 81 | 82 | In practice, to use this with a :class:`torch.nn.Module` the typical use case 83 | is to call `module = torch._dynamo.optimize(backend)(module)`, followed by 84 | `module.forward = patch_to_expand_modules(module.forward)`. 85 | 86 | This should be used in conjunction with :func:`torch_nn_modules_to_user_modules` 87 | 88 | Args: 89 | fn (Callable[..., T]): the function to be patched. 90 | 91 | Returns: 92 | Callable[..., T]: the new version of `fn` with patching applied. 93 | """ 94 | 95 | def _patched_call_function( # type: ignore[no-untyped-def] 96 | self, 97 | tx, 98 | args, 99 | kwargs, 100 | ): # pragma: no cover 101 | # Removing the check in https://github.com/pytorch/pytorch/blob/72662bf05b3499ce96aae9183a489c78f0c44c84/torch/_dynamo/variables/functions.py#L335 # noqa: E501 102 | return super( 103 | torch._dynamo.variables.functions.UserMethodVariable, self 104 | ).call_function(tx, args, kwargs) 105 | 106 | @functools.wraps(fn) 107 | def new_fn(*args: Any, **kwargs: Any) -> Any: 108 | with patch( 109 | "torch._dynamo.variables.functions.UserMethodVariable.call_function", 110 | new=_patched_call_function, 111 | ): 112 | return fn(*args, **kwargs) 113 | 114 | return new_fn 115 | 116 | 117 | def replace_node_with_function( 118 | graph: Graph, 119 | source: Node, 120 | target_fn: Callable[..., Any], 121 | args: Optional[Tuple[Any, ...]] = None, 122 | kwargs: Optional[Dict[Any, Any]] = None, 123 | keep_type_expr: bool = True, 124 | ) -> None: 125 | """Given a source node and its accompanying graph, remove the node and replace it 126 | with a new node that represents calling the target function. 127 | 128 | Args: 129 | graph (Graph): the graph in which the node is present. 130 | source (Node): the node to be replaced. 131 | target_fn (Callable[..., Any]): the function to be contained in the new node. 132 | args (Optional[Tuple[Any, ...]], optional): args of the new node. 133 | Defaults to None. 134 | kwargs (Optional[Dict[Any, Any]], optional): kwargs of the new node. 135 | Defaults to None. 136 | keep_type_expr (bool, optional): retain the type expression of the removed node. 137 | Defaults to True. 138 | """ 139 | if args is None: 140 | args = source.args 141 | if kwargs is None: 142 | kwargs = source.kwargs 143 | type_expr = getattr(source, "type", None) if keep_type_expr else None 144 | with graph.inserting_after(source): 145 | new_node = graph.call_function(target_fn, args, kwargs, type_expr) 146 | source.replace_all_uses_with(new_node) 147 | graph.erase_node(source) 148 | 149 | 150 | def _compose_backends(backends: Iterable[Backend]) -> Backend: 151 | def composite_backend( 152 | gm: GraphModule, example_inputs: List[Tensor] 153 | ) -> Callable[..., Any]: 154 | for b in backends: 155 | new_gm = b(gm, example_inputs) 156 | new_gm._param_name_to_source = getattr( # type: ignore 157 | gm, 158 | "_param_name_to_source", 159 | None, 160 | ) 161 | gm = new_gm # type: ignore[assignment] 162 | return gm 163 | 164 | return composite_backend 165 | 166 | 167 | M = TypeVar("M", bound=nn.Module) 168 | 169 | 170 | @no_type_check 171 | def apply_transform( 172 | module: M, 173 | backend: Backend, 174 | non_recurse_functions: List[Callable[..., Any]] = list(), 175 | ) -> M: 176 | """Applies a graph transformation to a module. 177 | 178 | The user-supplied :code:`backend` represents a transformation of a 179 | :class:`torch.fx.graph_module.GraphModule`. :code:`apply_transform()` uses 180 | :func:`torch._dynamo.optimize` to apply this transformation to the module, 181 | returning a new transformed module. 182 | 183 | Note that the current version of TorchDynamo (or :func:`torch.compile`, which is a 184 | wrapper around TorchDynamo) doesn't support nested transforms, so we implement our 185 | own system here. This makes it easy to nest transforms: 186 | 187 | .. code-block:: python 188 | 189 | module = apply_transform(apply_transform(module, backend_1), backend_2) 190 | 191 | However, it should be noted these transforms are not interoperable with the standard 192 | :func:`torch.compile` interface. 193 | 194 | This nesting system is implemented by moving the call to 195 | :func:`torch._dynamo.optimize` within the :code:`forward()` method of the module 196 | (though it is only executed on the first call to the module, or if a new transform 197 | is applied, the optimised call being cached thereafter). This differs from the 198 | standard approach used with :func:`torch._dynamo.optimize`, but enables this 199 | convenient nesting functionality. 200 | 201 | Args: 202 | _module (nn.Module): the module to be transformed. 203 | backend (Backend): the graph transformation to be applied. 204 | non_recurse_functions (Iterable[Callable[..., Any]], optional): functions which 205 | the user does not wish to be recursed into. Defaults to list(). 206 | 207 | Returns: 208 | nn.Module: the transformed module. 209 | """ 210 | module = copy.deepcopy(module) 211 | 212 | torch_nn_modules_to_user_modules(module) 213 | 214 | if not hasattr(module, "backends"): 215 | module.backends = [] 216 | module.backends.append(backend) 217 | 218 | for v in non_recurse_functions: 219 | torch._dynamo.allow_in_graph(v) 220 | 221 | backend = _compose_backends(module.backends) 222 | 223 | def new_forward(*args: Any, **kwargs: Any) -> Any: 224 | if module.rerun_transform: 225 | torch._dynamo.reset() 226 | dynamo_module = torch._dynamo.optimize(backend)(module) 227 | module.dynamo_forward = patch_to_expand_modules(dynamo_module.forward) 228 | module.rerun_transform = False 229 | with patch.object(module, "forward", module.base_forward): 230 | return module.dynamo_forward(*args, **kwargs) 231 | 232 | module.rerun_transform = True 233 | module.base_forward = getattr(module, "base_forward", module.forward) 234 | module.forward = new_forward 235 | return module 236 | 237 | 238 | __all__ = generate__all__(__name__) 239 | -------------------------------------------------------------------------------- /unit_scaling/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | """Utility functions for developing unit-scaled models.""" 4 | 5 | import ast 6 | import math 7 | import re 8 | import typing 9 | from collections import OrderedDict 10 | from dataclasses import dataclass 11 | from types import FunctionType, ModuleType 12 | from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union, cast 13 | 14 | import einops 15 | import torch 16 | from pygments import highlight 17 | from pygments.formatters import TerminalFormatter 18 | from pygments.lexers import PythonLexer 19 | from torch import Tensor, fx, nn 20 | 21 | from . import functional 22 | from ._internal_utils import generate__all__ 23 | 24 | 25 | @dataclass 26 | class ScalePair: 27 | """Dataclass containing a pair of scalars, intended to represent the standard 28 | deviation of an arbitrary tensor in the forward and backward passes.""" 29 | 30 | forward: Optional[float] = None 31 | backward: Optional[float] = None 32 | 33 | def __str__(self) -> str: 34 | fwd = f"{self.forward:.3}" if self.forward is not None else "n/a" 35 | bwd = f"{self.backward:.3}" if self.backward is not None else "n/a" 36 | return f"(-> {fwd}, <- {bwd})" 37 | 38 | 39 | ScaleDict = typing.OrderedDict[str, ScalePair] 40 | 41 | 42 | class ScaleTracker(torch.autograd.Function): 43 | """Given a `nn.Tensor`, records its standard deviation in the forward and 44 | backward pass in the supplied `ScalePair`.""" 45 | 46 | @staticmethod 47 | def forward( 48 | ctx: torch.autograd.function.FunctionCtx, 49 | t: Tensor, 50 | scale_tracker: ScalePair, 51 | ) -> Tensor: 52 | scale_tracker.forward = float(t.std()) 53 | ctx.scale_tracker = scale_tracker # type: ignore 54 | return t 55 | 56 | @staticmethod 57 | def backward( # type:ignore[override] 58 | ctx: torch.autograd.function.FunctionCtx, t: Tensor 59 | ) -> Tuple[Tensor, None, None]: 60 | ctx.scale_tracker.backward = float(t.std()) # type: ignore 61 | return t, None, None 62 | 63 | @staticmethod 64 | def track(t: Tensor, scale_tracker: ScalePair) -> Tensor: 65 | # Add typing information to `apply()` method from `torch.autograd.Function` 66 | apply = cast(Callable[[Tensor, ScalePair], Tensor], ScaleTracker.apply) 67 | return apply(t, scale_tracker) 68 | 69 | 70 | class ScaleTrackingInterpreter(fx.Interpreter): 71 | """Wraps an `fx.GraphModule` such than when executed it records the standard 72 | deviation of every intermediate `nn.Tensor` in the forward and backward pass. 73 | 74 | Args: 75 | module (fx.GraphModule): the module to be instrumented. 76 | """ 77 | 78 | def __init__(self, module: fx.GraphModule): 79 | super().__init__(module) 80 | self.scales: typing.OrderedDict[str, ScalePair] = OrderedDict() 81 | 82 | def run_node(self, n: fx.Node) -> Any: 83 | out = super().run_node(n) 84 | if isinstance(out, Tensor) and out.is_floating_point(): 85 | scale_pair = ScalePair() 86 | out = ScaleTracker.track(out, scale_pair) 87 | self.scales[n.name] = scale_pair 88 | return out 89 | 90 | def call_function( 91 | self, target: fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] 92 | ) -> Any: 93 | return super().call_function(target, args, kwargs) 94 | 95 | def placeholder( 96 | self, 97 | target: fx.node.Target, 98 | args: Tuple[fx.node.Argument, ...], 99 | kwargs: Dict[str, Any], 100 | ) -> Any: 101 | """To handle functions being passed as arguments (for example constraints) the 102 | tracer represents them as placeholder nodes. This method extracts the original 103 | function from the node, as stored in the `target_to_function` dict.""" 104 | if isinstance(target, str) and target.startswith("function_placeholder__"): 105 | return self.module.graph._tracer_extras["target_to_function"][ 106 | target 107 | ] # pragma: no cover 108 | return super().placeholder(target, args, kwargs) 109 | 110 | 111 | def _record_scales( 112 | fx_graph_module: fx.GraphModule, 113 | inputs: Tuple[Tensor, ...], 114 | backward: Optional[Tensor] = None, 115 | ) -> ScaleDict: 116 | """Given a `torch.fx.GraphModule`, and dummy tensors to feed into the forward and 117 | backward passes, returns a dictionary of the scales (standard deviations) of every 118 | intermediate tensor in the model (forward and backward pass). 119 | 120 | Args: 121 | fx_graph_module (fx.GraphModule): the module to record. 122 | input (Tuple[Tensor, ...]): fed into the forward pass for analysis. 123 | backward (Tensor, optional): fed into the output's `.backward()` method for 124 | analysis. Defaults to `None`, equivalent to calling plain `.backward()`. 125 | 126 | Returns: 127 | ScaleDict: An ordered dictionary with `ScalePair`s for each intermediate tensor. 128 | """ 129 | tracking_module = ScaleTrackingInterpreter(fx_graph_module) 130 | out = tracking_module.run(*inputs) 131 | out.backward(backward) 132 | return tracking_module.scales 133 | 134 | 135 | def _annotate(code: str, scales: ScaleDict, syntax_highlight: bool) -> str: 136 | """Given a string representation of some code and an `ScaleDict` with accompanying 137 | scales, annotates the code to include the scales on the right-hand side.""" 138 | 139 | function_placeholder_regex = r"function_placeholder__(\w+)" 140 | 141 | def is_function_placeholder_line(code_line: str) -> bool: 142 | return bool(re.search(f" = {function_placeholder_regex}$", code_line)) 143 | 144 | def cleanup_function_signature(code_line: str) -> str: 145 | code_line = re.sub(f", {function_placeholder_regex}", "", code_line) 146 | inner_code_line = code_line.split("(", 1)[1] 147 | replacement = re.sub(r"_([a-zA-Z0-9_]+)_", r"\1", inner_code_line) 148 | return code_line.replace(inner_code_line, replacement) 149 | 150 | def annotate_line(code_line: str) -> str: 151 | if code_line.startswith("torch.fx._symbolic_trace.wrap"): 152 | return "" 153 | code_line = code_line.split(";")[0] 154 | if is_function_placeholder_line(code_line): # pragma: no cover 155 | return "" 156 | words = code_line.strip().split(" ") 157 | if words: 158 | if words[0] in scales: 159 | return f"{code_line}; {scales[words[0]]}" 160 | elif words[0] == "def": 161 | parsed = ast.parse(code_line + "\n\t...").body[0] 162 | assert isinstance(parsed, ast.FunctionDef) 163 | arg_names = [arg.arg for arg in parsed.args.args] 164 | scale_strs = [str(scales[a]) for a in arg_names if a in scales] 165 | code_line = cleanup_function_signature(code_line) 166 | if scale_strs: 167 | return f"{code_line} {', '.join(scale_strs)}" # pragma: no cover 168 | else: 169 | return code_line 170 | return code_line 171 | 172 | def remove_empty_lines(code_lines: Iterator[str]) -> Iterator[str]: 173 | return (line for line in code_lines if line.strip()) 174 | 175 | code_lines = map(annotate_line, code.splitlines()) 176 | code = "\n".join(remove_empty_lines(code_lines)).strip() 177 | code = code.replace("unit_scaling_functional_", "U.") 178 | if syntax_highlight: 179 | return highlight(code, PythonLexer(), TerminalFormatter()) # pragma: no cover 180 | return code 181 | 182 | 183 | class _DeepTracer(fx.Tracer): 184 | """Version of `torch.fx.Tracer` which recurses into all sub-modules (if specified). 185 | 186 | Args: 187 | recurse_modules (bool): toggles recursive behavour. Defaults to True. 188 | autowrap_modules (Tuple[ModuleType]): defaults to 189 | `(math, einops, U.functional)`, 190 | Python modules whose functions should be wrapped automatically 191 | without needing to use fx.wrap(). 192 | autowrap_function (Tuple[Callable, ...]): defaults to `()`, 193 | Python functions that should be wrapped automatically without 194 | needing to use fx.wrap(). 195 | """ 196 | 197 | def __init__( 198 | self, 199 | recurse_modules: bool = True, 200 | autowrap_modules: Tuple[ModuleType, ...] = (math, einops, functional), 201 | autowrap_functions: Tuple[Callable[..., Any], ...] = (), 202 | ) -> None: 203 | super().__init__( 204 | autowrap_modules=autowrap_modules, # type: ignore[arg-type] 205 | autowrap_functions=autowrap_functions, 206 | ) 207 | self.recurse_modules = recurse_modules 208 | self.target_to_function: Dict[str, FunctionType] = {} 209 | self.function_to_node: Dict[FunctionType, fx.Node] = {} 210 | # Fixes: `TypeError: __annotations__ must be set to a dict object` 211 | if id(FunctionType) in self._autowrap_function_ids: 212 | self._autowrap_function_ids.remove(id(FunctionType)) 213 | 214 | def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: 215 | return not self.recurse_modules 216 | 217 | def create_arg(self, a: Any) -> fx.node.Argument: 218 | """Replaces callable arguments with strings for tracing.""" 219 | if isinstance(a, FunctionType): # pragma: no cover 220 | node = self.function_to_node.get(a) 221 | if node is None: 222 | assert hasattr( 223 | a, "__name__" 224 | ), f"can't create arg for unnamed function: {a}" 225 | name = getattr(a, "__name__") 226 | target = f"function_placeholder__{name}" 227 | node = self.create_node("placeholder", target, (), {}, name) 228 | self.target_to_function[target] = a 229 | self.function_to_node[a] = node 230 | return node 231 | return super().create_arg(a) 232 | 233 | def trace( 234 | self, 235 | root: Union[torch.nn.Module, Callable[..., Any]], 236 | concrete_args: Optional[Dict[str, Any]] = None, 237 | ) -> fx.Graph: 238 | """Adds the `target_to_function` dict to the graph so the interpreter can use it 239 | downstream.""" 240 | graph = super().trace(root, concrete_args) 241 | if not hasattr(graph, "_tracer_extras") or graph._tracer_extras is None: 242 | graph._tracer_extras = {} 243 | graph._tracer_extras["target_to_function"] = self.target_to_function 244 | return graph 245 | 246 | 247 | def analyse_module( 248 | module: nn.Module, 249 | inputs: Union[Tensor, Tuple[Tensor, ...]], 250 | backward: Optional[Tensor] = None, 251 | recurse_modules: bool = True, 252 | syntax_highlight: bool = True, 253 | autowrap_modules: Tuple[ModuleType, ...] = (math, einops, functional), 254 | autowrap_functions: Tuple[Callable[..., Any], ...] = (), 255 | ) -> str: 256 | """Given a `nn.Module` and dummy forward and backward tensors, generates code 257 | representing the module annotated with the scales (standard deviation) of each 258 | tensor in both forward and backward passes. Implemented using `torch.fx`. 259 | 260 | Args: 261 | module (nn.Module): the module to analyse. 262 | inputs (Union[Tensor, Tuple[Tensor, ...]]): fed into the forward pass for 263 | analysis. 264 | backward (Tensor, optional): fed into the output's `.backward()` method for 265 | analysis. Defaults to `None`, equivalent to calling plain `.backward()`. 266 | recurse_modules (bool, optional): toggles recursive behavour. Defaults to True. 267 | syntax_highlight (bool, optional): Defaults to True. 268 | autowrap_modules (Tuple[ModuleType]): defaults to 269 | `(math, einops, U.functional)`, 270 | Python modules whose functions should be wrapped automatically 271 | without needing to use fx.wrap(). 272 | autowrap_function (Tuple[Callable, ...]): defaults to `()`, 273 | Python functions that should be wrapped automatically without 274 | needing to use fx.wrap(). 275 | 276 | Returns: 277 | str: 278 | a code string representing the operations in the module with scale 279 | annotations for each tensor, reflecting their standard deviations in the 280 | forward and backward passes. 281 | 282 | Examples:: 283 | 284 | >>> class MLP(nn.Module): 285 | >>> def __init__(self, d): 286 | >>> super().__init__() 287 | >>> self.fc1 = nn.Linear(d, d * 4) 288 | >>> self.relu = nn.ReLU() 289 | >>> self.fc2 = nn.Linear(d * 4, d) 290 | 291 | >>> def forward(self, x): 292 | >>> x = self.fc1(x) 293 | >>> x = self.relu(x) 294 | >>> x = self.fc2(x) 295 | >>> return x 296 | 297 | 298 | >>> hidden_size = 2**10 299 | >>> x = torch.randn(hidden_size, hidden_size).requires_grad_() 300 | >>> bwd = torch.randn(hidden_size, hidden_size) 301 | 302 | >>> code = analyse_module(MLP(hidden_size), x, bwd) 303 | >>> print(code) 304 | def forward(self, x): (-> 1.0, <- 0.236) 305 | fc1_weight = self.fc1.weight; (-> 0.018, <- 6.54) 306 | fc1_bias = self.fc1.bias; (-> 0.0182, <- 6.51) 307 | linear = torch._C._nn.linear(x, fc1_weight, fc1_bias); (-> 0.578, <- 0.204) 308 | relu = torch.nn.functional.relu(linear, inplace = False); (-> 0.337, <- 0.288) 309 | fc2_weight = self.fc2.weight; (-> 0.00902, <- 13.0) 310 | fc2_bias = self.fc2.bias; (-> 0.00904, <- 31.6) 311 | linear_1 = torch._C._nn.linear(relu, fc2_weight, fc2_bias); (-> 0.235, <- 0.999) 312 | return linear_1 313 | """ # noqa: E501 314 | tracer = _DeepTracer(recurse_modules, autowrap_modules, autowrap_functions) 315 | fx_graph = tracer.trace(module) 316 | fx_graph_module = fx.GraphModule(tracer.root, fx_graph) 317 | 318 | if not isinstance(inputs, tuple): 319 | inputs = (inputs,) 320 | scales = _record_scales(fx_graph_module, inputs, backward) 321 | return _annotate(fx_graph_module.code, scales, syntax_highlight=syntax_highlight) 322 | 323 | 324 | __all__ = generate__all__(__name__) 325 | --------------------------------------------------------------------------------