├── .devcontainer.json
├── .dockerignore
├── .github
└── workflows
│ ├── ci.yaml
│ └── publish.yaml
├── .gitignore
├── Dockerfile
├── LICENSE
├── NOTICE.md
├── README.md
├── analysis
├── almost_scaled_dot_product_attention
│ ├── .gitignore
│ ├── almost_scaled_dot_product_attention.ipynb
│ └── demo_transformer.py
├── benchmarking_compiled_unit_scaled_ops.ipynb
├── emb_lr_analysis.ipynb
├── empirical_op_scaling.ipynb
└── sgd_scaling.ipynb
├── dev
├── docs
├── Makefile
├── _static
│ ├── animation.html
│ └── scales.png
├── _templates
│ ├── custom-class-template.rst
│ └── custom-module-template.rst
├── api_reference.rst
├── conf.py
├── index.rst
├── limitations.rst
├── make.bat
├── posts
│ └── almost_scaled_dot_product_attention.md
├── u-muP_slides.pdf
└── user_guide.rst
├── examples
├── .gitignore
├── demo.ipynb
├── how_to_scale_op.ipynb
└── scale_analysis.py
├── pyproject.toml
├── requirements-dev.txt
├── setup.cfg
└── unit_scaling
├── __init__.py
├── _internal_utils.py
├── _modules.py
├── analysis.py
├── constraints.py
├── core
├── __init__.py
└── functional.py
├── docs.py
├── formats.py
├── functional.py
├── optim.py
├── parameter.py
├── scale.py
├── tests
├── __init__.py
├── conftest.py
├── core
│ ├── __init__.py
│ └── test_functional.py
├── helper.py
├── test_analysis.py
├── test_constraints.py
├── test_docs.py
├── test_formats.py
├── test_functional.py
├── test_modules.py
├── test_optim.py
├── test_parameter.py
├── test_scale.py
├── test_utils.py
└── transforms
│ ├── __init__.py
│ ├── test_compile.py
│ ├── test_simulate_format.py
│ ├── test_track_scales.py
│ └── test_unit_scale.py
├── transforms
├── __init__.py
├── _compile.py
├── _simulate_format.py
├── _track_scales.py
├── _unit_scale.py
└── utils.py
└── utils.py
/.devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "build": {
3 | "dockerfile": "Dockerfile"
4 | },
5 | "workspaceFolder": "/home/developer/unit-scaling",
6 | "customizations": {
7 | "vscode": {
8 | "extensions": [
9 | "ms-python.python",
10 | "ms-toolsai.jupyter"
11 | ],
12 | "settings": {
13 | "terminal.integrated.defaultProfile.linux": "zsh",
14 | "terminal.integrated.profiles.linux": { "zsh": { "path": "/bin/zsh" } }
15 | }
16 | }
17 | },
18 | "mounts": [
19 | "source=${localEnv:HOME}/.ssh,target=/home/developer/.ssh,type=bind,readonly=true",
20 | "source=${localEnv:HOME}/.gitconfig,target=/home/developer/.gitconfig,type=bind,readonly=true",
21 | "source=${localWorkspaceFolder},target=/home/developer/unit-scaling,type=bind"
22 | ],
23 | "remoteUser": "developer"
24 | }
25 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | *
2 | !requirements-dev.txt
3 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push: { branches: ["main"] }
5 | pull_request:
6 | workflow_dispatch:
7 |
8 | concurrency:
9 | # Run everything on main, most-recent on PR builds
10 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | ci:
15 | runs-on: ubuntu-latest
16 | timeout-minutes: 10
17 | steps:
18 | - name: Checkout code
19 | uses: actions/checkout@v3
20 |
21 | - name: Build Docker image
22 | run: docker build -t unit-scaling-dev:latest .
23 |
24 | - name: Local unit_scaling install
25 | run: docker run -v $(pwd):/home/developer/unit-scaling unit-scaling-dev:latest pip install --user -e .
26 |
27 | - name: Run CI
28 | run: docker run -v $(pwd):/home/developer/unit-scaling unit-scaling-dev:latest ./dev ci
29 |
30 | - name: Publish documentation
31 | if: ${{github.ref == 'refs/heads/main'}}
32 | uses: Cecilapp/GitHub-Pages-deploy@v3
33 | env: { GITHUB_TOKEN: "${{ github.token }}" }
34 | with:
35 | build_dir: docs/_build/html
36 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | # Based on template at https://docs.github.com/en/enterprise-cloud@latest/actions/use-cases-and-examples/building-and-testing/building-and-testing-python?learn=continuous_integration#publishing-to-pypi
2 |
3 | name: Publish to PyPI
4 |
5 | on:
6 | release:
7 | types: [published]
8 | workflow_dispatch:
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | release-build:
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v4
19 |
20 | - uses: actions/setup-python@v5
21 | with:
22 | python-version: "3.x"
23 |
24 | - name: Build release distributions
25 | run: |
26 | python -m pip install build
27 | python -m build
28 |
29 | - name: Upload distributions
30 | uses: actions/upload-artifact@v4
31 | with:
32 | name: release-dists
33 | path: dist/
34 |
35 | pypi-publish:
36 | runs-on: ubuntu-latest
37 |
38 | needs:
39 | - release-build
40 |
41 | permissions:
42 | # IMPORTANT: this permission is mandatory for trusted publishing
43 | id-token: write
44 |
45 | # Dedicated environments with protections for publishing are strongly recommended.
46 | environment:
47 | name: pypi
48 | url: https://pypi.org/p/unit-scaling
49 |
50 | steps:
51 | - name: Retrieve release distributions
52 | uses: actions/download-artifact@v4
53 | with:
54 | name: release-dists
55 | path: dist/
56 |
57 | - name: Publish release distributions to PyPI
58 | uses: pypa/gh-action-pypi-publish@release/v1
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | .coverage*
3 | .env
4 | .mypy_cache
5 | __pycache__
6 | .pytest_cache
7 | .venv
8 | .venvs
9 | .vscode
10 | unit_scaling/_version.py
11 |
12 | /build
13 | /dist
14 | /local
15 |
16 | /docs/_build
17 | /docs/generated
18 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use PyTorch base image
2 | FROM pytorch/pytorch:latest
3 |
4 | # Install additional dependencies
5 | RUN apt-get update && apt-get install -y \
6 | git \
7 | vim \
8 | sudo \
9 | make \
10 | g++ \
11 | zsh \
12 | && chsh -s /bin/zsh \
13 | && apt-get clean && rm -rf /var/lib/apt/lists/* # cleanup (smaller image)
14 |
15 | # Configure a non-root user with sudo privileges
16 | ARG USERNAME=developer # Change this to preferred username
17 | ARG USER_UID=1001
18 | ARG USER_GID=$USER_UID
19 | RUN groupadd --gid $USER_GID $USERNAME \
20 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
21 | && echo "$USERNAME ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/$USERNAME \
22 | && chmod 0440 /etc/sudoers.d/$USERNAME
23 | USER $USERNAME
24 |
25 | # Set working directory
26 | WORKDIR /home/$USERNAME/unit-scaling
27 |
28 | # Puts pip install libs on $PATH & sets correct locale
29 | ENV PATH="$PATH:/home/$USERNAME/.local/bin" \
30 | LC_ALL=C.UTF-8
31 |
32 | # Install Python dependencies
33 | COPY requirements-dev.txt .
34 | RUN pip install --user -r requirements-dev.txt
35 |
36 | # Creates basic .zshrc
37 | RUN sudo cp /etc/zsh/newuser.zshrc.recommended /home/$USERNAME/.zshrc
38 |
39 | CMD ["/bin/zsh"]
40 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Graphcore
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/NOTICE.md:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023 Graphcore Ltd. Licensed under the Apache 2.0 License.
2 |
3 | The included code is released under an Apache 2.0 license, (see [LICENSE](LICENSE)).
4 |
5 | Our dependencies are (see [pyproject.toml](pyproject.toml)):
6 |
7 | | Component | About | License |
8 | | --- | --- | --- |
9 | | docstring-parser | Parse Python docstrings | MIT |
10 | | einops | Deep learning operations reinvented (for pytorch, tensorflow, jax and others) | MIT |
11 | | numpy | Array processing library | BSD 3-Clause |
12 |
13 | We also use additional Python dependencies for development/testing (see [requirements-dev.txt](requirements-dev.txt)).
14 |
15 | **This directory includes derived work from the following:**
16 |
17 | ---
18 |
19 | Sphinx: https://github.com/sphinx-doc/sphinx, licensed under:
20 |
21 | > Unless otherwise indicated, all code in the Sphinx project is licenced under the
22 | > two clause BSD licence below.
23 | >
24 | > Copyright (c) 2007-2023 by the Sphinx team (see AUTHORS file).
25 | > All rights reserved.
26 | >
27 | > Redistribution and use in source and binary forms, with or without
28 | > modification, are permitted provided that the following conditions are
29 | > met:
30 | >
31 | > * Redistributions of source code must retain the above copyright
32 | > notice, this list of conditions and the following disclaimer.
33 | >
34 | > * Redistributions in binary form must reproduce the above copyright
35 | > notice, this list of conditions and the following disclaimer in the
36 | > documentation and/or other materials provided with the distribution.
37 | >
38 | > THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
39 | > "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
40 | > LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
41 | > A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
42 | > HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
43 | > SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
44 | > LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
45 | > DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
46 | > THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
47 | > (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
48 | > OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
49 |
50 | this applies to:
51 | - `docs/_templates/custom-class-template.rst` (modified)
52 | - `docs/_templates/custom-module-template.rst` (modified)
53 |
54 | ---
55 |
56 | The Example: Basic Sphinx project for Read the Docs: https://github.com/readthedocs-examples/example-sphinx-basic, licensed under:
57 |
58 | > MIT License
59 | >
60 | > Copyright (c) 2022 Read the Docs Inc
61 | >
62 | > Permission is hereby granted, free of charge, to any person obtaining a copy
63 | > of this software and associated documentation files (the "Software"), to deal
64 | > in the Software without restriction, including without limitation the rights
65 | > to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
66 | > copies of the Software, and to permit persons to whom the Software is
67 | > furnished to do so, subject to the following conditions:
68 | >
69 | > The above copyright notice and this permission notice shall be included in all
70 | > copies or substantial portions of the Software.
71 | >
72 | > THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
73 | > IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
74 | > FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
75 | > AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
76 | > LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
77 | > OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
78 | > SOFTWARE.
79 |
80 | this applies to:
81 | - `docs/conf.py` (modified)
82 | - `docs/make.bat`
83 | - `docs/Makefile`
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unit-Scaled Maximal Update Parameterization (u-μP)
2 |
3 | [](https://github.com/graphcore-research/unit-scaling/actions/workflows/ci-public.yaml)
4 | 
5 | [](https://github.com/graphcore-research/unit-scaling/blob/main/LICENSE)
6 | [](https://github.com/graphcore-research/unit-scaling/stargazers)
7 |
8 | A library for unit scaling in PyTorch, based on the paper [u-μP: The Unit-Scaled Maximal Update Parametrization](https://arxiv.org/abs/2407.17465) and previous work [Unit Scaling: Out-of-the-Box Low-Precision Training](https://arxiv.org/abs/2303.11257).
9 |
10 | Documentation can be found at
11 | [https://graphcore-research.github.io/unit-scaling](https://graphcore-research.github.io/unit-scaling) and an example notebook at [examples/demo.ipynb](examples/demo.ipynb).
12 |
13 | **Note:** The library is currently in its _beta_ release.
14 | Some features have yet to be implemented and occasional bugs may be present.
15 | We're keen to help users with any problems they encounter.
16 |
17 | ## Installation
18 |
19 | To install the `unit-scaling` library, run:
20 |
21 | ```sh
22 | pip install unit-scaling
23 | ```
24 | or for a local editable install (i.e. one which uses the files in this repo), run:
25 |
26 | ```sh
27 | pip install -e .
28 | ```
29 |
30 | ## Development
31 |
32 | For development in this repository, we recommend using the provided docker container.
33 | This image can be built and entered interactively using:
34 |
35 | ```sh
36 | docker build -t unit-scaling-dev:latest .
37 | docker run -it --rm --user developer:developer -v $(pwd):/home/developer/unit-scaling unit-scaling-dev:latest
38 | # To use git within the container, add `-v ~/.ssh:/home/developer/.ssh:ro -v ~/.gitconfig:/home/developer/.gitconfig:ro`.
39 | ```
40 |
41 | For vscode users, this repo also contains a `.devcontainer.json` file, which enables the container to be used as a full-featured IDE (see the [Dev Container docs](https://code.visualstudio.com/docs/devcontainers/containers) for details on how to use this feature).
42 |
43 | Key development functionality is contained within the `./dev` script. This includes running unit tests, linting, formatting, documentation generation and more. Run `./dev --help` for the available options. Running `./dev` without arguments is equivalent to using the `--ci` option, which runs all of the available dev checks. This is the test used for GitHub CI.
44 |
45 | We encourage pull requests from the community. Please reach out to us with any questions about contributing.
46 |
47 | ## What is u-μP?
48 |
49 | u-μP inserts scaling factors into the model to make activations, gradients and weights unit-scaled (RMS ≈ 1) at initialisation, and into optimiser learning rates to keep updates stable as models are scaled in width and depth. This results in hyperparameter transfer from small to large models and easy support for low-precision training.
50 |
51 | For a quick intro, see [examples/demo.ipynb](examples/demo.ipynb), for more depth see the [paper](https://arxiv.org/abs/2407.17465) and [library documentation](https://graphcore-research.github.io/unit-scaling/).
52 |
53 | ## What is unit scaling?
54 |
55 | For a demonstration of the library and an overview of how it works, see
56 | [Out-of-the-Box FP8 Training](https://github.com/graphcore-research/out-of-the-box-fp8-training/blob/main/out_of_the_box_fp8_training.ipynb)
57 | (a notebook showing how to unit-scale the nanoGPT model).
58 |
59 | For a more in-depth explanation, consult our paper
60 | [Unit Scaling: Out-of-the-Box Low-Precision Training](https://arxiv.org/abs/2303.11257).
61 |
62 | And for a practical introduction to using the library, see our [User Guide](https://graphcore-research.github.io/unit-scaling/user_guide.html).
63 |
64 | ## License
65 |
66 | Copyright (c) 2023 Graphcore Ltd. Licensed under the Apache 2.0 License.
67 |
68 | See [NOTICE.md](NOTICE.md) for further details.
69 |
--------------------------------------------------------------------------------
/analysis/almost_scaled_dot_product_attention/.gitignore:
--------------------------------------------------------------------------------
1 | shakespeare.txt
2 |
--------------------------------------------------------------------------------
/analysis/almost_scaled_dot_product_attention/demo_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | from itertools import islice
4 | import math
5 | from pathlib import Path
6 | from typing import *
7 |
8 | import einops
9 | import torch
10 | from torch import nn, Tensor
11 | import tqdm
12 |
13 |
14 | class Config(dict):
15 | def __init__(self, *args: Any, **kwargs: Any):
16 | super().__init__(*args, **kwargs)
17 | self.__dict__ = self
18 |
19 |
20 | CONFIG = Config(
21 | sequence_length=256,
22 | batch_size=16,
23 | hidden_size=256,
24 | head_size=64,
25 | depth=4,
26 | fully_scaled_attention=False,
27 | lr=2**-10,
28 | steps=5000,
29 | )
30 |
31 |
32 | # https://www.gutenberg.org/cache/epub/100/pg100.txt
33 | DATA = torch.tensor(list(Path("shakespeare.txt").read_bytes()))
34 |
35 |
36 | def batches() -> Iterable[Tensor]:
37 | while True:
38 | offsets = torch.randint(
39 | len(DATA) - CONFIG.sequence_length - 1, (CONFIG.batch_size,)
40 | )
41 | yield torch.stack([DATA[i : i + CONFIG.sequence_length + 1] for i in offsets])
42 |
43 |
44 | class Attention(nn.Module):
45 | def __init__(self):
46 | super().__init__()
47 | self.head_size = CONFIG.head_size
48 | self.n_heads = CONFIG.hidden_size // CONFIG.head_size
49 | self.qkv = nn.Linear(CONFIG.hidden_size, 3 * self.n_heads * self.head_size)
50 | self.proj = nn.Linear(self.n_heads * self.head_size, CONFIG.hidden_size)
51 | # Put the scale in a non-trainable parameter, to avoid recompilation
52 | self.out_scale = nn.Parameter(
53 | torch.tensor(
54 | (CONFIG.sequence_length / math.e) ** 0.5
55 | if CONFIG.fully_scaled_attention
56 | else 1.0
57 | ),
58 | requires_grad=False,
59 | )
60 |
61 | def forward(self, x: Tensor) -> Tensor:
62 | s = x.shape[1]
63 | q, k, v = einops.rearrange(
64 | self.qkv(x), "b s (M n d) -> M b n s d", M=3, n=self.n_heads
65 | )
66 | qk_scale = torch.tensor(self.head_size**-0.5, dtype=x.dtype, device=x.device)
67 | pre_a = torch.einsum("bnsd, bntd -> bnst", q, k) * qk_scale
68 | pre_a = pre_a + torch.triu(
69 | torch.full((s, s), -1e4, device=x.device), diagonal=1
70 | )
71 | a = torch.softmax(pre_a, -1)
72 | out = torch.einsum("bnst, bntd -> bnsd", a, v) * self.out_scale
73 | return self.proj(einops.rearrange(out, "b n s d -> b s (n d)"))
74 |
75 |
76 | class FFN(nn.Module):
77 | def __init__(self):
78 | super().__init__()
79 | self.up = nn.Linear(CONFIG.hidden_size, 4 * CONFIG.hidden_size)
80 | self.down = nn.Linear(self.up.out_features, self.up.in_features)
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | return self.down(torch.nn.functional.gelu(self.up(x)))
84 |
85 |
86 | class PreNormResidual(nn.Module):
87 | def __init__(self, body: nn.Module):
88 | super().__init__()
89 | self.norm = nn.LayerNorm([CONFIG.hidden_size])
90 | self.body = body
91 |
92 | def forward(self, x: Tensor) -> Tensor:
93 | return x + self.body(self.norm(x))
94 |
95 |
96 | class AbsolutePositionalEncoding(nn.Module):
97 | def __init__(self):
98 | super().__init__()
99 | self.weight = nn.Parameter(
100 | torch.randn(CONFIG.sequence_length, CONFIG.hidden_size)
101 | )
102 |
103 | def forward(self, x: Tensor) -> Tensor:
104 | return x + self.weight
105 |
106 |
107 | class Model(nn.Module):
108 | def __init__(self):
109 | super().__init__()
110 | self.model = nn.Sequential(
111 | nn.Embedding(256, CONFIG.hidden_size),
112 | AbsolutePositionalEncoding(),
113 | nn.LayerNorm([CONFIG.hidden_size]),
114 | *(
115 | nn.Sequential(PreNormResidual(Attention()), PreNormResidual(FFN()))
116 | for _ in range(CONFIG.depth)
117 | ),
118 | nn.LayerNorm([CONFIG.hidden_size]),
119 | nn.Linear(CONFIG.hidden_size, 256),
120 | )
121 |
122 | def forward(self, indices: Tensor) -> Tensor:
123 | return nn.functional.cross_entropy(
124 | self.model(indices[:, :-1]).flatten(0, -2), indices[:, 1:].flatten()
125 | )
126 |
127 |
128 | def train() -> Tensor:
129 | model = Model()
130 | opt = torch.optim.Adam(model.parameters(), lr=CONFIG.lr)
131 | losses = []
132 | for batch in tqdm.tqdm(islice(batches(), CONFIG.steps), total=CONFIG.steps):
133 | opt.zero_grad()
134 | loss = model(batch)
135 | loss.backward()
136 | opt.step()
137 | losses.append(float(loss))
138 | return torch.tensor(losses)
139 |
--------------------------------------------------------------------------------
/analysis/emb_lr_analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Analysis of the effect of the embedding LR update on the subsequent matmul\n",
8 | "\n",
9 | "I wanted to write this out in a notebook to make sure I understood the way in which the embedding update effects the subsequent matmul.\n",
10 | "\n",
11 | "No revelations unfortunately - it still seems as though our rule can't be justified this way (it is \"unnatural\"!). Under the \"no-alignment\" assumption the standard embedding LR breaks, but unfortunately our fix does nothing to help. Oh well."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 1,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import torch\n",
21 | "from torch import randn\n",
22 | "from typing import Iterable"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "def rms(*xs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:\n",
32 | " if len(xs) == 1:\n",
33 | " return xs[0].pow(2).mean().sqrt()\n",
34 | " return tuple(rms(x) for x in xs)"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "## Setup\n",
42 | "\n",
43 | "Toggle `full_alignment` and `umup_lr_rule` to see the effect. mup scaling is used by default."
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 3,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "d = 2**11\n",
53 | "full_alignment = True\n",
54 | "umup_lr_rule = False\n",
55 | "\n",
56 | "w_lr = d ** -(1 if full_alignment else 0.5)\n",
57 | "e_lr = d ** -(0.5 if umup_lr_rule else 0)"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "## Model & update\n",
65 | "\n",
66 | "Everything can be described in terms of these three tensors (a single embedding vector, weight matrix and a gradient vector). Note that I assume the gradient is unit-scale, and then just use the adam LR rules but under and SGD-like update (I appreciate this is a bit odd, but it's simple and the maths should work out)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 4,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "data": {
76 | "text/plain": [
77 | "(tensor(0.9984), tensor(0.0221), tensor(0.9882))"
78 | ]
79 | },
80 | "execution_count": 4,
81 | "metadata": {},
82 | "output_type": "execute_result"
83 | }
84 | ],
85 | "source": [
86 | "e1 = randn(d, 1)\n",
87 | "W1 = randn(d + 1, d) * d**-0.5\n",
88 | "g = randn(d + 1, 1)\n",
89 | "rms(\n",
90 | " e1, W1, g\n",
91 | ") # all \"well-scaled\", except the weight which is 1/sqrt(d) (this isn't unit scaling!)"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "metadata": {},
97 | "source": [
98 | "Then we just run:"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 5,
104 | "metadata": {},
105 | "outputs": [
106 | {
107 | "data": {
108 | "text/plain": [
109 | "tensor(0.9953)"
110 | ]
111 | },
112 | "execution_count": 5,
113 | "metadata": {},
114 | "output_type": "execute_result"
115 | }
116 | ],
117 | "source": [
118 | "x1 = W1 @ e1\n",
119 | "rms(x1) # well-scaled"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 6,
125 | "metadata": {},
126 | "outputs": [
127 | {
128 | "data": {
129 | "text/plain": [
130 | "((tensor(0.9977), tensor(0.0005)), 0.00048828125)"
131 | ]
132 | },
133 | "execution_count": 6,
134 | "metadata": {},
135 | "output_type": "execute_result"
136 | }
137 | ],
138 | "source": [
139 | "u_e = W1.T @ g * e_lr\n",
140 | "u_W = g @ e1.T * w_lr\n",
141 | "(\n",
142 | " rms(u_e, u_W),\n",
143 | " 1 / d,\n",
144 | ") # the weight update is under-scaled (to be expected I think), though as a rank-1 matrix it has a much higher (O(1)) spectral norm! This means its effect doesn't \"go to zero\" in inf. width, though the rms does."
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 7,
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "data": {
154 | "text/plain": [
155 | "(tensor(0.9998), tensor(0.0221))"
156 | ]
157 | },
158 | "execution_count": 7,
159 | "metadata": {},
160 | "output_type": "execute_result"
161 | }
162 | ],
163 | "source": [
164 | "e2 = e1 + u_e\n",
165 | "e2_std = e2.std()\n",
166 | "e2 /= e2_std # Why is `/ e2.std()` allowed/justified? Normally we'd have a much smaller weight update (scaled down by small LR constant), and then the original weight would be decayed a bit, keeping this at about rms=1. This re-scaling does something similar, though allows us to see the effect of the weight update scaling more clearly.\n",
167 | "W2 = W1 + u_W\n",
168 | "rms(\n",
169 | " e2, W2\n",
170 | ") # Update is well-scaled. Weight has barely changed from its 1/sqrt(d) starting point"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 8,
176 | "metadata": {},
177 | "outputs": [
178 | {
179 | "data": {
180 | "text/plain": [
181 | "tensor(1.7412)"
182 | ]
183 | },
184 | "execution_count": 8,
185 | "metadata": {},
186 | "output_type": "execute_result"
187 | }
188 | ],
189 | "source": [
190 | "x2 = W2 @ e2\n",
191 | "rms(x2) # ~well-scaled. Certainly doesn't scale with a significant power of d"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "## Analysis\n",
199 | "\n",
200 | "Now we break this down into its constituent terms.\n",
201 | "\n",
202 | "First checking that they combine to the original"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 9,
208 | "metadata": {},
209 | "outputs": [
210 | {
211 | "data": {
212 | "text/plain": [
213 | "True"
214 | ]
215 | },
216 | "execution_count": 9,
217 | "metadata": {},
218 | "output_type": "execute_result"
219 | }
220 | ],
221 | "source": [
222 | "torch.allclose(x2, (W1 + u_W) @ (e1 + u_e * e_lr) / e2_std, atol=1e-6)\n",
223 | "torch.allclose(x2, (W1 + g @ e1.T * w_lr) @ (e1 + W1.T @ g * e_lr) / e2_std, atol=1e-6)"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 10,
229 | "metadata": {},
230 | "outputs": [
231 | {
232 | "data": {
233 | "text/plain": [
234 | "True"
235 | ]
236 | },
237 | "execution_count": 10,
238 | "metadata": {},
239 | "output_type": "execute_result"
240 | }
241 | ],
242 | "source": [
243 | "# t1 = W1 @ e1 (== x1)\n",
244 | "t2 = W1 @ W1.T @ g * e_lr\n",
245 | "t3 = g @ e1.T * w_lr @ e1\n",
246 | "t4 = g @ e1.T * w_lr @ W1.T @ g * e_lr\n",
247 | "torch.allclose(x2, (x1 + t2 + t3 + t4) / e2_std, atol=1e-5)"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "### Weight @ emb_update (t2)\n",
255 | "\n",
256 | "This is well-scaled under the original emb lr rule, but not under our lr rule - which isn't a great sign for our approach"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 11,
262 | "metadata": {},
263 | "outputs": [
264 | {
265 | "name": "stdout",
266 | "output_type": "stream",
267 | "text": [
268 | "rms(W1, g), e_lr=((tensor(0.0221), tensor(0.9882)), 1)\n",
269 | "rms(W1 @ W1.T)=tensor(0.0312)\n",
270 | "rms(W1.T @ g)=tensor(0.9977)\n",
271 | "rms(W1 @ W1.T @ g * e_lr / e2_std)=tensor(0.9857)\n"
272 | ]
273 | }
274 | ],
275 | "source": [
276 | "print(f\"{rms(W1, g), e_lr=}\")\n",
277 | "print(f\"{rms(W1 @ W1.T)=}\")\n",
278 | "print(f\"{rms(W1.T @ g)=}\")\n",
279 | "print(f\"{rms(W1 @ W1.T @ g * e_lr / e2_std)=}\")"
280 | ]
281 | },
282 | {
283 | "cell_type": "markdown",
284 | "metadata": {},
285 | "source": [
286 | "### Weight_update @ emb (t3)\n",
287 | "\n",
288 | "This is well-scaled under the original emb lr rule and our rule"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 12,
294 | "metadata": {},
295 | "outputs": [
296 | {
297 | "name": "stdout",
298 | "output_type": "stream",
299 | "text": [
300 | "rms(g, e1)=(tensor(0.9882), tensor(0.9984))\n",
301 | "rms(g @ e1.T)=tensor(0.9866)\n",
302 | "rms(e1.T @ e1 * w_lr)=tensor(0.9968)\n",
303 | "rms(g @ e1.T * w_lr @ e1)=tensor(0.9850)\n"
304 | ]
305 | }
306 | ],
307 | "source": [
308 | "print(f\"{rms(g, e1)=}\")\n",
309 | "print(f\"{rms(g @ e1.T)=}\")\n",
310 | "print(f\"{rms(e1.T @ e1 * w_lr)=}\")\n",
311 | "print(f\"{rms(g @ e1.T * w_lr @ e1)=}\")"
312 | ]
313 | },
314 | {
315 | "cell_type": "markdown",
316 | "metadata": {},
317 | "source": [
318 | "### Weight_update @ emb_update (t4)\n",
319 | "\n",
320 | "This vanishes with width under the original emb lr and our rule. Probably a good thing?"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": 13,
326 | "metadata": {},
327 | "outputs": [
328 | {
329 | "name": "stdout",
330 | "output_type": "stream",
331 | "text": [
332 | "rms(g @ e1.T @ W1.T @ g)=tensor(46.5558)\n",
333 | "rms(g @ e1.T * w_lr @ W1.T @ g * e_lr)=tensor(0.0227)\n"
334 | ]
335 | }
336 | ],
337 | "source": [
338 | "print(f\"{rms(g @ e1.T @ W1.T @ g)=}\")\n",
339 | "print(f\"{rms(g @ e1.T * w_lr @ W1.T @ g * e_lr)=}\")"
340 | ]
341 | }
342 | ],
343 | "metadata": {
344 | "kernelspec": {
345 | "display_name": ".venv",
346 | "language": "python",
347 | "name": "python3"
348 | },
349 | "language_info": {
350 | "codemirror_mode": {
351 | "name": "ipython",
352 | "version": 3
353 | },
354 | "file_extension": ".py",
355 | "mimetype": "text/x-python",
356 | "name": "python",
357 | "nbconvert_exporter": "python",
358 | "pygments_lexer": "ipython3",
359 | "version": "3.11.9"
360 | }
361 | },
362 | "nbformat": 4,
363 | "nbformat_minor": 2
364 | }
365 |
--------------------------------------------------------------------------------
/dev:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
3 |
4 | """Dev task launcher."""
5 |
6 | import argparse
7 | import datetime
8 | import os
9 | import subprocess
10 | import sys
11 | from pathlib import Path
12 | from typing import Any, Callable, Iterable, List, Optional, TypeVar
13 |
14 | # Utilities
15 |
16 |
17 | def run(command: Iterable[Any]) -> None:
18 | """Run a command, terminating on failure."""
19 | cmd = [str(arg) for arg in command if arg is not None]
20 | print("$ " + " ".join(cmd), file=sys.stderr)
21 | environ = os.environ.copy()
22 | environ["PYTHONPATH"] = f"{os.getcwd()}:{environ.get('PYTHONPATH', '')}"
23 | exit_code = subprocess.call(cmd, env=environ)
24 | if exit_code:
25 | sys.exit(exit_code)
26 |
27 |
28 | T = TypeVar("T")
29 |
30 |
31 | def cli(*args: Any, **kwargs: Any) -> Callable[[T], T]:
32 | """Declare a CLI command / arguments for that command."""
33 |
34 | def wrap(func: T) -> T:
35 | if not hasattr(func, "cli_args"):
36 | setattr(func, "cli_args", [])
37 | if args or kwargs:
38 | getattr(func, "cli_args").append((args, kwargs))
39 | return func
40 |
41 | return wrap
42 |
43 |
44 | # Commands
45 |
46 | PYTHON_ROOTS = ["unit_scaling", "dev", "examples"]
47 |
48 |
49 | @cli("-k", "--filter")
50 | def tests(filter: Optional[str]) -> None:
51 | """run Python tests"""
52 | run(
53 | [
54 | "python",
55 | "-m",
56 | "pytest",
57 | "unit_scaling",
58 | None if filter else "--cov=unit_scaling",
59 | *(["-k", filter] if filter else []),
60 | ]
61 | )
62 |
63 |
64 | @cli("commands", nargs="*")
65 | def python(commands: List[Any]) -> None:
66 | """run Python with the current directory on PYTHONPATH, for development"""
67 | run(["python"] + commands)
68 |
69 |
70 | @cli()
71 | def lint() -> None:
72 | """run static analysis"""
73 | run(["python", "-m", "flake8", *PYTHON_ROOTS])
74 | run(["python", "-m", "mypy", *PYTHON_ROOTS])
75 |
76 |
77 | @cli("--check", action="store_true")
78 | def format(check: bool) -> None:
79 | """autoformat all sources"""
80 | run(["python", "-m", "black", "--check" if check else None, *PYTHON_ROOTS])
81 | run(["python", "-m", "isort", "--check" if check else None, *PYTHON_ROOTS])
82 |
83 |
84 | @cli()
85 | def copyright() -> None:
86 | """check for Graphcore copyright headers on relevant files"""
87 | command = (
88 | f"find {' '.join(PYTHON_ROOTS)} -type f -not -name *.pyc -not -name *.json"
89 | " -not -name .gitignore -not -name *_version.py"
90 | " | xargs grep -L 'Copyright (c) 202. Graphcore Ltd[.] All rights reserved[.]'"
91 | )
92 | print(f"$ {command}", file=sys.stderr)
93 | # Note: grep exit codes are not consistent between versions, so we don't use
94 | # check=True
95 | output = (
96 | subprocess.run(
97 | command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
98 | )
99 | .stdout.decode()
100 | .strip()
101 | )
102 | if output:
103 | print(
104 | "Error - failed copyright header check in:\n "
105 | + output.replace("\n", "\n "),
106 | file=sys.stderr,
107 | )
108 | print("Template(s):")
109 | comment_prefixes = {
110 | {".cpp": "//"}.get(Path(f).suffix, "#") for f in output.split("\n")
111 | }
112 | for prefix in comment_prefixes:
113 | print(
114 | (
115 | f"{prefix} Copyright (c) {datetime.datetime.now().year}"
116 | " Graphcore Ltd. All rights reserved."
117 | ),
118 | file=sys.stderr,
119 | )
120 | sys.exit(1)
121 |
122 |
123 | @cli()
124 | def doc() -> None:
125 | """generate API documentation"""
126 | subprocess.call(["rm", "-r", "docs/generated", "docs/_build"])
127 | run(
128 | [
129 | "make",
130 | "-C",
131 | "docs",
132 | "html",
133 | ]
134 | )
135 |
136 |
137 | @cli(
138 | "-s",
139 | "--skip",
140 | nargs="*",
141 | default=[],
142 | choices=["tests", "lint", "format", "copyright"],
143 | help="commands to skip",
144 | )
145 | def ci(skip: List[str] = []) -> None:
146 | """run all continuous integration tests & checks"""
147 | if "tests" not in skip:
148 | tests(filter=None)
149 | if "lint" not in skip:
150 | lint()
151 | if "format" not in skip:
152 | format(check=True)
153 | if "copyright" not in skip:
154 | copyright()
155 | if "doc" not in skip:
156 | doc()
157 |
158 |
159 | # Script
160 |
161 |
162 | def _main() -> None:
163 | # Build an argparse command line by finding globals in the current module
164 | # that are marked via the @cli() decorator. Each one becomes a subcommand
165 | # running that function, usage "$ ./dev fn_name ...args".
166 | parser = argparse.ArgumentParser(description=__doc__)
167 | parser.set_defaults(command=ci)
168 |
169 | subs = parser.add_subparsers()
170 | for key, value in globals().items():
171 | if hasattr(value, "cli_args"):
172 | sub = subs.add_parser(key.replace("_", "-"), help=value.__doc__)
173 | for args, kwargs in value.cli_args:
174 | sub.add_argument(*args, **kwargs)
175 | sub.set_defaults(command=value)
176 |
177 | cli_args = vars(parser.parse_args())
178 | command = cli_args.pop("command")
179 | command(**cli_args)
180 |
181 |
182 | if __name__ == "__main__":
183 | _main()
184 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Read the Docs Inc. All rights reserved.
2 |
3 | # Minimal makefile for Sphinx documentation
4 | #
5 |
6 | # You can set these variables from the command line, and also
7 | # from the environment for the first two.
8 | SPHINXOPTS ?=
9 | SPHINXBUILD ?= sphinx-build
10 | SOURCEDIR = .
11 | BUILDDIR = _build
12 |
13 | # Put it first so that "make" without argument is like "make help".
14 | help:
15 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
16 |
17 | .PHONY: help Makefile
18 |
19 | # Catch-all target: route all unknown targets to Sphinx using the new
20 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
21 | %: Makefile
22 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
23 |
--------------------------------------------------------------------------------
/docs/_static/animation.html:
--------------------------------------------------------------------------------
1 |
5 |
6 |
--------------------------------------------------------------------------------
/docs/_static/scales.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/graphcore-research/unit-scaling/215ba6395230c375156aa139dcd4db950ad19439/docs/_static/scales.png
--------------------------------------------------------------------------------
/docs/_templates/custom-class-template.rst:
--------------------------------------------------------------------------------
1 | ..
2 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
3 | # Copyright (c) 2007-2023 by the Sphinx team. All rights reserved.
4 |
5 | {{ fullname | escape | underline}}
6 |
7 | .. currentmodule:: {{ module }}
8 |
9 | .. autoclass:: {{ objname }}
10 | :members:
11 | :inherited-members: Module
12 | :exclude-members: extra_repr, forward
13 |
--------------------------------------------------------------------------------
/docs/_templates/custom-module-template.rst:
--------------------------------------------------------------------------------
1 | ..
2 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
3 | # Copyright (c) 2007-2023 by the Sphinx team. All rights reserved.
4 |
5 | {{ fullname | escape | underline}}
6 |
7 | .. automodule:: {{ fullname }}
8 |
9 | {% block aliases %}
10 | {% if fullname == 'unit_scaling.constraints' %}
11 | .. rubric:: {{ _('Type Aliases') }}
12 |
13 | .. autosummary::
14 | :toctree:
15 |
16 | BinaryConstraint
17 | TernaryConstraint
18 | VariadicConstraint
19 | {% endif %}
20 | {% endblock %}
21 |
22 | {% block attributes %}
23 | {% if attributes %}
24 | .. rubric:: {{ _('Module Attributes') }}
25 |
26 | .. autosummary::
27 | :toctree:
28 | {% for item in attributes %}
29 | {{ item }}
30 | {%- endfor %}
31 | {% endif %}
32 | {% endblock %}
33 |
34 | {% block functions %}
35 | .. rubric:: {{ _('Functions') }}
36 |
37 | .. autosummary::
38 | :toctree:
39 | {% for item in functions %}
40 | {{ item }}
41 | {%- endfor %}
42 | {% endblock %}
43 |
44 | {% block classes %}
45 | {% if classes %}
46 | .. rubric:: {{ _('Classes') }}
47 |
48 | .. autosummary::
49 | :toctree:
50 | :template: custom-class-template.rst
51 | {% for item in classes %}
52 | {{ item }}
53 | {%- endfor %}
54 | {% endif %}
55 | {% endblock %}
56 |
57 | {% block exceptions %}
58 | {% if exceptions %}
59 | .. rubric:: {{ _('Exceptions') }}
60 |
61 | .. autosummary::
62 | :toctree:
63 | {% for item in exceptions %}
64 | {{ item }}
65 | {%- endfor %}
66 | {% endif %}
67 | {% endblock %}
68 |
69 | {% block modules %}
70 | {% if modules %}
71 | .. rubric:: Modules
72 |
73 | .. autosummary::
74 | :toctree:
75 | :template: custom-module-template.rst
76 | :recursive:
77 | {% for item in modules %}
78 | {% if "test" not in item and "docs" not in item %}
79 | {{ item }}
80 | {% endif %}
81 | {%- endfor %}
82 | {% endif %}
83 | {% endblock %}
84 |
--------------------------------------------------------------------------------
/docs/api_reference.rst:
--------------------------------------------------------------------------------
1 | API reference
2 | =============
3 |
4 | :code:`unit-scaling` is implemented using thin wrappers around existing :code:`torch.nn`
5 | classes and functions. Documentation also inherits from the standard PyTorch docs, with
6 | modifications for scaling. Note that some docs may no longer be relevant but are
7 | nevertheless inherited.
8 |
9 | The API is built to mirror :code:`torch.nn` as closely as possible, such that PyTorch
10 | classes and functions can easily be swapped-out for their unit-scaled equivalents.
11 |
12 | For PyTorch code which uses the following imports:
13 |
14 | .. code-block::
15 |
16 | from torch import nn
17 | from torch.nn import functional as F
18 |
19 | Unit scaling can be applied by first adding:
20 |
21 | .. code-block::
22 |
23 | import unit_scaling as uu
24 | from unit_scaling import functional as U
25 |
26 | and then replacing the letters :code:`nn` with :code:`uu` and
27 | :code:`F` with :code:`U`, for those classes/functions to be unit-scaled
28 | (assuming they are supported).
29 |
30 | Click below for the full documentation:
31 |
32 | .. autosummary::
33 | :toctree: generated
34 | :template: custom-module-template.rst
35 | :recursive:
36 |
37 | unit_scaling
38 | unit_scaling.analysis
39 | unit_scaling.constraints
40 | unit_scaling.formats
41 | unit_scaling.functional
42 | unit_scaling.optim
43 | unit_scaling.scale
44 | unit_scaling.transforms
45 | unit_scaling.transforms.utils
46 | unit_scaling.utils
47 | unit_scaling.core.functional
48 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 | # Copyright (c) 2022 Read the Docs Inc. All rights reserved.
3 |
4 | # Configuration file for the Sphinx documentation builder.
5 | #
6 | # For the full list of built-in configuration values, see the documentation:
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
8 |
9 | # Setup based on https://example-sphinx-basic.readthedocs.io
10 | import os
11 | import sys
12 | from unit_scaling._version import __version__
13 |
14 | sys.path.insert(0, os.path.abspath(".."))
15 |
16 | # -- Project information -----------------------------------------------------
17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
18 |
19 | project = "unit-scaling"
20 | copyright = "(c) 2023 Graphcore Ltd. All rights reserved"
21 | author = "Charlie Blake, Douglas Orr"
22 | version = __version__
23 |
24 | # -- General configuration ---------------------------------------------------
25 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
26 |
27 | extensions = [
28 | "sphinx.ext.duration",
29 | "sphinx.ext.doctest",
30 | "sphinx.ext.autodoc",
31 | "sphinx.ext.autosummary",
32 | "sphinx.ext.intersphinx",
33 | "sphinx.ext.napoleon", # support for google-style docstrings
34 | "myst_parser", # support for including markdown files in .rst files (e.g. readme)
35 | "sphinx.ext.viewcode", # adds source code to docs
36 | "sphinx.ext.autosectionlabel", # links to sections in the same document
37 | "sphinx.ext.mathjax", # equations
38 | ]
39 |
40 | autosummary_generate = True
41 | autosummary_imported_members = True
42 | autosummary_ignore_module_all = False
43 | napoleon_google_docstring = True
44 | napoleon_numpy_docstring = False
45 |
46 | numfig_format = {
47 | "section": "Section {number}. {name}",
48 | "figure": "Fig. %s",
49 | "table": "Table %s",
50 | "code-block": "Listing %s",
51 | }
52 |
53 | intersphinx_mapping = {
54 | "rtd": ("https://docs.readthedocs.io/en/stable/", None),
55 | "python": ("https://docs.python.org/3/", None),
56 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None),
57 | "pytorch": ("https://pytorch.org/docs/stable/", None),
58 | }
59 | intersphinx_disabled_domains = ["std"]
60 |
61 | autodoc_type_aliases = {
62 | "BinaryConstraint": "unit_scaling.constraints.BinaryConstraint",
63 | "TernaryConstraint": "unit_scaling.constraints.TernaryConstraint",
64 | "VariadicConstraint": "unit_scaling.constraints.VariadicConstraint",
65 | } # make docgen output name of alias rather than definition.
66 |
67 | templates_path = ["_templates"]
68 |
69 | # -- Options for EPUB output
70 | epub_show_urls = "footnote"
71 |
72 | # List of patterns, relative to source directory, that match files and
73 | # directories to ignore when looking for source files.
74 | # This pattern also affects html_static_path and html_extra_path.
75 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
76 |
77 | html_favicon = "_static/scales.png"
78 |
79 | # -- Options for HTML output -------------------------------------------------
80 |
81 | # The theme to use for HTML and HTML Help pages. See the documentation for
82 | # a list of builtin themes.
83 | #
84 | html_theme = "sphinx_rtd_theme"
85 |
86 | # Add any paths that contain custom static files (such as style sheets) here,
87 | # relative to this directory. They are copied after the builtin static files,
88 | # so a file named "default.css" will overwrite the builtin "default.css".
89 | html_static_path = ["_static"]
90 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Unit Scaling
2 | ============
3 |
4 | Welcome to the :code:`unit-scaling` library. This library is designed to facilitate
5 | the use of the *unit scaling* and *u-µP* methods, as outlined in the papers
6 | `Unit Scaling: Out-of-the-Box Low-Precision Training (ICML, 2023)
7 | `_ and
8 | `u-μP: The Unit-Scaled Maximal Update Parametrization
9 | `_
10 |
11 | For a demonstration of the library, see `u-μP using the unit_scaling library
12 | `_ — a notebook showing the definition and training of a u-µP language model, comparing against Standard Parametrization (SP).
13 |
14 | Installation
15 | ------------
16 |
17 | To install :code:`unit-scaling`, run:
18 |
19 | .. code-block::
20 |
21 | pip install unit-scaling
22 |
23 | Getting Started
24 | ---------------
25 |
26 | We recommend that new users get started with :numref:`User Guide`.
27 |
28 | A reference outlining our API can be found at :numref:`API reference`.
29 |
30 | The following video gives a broad overview of the workings of unit scaling.
31 |
32 | .. raw:: html
33 | :file: _static/animation.html
34 |
35 | .. Note:: The library is currently in its *beta* release.
36 | Some features have yet to be implemented and occasional bugs may be present.
37 | We're keen to help users with any problems they encounter.
38 |
39 | `The following slides `_ also give an overview of u-µP.
40 |
41 | Development
42 | -----------
43 |
44 | For those who wish to develop on the :code:`unit-scaling` codebase, clone or fork our
45 | `GitHub repo `_ and follow the
46 | instructions in our :doc:`developer guide `.
47 |
48 | .. toctree::
49 | :caption: Contents
50 | :numbered:
51 | :maxdepth: 3
52 |
53 | User guide
54 | Developer guide
55 | Limitations
56 | API reference
57 |
--------------------------------------------------------------------------------
/docs/limitations.rst:
--------------------------------------------------------------------------------
1 | Limitations
2 | ===========
3 |
4 | :code:`unit-scaling` is a new library and (despite our best efforts!) we can't guarantee
5 | it will be bug-free or feature-complete. We're keen to assist anyone who wants to use
6 | the library, and help them work through any issues.
7 |
8 | Known limitations of the library include:
9 |
10 | 1. **Op coverage:** we've currently focussed on adding common transformer operations — other ops may be missing (though we can add most requested ops without difficulty).
11 | 2. **Using transforms with torch.compile:** currently our transforms (for example :code:`unit_scale`, :code:`simulate_fp8`) can't be used directly with :code:`torch.compile`. We provide a special compilation function to get around this: :code:`unit_scaling.transforms.compile` (see docs for more details), though this only works with :code:`unit_scale` and not :code:`simulate_fp8`.
12 | 3. **Distributed training:** although we suspect distributed training will still work reasonably well with the current library, we haven't tested this.
13 |
14 | This list is not exhaustive and we encourage you to get in touch if you have
15 | feature-requests not listed here.
16 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | :: Copyright (c) 2022 Read the Docs Inc. All rights reserved.
2 | @ECHO OFF
3 |
4 | pushd %~dp0
5 |
6 | REM Command file for Sphinx documentation
7 |
8 | if "%SPHINXBUILD%" == "" (
9 | set SPHINXBUILD=sphinx-build
10 | )
11 | set SOURCEDIR=.
12 | set BUILDDIR=_build
13 |
14 | %SPHINXBUILD% >NUL 2>NUL
15 | if errorlevel 9009 (
16 | echo.
17 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
18 | echo.installed, then set the SPHINXBUILD environment variable to point
19 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
20 | echo.may add the Sphinx directory to PATH.
21 | echo.
22 | echo.If you don't have Sphinx installed, grab it from
23 | echo.https://www.sphinx-doc.org/
24 | exit /b 1
25 | )
26 |
27 | if "%1" == "" goto help
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/posts/almost_scaled_dot_product_attention.md:
--------------------------------------------------------------------------------
1 | # Almost-scaled dot-product attention
2 |
3 | **This post [has moved](https://graphcore-research.github.io/posts/almost_scaled/)**.
4 |
5 | Note that the approach and equations described in this post are legacy and do not reflect the current implementation of u-μP. Please see the code for a definitive reference.
6 |
--------------------------------------------------------------------------------
/docs/u-muP_slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/graphcore-research/unit-scaling/215ba6395230c375156aa139dcd4db950ad19439/docs/u-muP_slides.pdf
--------------------------------------------------------------------------------
/examples/.gitignore:
--------------------------------------------------------------------------------
1 | *.json
2 |
--------------------------------------------------------------------------------
/examples/scale_analysis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 | import torch
3 |
4 | import unit_scaling as uu
5 | from unit_scaling.utils import analyse_module
6 |
7 | print("=== Unit-scaled Linear ===\n")
8 |
9 | batch_size = 2**8
10 | hidden_size = 2**10
11 | out_size = 2**10
12 | input = torch.randn(batch_size, hidden_size).requires_grad_()
13 | backward = torch.randn(batch_size, out_size)
14 |
15 | annotated_code = analyse_module(uu.Linear(hidden_size, out_size), input, backward)
16 | print(annotated_code)
17 |
18 | print("=== Unit-scaled MLP ===\n")
19 |
20 | batch_size = 2**8
21 | hidden_size = 2**10
22 | input = torch.randn(batch_size, hidden_size).requires_grad_()
23 | backward = torch.randn(batch_size, hidden_size)
24 |
25 | annotated_code = analyse_module(uu.MLP(hidden_size), input, backward)
26 | print(annotated_code)
27 |
28 | print("=== Unit-scaled MHSA ===\n")
29 |
30 | batch_size = 2**8
31 | seq_len = 2**6
32 | hidden_size = 2**6
33 | heads = 4
34 | dropout_p = 0.1
35 | input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_()
36 | backward = torch.randn(batch_size, seq_len, hidden_size)
37 |
38 | annotated_code = analyse_module(
39 | uu.MHSA(hidden_size, heads, is_causal=False, dropout_p=dropout_p), input, backward
40 | )
41 | print(annotated_code)
42 |
43 | print("=== Unit-scaled Transformer Layer ===\n")
44 |
45 | batch_size = 2**8
46 | seq_len = 2**6
47 | hidden_size = 2**6
48 | heads = 4
49 | dropout_p = 0.1
50 | input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_()
51 | backward = torch.randn(batch_size, seq_len, hidden_size)
52 |
53 | annotated_code = analyse_module(
54 | uu.TransformerLayer(
55 | hidden_size,
56 | heads,
57 | mhsa_tau=0.1,
58 | mlp_tau=1.0,
59 | is_causal=False,
60 | dropout_p=dropout_p,
61 | ),
62 | input,
63 | backward,
64 | )
65 | print(annotated_code)
66 |
67 | print("=== Unit-scaled Full Transformer Decoder ===\n")
68 |
69 | batch_size = 2**8
70 | seq_len = 2**6
71 | hidden_size = 2**6
72 | vocab_size = 2**12
73 | layers = 2
74 | heads = 4
75 | dropout_p = 0.1
76 |
77 | seq = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len + 1))
78 | input_idxs = seq[:, :-1]
79 | labels = torch.roll(seq, -1, 1)[:, 1:]
80 |
81 | annotated_code = analyse_module(
82 | uu.TransformerDecoder(hidden_size, vocab_size, layers, heads, dropout_p),
83 | (input_idxs, labels),
84 | )
85 | print(annotated_code)
86 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # Configuration inspired by official pypa example:
2 | # https://github.com/pypa/sampleproject/blob/main/pyproject.toml
3 |
4 | [build-system]
5 | requires = ["setuptools>=68.2.2", "setuptools-scm"]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [project]
9 | name = "unit-scaling"
10 | description = "A library for unit scaling in PyTorch, based on the paper 'u-muP: The Unit-Scaled Maximal Update Parametrization.'"
11 | readme = "README.md"
12 | authors = [
13 | { name = "Charlie Blake", email = "charlieb@graphcore.ai" },
14 | { name = "Douglas Orr", email = "douglaso@graphcore.ai" },
15 | ]
16 | requires-python = ">=3.9"
17 | classifiers = [
18 | "Development Status :: 4 - Beta",
19 | "Intended Audience :: Developers",
20 | "Intended Audience :: Science/Research",
21 | "License :: OSI Approved :: Apache Software License",
22 | "Programming Language :: Python :: 3.9",
23 | "Programming Language :: Python :: 3.10",
24 | "Programming Language :: Python :: 3.11",
25 | "Programming Language :: Python :: 3.12",
26 | "Programming Language :: Python :: 3.13",
27 | "Programming Language :: Python :: 3.14",
28 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
29 | ]
30 | dependencies = [
31 | "datasets",
32 | "docstring-parser",
33 | "einops",
34 | "numpy<2.0.0",
35 | "seaborn",
36 | "tabulate",
37 | "torch>=2.2",
38 | ]
39 | dynamic = ["version"]
40 |
41 | [project.urls]
42 | "Homepage" = "https://github.com/graphcore-research/unit-scaling/#readme"
43 | "Bug Reports" = "https://github.com/graphcore-research/unit-scaling/issues"
44 | "Source" = "https://github.com/graphcore-research/unit-scaling/"
45 |
46 | [project.optional-dependencies]
47 | dev = ["check-manifest"]
48 | test = ["pytest"]
49 |
50 | [tool.setuptools]
51 | packages = ["unit_scaling", "unit_scaling.core", "unit_scaling.transforms"]
52 |
53 | [tool.setuptools.dynamic]
54 | version = {attr = "unit_scaling._version.__version__"}
55 |
56 | [tool.setuptools_scm]
57 | version_file = "unit_scaling/_version.py"
58 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | # Look in pytorch-cpu first, then pypi second
2 | --index-url https://download.pytorch.org/whl/cpu
3 | --extra-index-url=https://pypi.org/simple
4 |
5 | # Same as pyproject.toml, but with versions locked-in
6 | datasets==3.1.0
7 | docstring-parser==0.16
8 | einops==0.8.0
9 | numpy==1.26.4
10 | seaborn==0.13.2
11 | tabulate==0.9.0
12 | torch==2.5.1+cpu
13 |
14 | # Additional dev requirements
15 | black==24.10.0
16 | flake8==7.1.1
17 | isort==5.13.2
18 | mypy==1.13.0
19 | myst-parser==4.0.0
20 | pandas-stubs==2.2.3.241009
21 | pytest==8.3.3
22 | pytest-cov==6.0.0
23 | setuptools==70.0.0
24 | sphinx==8.1.3
25 | sphinx-rtd-theme==3.0.1
26 | transformers==4.46.1
27 | triton==3.1.0
28 | types-Pygments==2.18.0.20240506
29 | types-tabulate==0.9.0.20240106
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [mypy]
2 | pretty = true
3 | show_error_codes = true
4 | strict = true
5 | check_untyped_defs = true
6 |
7 | # As torch.fx doesn't explicitly export many of its useful modules.
8 | [mypy-torch.fx]
9 | implicit_reexport = True
10 |
11 | [flake8]
12 | # See https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html
13 | max-line-length = 88
14 | extend-ignore = E203,E731
15 |
16 | [isort]
17 | profile = black
18 |
19 | [tool:pytest]
20 | addopts = --no-cov-on-fail
21 |
22 | [coverage:report]
23 | # fail_under = 100
24 | skip_covered = true
25 | show_missing = true
26 | exclude_lines =
27 | pragma: no cover
28 | raise NotImplementedError
29 | assert False
30 |
--------------------------------------------------------------------------------
/unit_scaling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | """Unit-scaled versions of common torch.nn modules."""
4 |
5 | # This all has to be done manually to keep mypy happy.
6 | # Removing the `--no-implicit-reexport` option ought to fix this, but doesn't appear to.
7 |
8 | from . import core, functional, optim, parameter
9 | from ._modules import (
10 | GELU,
11 | MHSA,
12 | MLP,
13 | Conv1d,
14 | CrossEntropyLoss,
15 | DepthModuleList,
16 | DepthSequential,
17 | Dropout,
18 | Embedding,
19 | LayerNorm,
20 | Linear,
21 | LinearReadout,
22 | RMSNorm,
23 | SiLU,
24 | Softmax,
25 | TransformerDecoder,
26 | TransformerLayer,
27 | )
28 | from ._version import __version__
29 | from .analysis import visualiser
30 | from .core.functional import transformer_residual_scaling_rule
31 | from .parameter import MupType, Parameter
32 |
33 | __all__ = [
34 | # Modules
35 | "Conv1d",
36 | "CrossEntropyLoss",
37 | "DepthModuleList",
38 | "DepthSequential",
39 | "Dropout",
40 | "Embedding",
41 | "GELU",
42 | "LayerNorm",
43 | "Linear",
44 | "LinearReadout",
45 | "MHSA",
46 | "MLP",
47 | "MupType",
48 | "RMSNorm",
49 | "SiLU",
50 | "Softmax",
51 | "TransformerDecoder",
52 | "TransformerLayer",
53 | # Modules
54 | "core",
55 | "functional",
56 | "optim",
57 | "parameter",
58 | # Functions
59 | "Parameter",
60 | "transformer_residual_scaling_rule",
61 | "visualiser",
62 | "__version__",
63 | ]
64 |
--------------------------------------------------------------------------------
/unit_scaling/_internal_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import inspect
4 | import sys
5 | from typing import List
6 |
7 |
8 | def generate__all__(module_name: str, include_imports: bool = False) -> List[str]:
9 | """Generates the contents of __all__ by extracting every public function/class/etc.
10 | except those imported from other modules. Necessary for Sphinx docs."""
11 | module = sys.modules[module_name]
12 | all = []
13 | for name, member in inspect.getmembers(module):
14 | # Skip members imported from other modules and private members
15 | is_local = inspect.getmodule(member) == module
16 | if (include_imports or is_local) and not name.startswith("_"):
17 | all.append(name)
18 | return all
19 |
--------------------------------------------------------------------------------
/unit_scaling/constraints.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | """Common scale-constraints used in unit-scaled operations."""
4 |
5 | from __future__ import annotations # required for docs to alias type annotations
6 |
7 | import sys
8 | from math import pow, prod
9 | from typing import Optional, Tuple
10 |
11 | from ._internal_utils import generate__all__
12 |
13 |
14 | def gmean(*scales: float) -> float:
15 | """Computes the geometric mean of the provided scales. Recommended for unit scaling.
16 |
17 | Args:
18 | scales: (*float): the group of constrained scales
19 |
20 | Returns:
21 | float: the geometric mean.
22 | """
23 | return pow(prod(scales), (1 / len(scales)))
24 |
25 |
26 | def hmean(*scales: float) -> float:
27 | """Computes the harmonic mean of the provided scales. Used in Xavier/Glorot scaling.
28 |
29 | Args:
30 | scales: (*float): the group of constrained scales
31 |
32 | Returns:
33 | float: the harmonic mean.
34 | """
35 | return 1 / (sum(1 / s for s in scales) / len(scales))
36 |
37 |
38 | def amean(*scales: float) -> float:
39 | """Computes the arithmetic mean of the provided scales.
40 |
41 | Args:
42 | scales: (*float): the group of constrained scales
43 |
44 | Returns:
45 | float: the arithmetic mean.
46 | """
47 | return sum(scales) / len(scales)
48 |
49 |
50 | def to_output_scale(output_scale: float, *grad_input_scale: float) -> float:
51 | """Assumes an output scale is provided and any number of grad input scales:
52 | `(output_scale, *grad_input_scales)`. Selects only `output_scale` as the chosen
53 | scaling factor.
54 |
55 | Args:
56 | output_scale (float): the scale of the op's output
57 | grad_input_scales (*float): the scales of the op's input gradients
58 |
59 | Returns:
60 | float: equal to `output_scale`
61 | """
62 | return output_scale
63 |
64 |
65 | def to_grad_input_scale(output_scale: float, grad_input_scale: float) -> float:
66 | """Assumes two provided scales: `(output_scale, grad_input_scale)`. Selects only
67 | `grad_input_scale` as the chosen scaling factor.
68 |
69 | Args:
70 | output_scale (float): the scale of the op's output
71 | grad_input_scale (float): the scale of the op's input gradient
72 |
73 | Returns:
74 | float: equal to `grad_input_scale`
75 | """
76 | return grad_input_scale
77 |
78 |
79 | def to_left_grad_scale(
80 | output_scale: float, left_grad_scale: float, right_grad_scale: float
81 | ) -> float:
82 | """Assumes three provided scales:
83 | `(output_scale, left_grad_scale, right_grad_scale)`. Selects only `left_grad_scale`
84 | as the chosen scaling factor.
85 |
86 | Args:
87 | output_scale (float): the scale of the op's output
88 | left_grad_scale (float): the scale of the op's left input gradient
89 | right_grad_scale (float): the scale of the op's right input gradient
90 |
91 | Returns:
92 | float: equal to `left_grad_scale`
93 | """
94 | return left_grad_scale
95 |
96 |
97 | def to_right_grad_scale(
98 | output_scale: float, left_grad_scale: float, right_grad_scale: float
99 | ) -> float:
100 | """Assumes three provided scales:
101 | `(output_scale, left_grad_scale, right_grad_scale)`. Selects only `right_grad_scale`
102 | as the chosen scaling factor.
103 |
104 | Args:
105 | output_scale (float): the scale of the op's output
106 | left_grad_scale (float): the scale of the op's left input gradient
107 | right_grad_scale (float): the scale of the op's right input gradient
108 |
109 | Returns:
110 | float: equal to `right_grad_scale`
111 | """
112 | return right_grad_scale
113 |
114 |
115 | def apply_constraint(
116 | constraint_name: Optional[str], *scales: float
117 | ) -> Tuple[float, ...]:
118 | """Retrieves the constraint function corresponding to `constraint_name` and applies
119 | it to the group of scales. This name must be that of one of the functions defined in
120 | this module.
121 |
122 | Args:
123 | constraint_name (Optional[str]): The name of the constraint function to be used.
124 |
125 | Raises:
126 | ValueError: if `constraint_name` is not that of a function in this module.
127 |
128 | Returns:
129 | Tuple[float, ...]: the scales after the constraint has been applied.
130 | """
131 | if constraint_name is None or constraint_name == "":
132 | return scales
133 | constraint = getattr(sys.modules[__name__], constraint_name, None)
134 | if constraint is None:
135 | raise ValueError(
136 | f"Constraint: {constraint_name} is not a valid constraint (see"
137 | " unit_scaling.constraints for available options)."
138 | )
139 | scale = constraint(*scales)
140 | return tuple(scale for _ in scales)
141 |
142 |
143 | __all__ = generate__all__(__name__)
144 |
--------------------------------------------------------------------------------
/unit_scaling/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2 |
3 | """Core components for advanced library usage."""
4 |
5 | from . import functional
6 |
7 | __all__ = ["functional"]
8 |
--------------------------------------------------------------------------------
/unit_scaling/core/functional.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2 |
3 | """Core functionality for implementing `unit_scaling.functional`."""
4 |
5 | import math
6 | from typing import Any, Callable, Optional, Tuple
7 |
8 | from torch import Tensor
9 |
10 | from .._internal_utils import generate__all__
11 | from ..constraints import apply_constraint
12 | from ..docs import binary_constraint_docstring, format_docstring
13 | from ..scale import scale_bwd, scale_fwd
14 |
15 |
16 | @format_docstring(binary_constraint_docstring)
17 | def scale_elementwise(
18 | f: Callable[..., Tensor],
19 | output_scale: float,
20 | grad_input_scale: float,
21 | constraint: Optional[str] = "to_output_scale",
22 | ) -> Callable[..., Tensor]:
23 | """Transforms an element-wise function into a scaled version.
24 |
25 | Args:
26 | f (Callable[..., Tensor]): the element-wise function to be scaled. Should take
27 | as its first input a `Tensor`, followed by `*args, **kwargs`.
28 | output_scale (float): the scale to be applied to the output
29 | grad_input_scale (float): the scale to be applied to the grad of the input
30 | {0}
31 |
32 | Returns:
33 | Callable[..., Tensor]: the scaled function
34 | """
35 | output_scale, grad_input_scale = apply_constraint(
36 | constraint, output_scale, grad_input_scale
37 | )
38 |
39 | def scaled_f(input: Tensor, *args: Any, **kwargs: Any) -> Tensor:
40 | input = scale_bwd(input, grad_input_scale)
41 | output = f(input, *args, **kwargs)
42 | return scale_fwd(output, output_scale)
43 |
44 | return scaled_f
45 |
46 |
47 | def logarithmic_interpolation(alpha: float, lower: float, upper: float) -> float:
48 | """Interpolate between lower and upper with logarithmic spacing (constant ratio).
49 |
50 | For example::
51 |
52 | logarithmic_interpolation(alpha=0.0, lower=1/1000, upper=1/10) == 1/1000
53 | logarithmic_interpolation(alpha=0.5, lower=1/1000, upper=1/10) == 1/100
54 | logarithmic_interpolation(alpha=1.0, lower=1/1000, upper=1/10) == 1/10
55 |
56 | Args:
57 | alpha (float): interpolation weight (0=lower, 1=upper)
58 | lower (float): lower limit (alpha=0), must be > 0
59 | upper (float): upper limit (alpha=1), must be > 0
60 |
61 | Returns:
62 | float: interpolated value
63 | """
64 | return math.exp(alpha * math.log(upper) + (1 - alpha) * math.log(lower))
65 |
66 |
67 | def rms(
68 | x: Tensor,
69 | dims: Optional[Tuple[int, ...]] = None,
70 | keepdim: bool = False,
71 | eps: float = 0.0,
72 | ) -> Tensor:
73 | """Compute the RMS :math:`\\sqrt{\\mathrm{mean}(x^2) + \\epsilon}` of a tensor."""
74 | mean = x.float().pow(2).mean(dims, keepdim=keepdim)
75 | if eps:
76 | mean = mean + eps
77 | return mean.sqrt().to(x.dtype)
78 |
79 |
80 | ResidualScalingFn = Callable[[int, int], float]
81 |
82 |
83 | def transformer_residual_scaling_rule(
84 | residual_mult: float = 1.0, residual_attn_ratio: float = 1.0
85 | ) -> ResidualScalingFn:
86 | """Compute the residual tau ratios for the default transformer rule.
87 |
88 | For a transformer stack that starts with embedding, then alternates
89 | between attention and MLP layers, this rule ensures:
90 |
91 | - Every attention layer contributes the same scale.
92 | - Every MLP layer contributes the same scale.
93 | - The ratio of the average (variance) contribution of all attention
94 | and all MLP layers to the embedding layer is `residual_mult`.
95 | - The ratio of Attn to MLP contributions is `residual_attn_ratio`.
96 |
97 | If both hyperparameters are set to 1.0, the total contribution of
98 | embedding, attention and MLP layers are all equal.
99 |
100 | This scheme is described in Appendix G of the u-μP paper,
101 |
102 | Args:
103 | residual_mult (float, optional): contribution of residual layers
104 | (relative to an initial/embedding layer).
105 | residual_attn_ratio (float, optional): contribution of attn
106 | layers relative to FFN layers.
107 |
108 | Returns:
109 | :code:`fn(index, layers) -> tau` : a function for calculating tau
110 | at a given depth.
111 | """
112 | alpha_mlp = residual_mult * (2 / (1 + residual_attn_ratio**2)) ** 0.5
113 | alpha_attn = residual_attn_ratio * alpha_mlp
114 |
115 | def _tau(index: int, layers: int) -> float:
116 | n_attn = (index + 1) // 2
117 | n_mlp = index // 2
118 | tau = (alpha_attn if (index % 2) == 0 else alpha_mlp) / (
119 | layers / 2 + n_attn * alpha_attn**2 + n_mlp * alpha_mlp**2
120 | ) ** 0.5
121 | return tau # type:ignore[no-any-return]
122 |
123 | return _tau
124 |
125 |
126 | __all__ = generate__all__(__name__)
127 |
--------------------------------------------------------------------------------
/unit_scaling/docs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import inspect
4 | from functools import wraps
5 | from itertools import zip_longest
6 | from typing import Any, Callable, Iterable, List, Optional, Type, TypeVar
7 |
8 | from docstring_parser.google import (
9 | DEFAULT_SECTIONS,
10 | GoogleParser,
11 | Section,
12 | SectionType,
13 | compose,
14 | )
15 |
16 | T = TypeVar("T")
17 |
18 |
19 | def _validate(
20 | f: Callable[..., T], unsupported_args: Iterable[str] = {}
21 | ) -> Callable[..., T]:
22 | """Wraps the supplied function in a check to ensure its arguments aren't in the
23 | unsupported args list. Unsupported args are by nature optional (they have
24 | a default value). It is assumed this default is valid, but all other values are
25 | invalid."""
26 |
27 | argspec = inspect.getfullargspec(f)
28 |
29 | # argspec.defaults is a tuple of default arguments. These may begin at an offset
30 | # relative to rgspec.args due to args without a default. To zip these properly the
31 | # lists are reversed, zipped, and un-reversed, with missing values filled with `...`
32 | rev_args = reversed(argspec.args)
33 | rev_defaults = reversed(argspec.defaults) if argspec.defaults else []
34 | rev_arg_default_pairs = list(zip_longest(rev_args, rev_defaults, fillvalue=...))
35 | default_kwargs = dict(reversed(rev_arg_default_pairs))
36 |
37 | for arg in unsupported_args:
38 | if arg not in default_kwargs:
39 | raise ValueError(f"unsupported arg '{arg}' is not valid.")
40 | if default_kwargs[arg] is ...:
41 | raise ValueError(f"unsupported arg '{arg}' has no default value")
42 |
43 | @wraps(f)
44 | def _validate_args_supported(*args: Any, **kwargs: Any) -> T:
45 | arg_values = dict(zip(argspec.args, args))
46 | full_kwargs = {**arg_values, **kwargs}
47 | for arg_name, arg_value in full_kwargs.items():
48 | if arg_name in unsupported_args:
49 | arg_default_value = default_kwargs[arg_name]
50 | if arg_value != arg_default_value:
51 | raise ValueError(
52 | f"Support for the '{arg_name}' argument has not been"
53 | " implemented for the unit-scaling library."
54 | " Please remove it or replace it with its default value."
55 | )
56 | return f(*args, **kwargs)
57 |
58 | return _validate_args_supported
59 |
60 |
61 | def _get_docstring_from_target(
62 | source: T,
63 | target: Any,
64 | short_description: Optional[str] = None,
65 | add_args: Optional[List[str]] = None,
66 | unsupported_args: Iterable[str] = {},
67 | ) -> T:
68 | """Takes the docstring from `target`, modifies it, and applies it to `source`."""
69 |
70 | # Make the parser aware of the Shape and Examples sections (standard in torch docs)
71 | parser_sections = DEFAULT_SECTIONS + [
72 | Section("Shape", "shape", SectionType.SINGULAR),
73 | Section("Examples:", "examples", SectionType.SINGULAR),
74 | ]
75 | parser = GoogleParser(sections=parser_sections)
76 | docstring = parser.parse(target.__doc__)
77 | docstring.short_description = short_description
78 | if docstring.long_description:
79 | docstring.long_description += "\n" # fixes "Args:" section merging
80 |
81 | for param in docstring.params:
82 | if param.arg_name in unsupported_args and param.description is not None:
83 | param.description = (
84 | "**[not supported by unit-scaling]** " + param.description
85 | )
86 |
87 | if add_args:
88 | for arg_str in add_args:
89 | # Parse the additional args strings and add them to the docstring object
90 | param_meta = parser._build_meta(arg_str, "Args")
91 | docstring.meta.append(param_meta)
92 |
93 | source.__doc__ = compose(docstring) # docstring object to actual string
94 | return source
95 |
96 |
97 | def inherit_docstring(
98 | short_description: Optional[str] = None,
99 | add_args: Optional[List[str]] = None,
100 | unsupported_args: Iterable[str] = {},
101 | ) -> Callable[[Type[T]], Type[T]]:
102 | """Returns a decorator which causes the wrapped class to inherit its parent
103 | docstring, with the specified modifications applied.
104 |
105 | Args:
106 | short_description (Optional[str], optional): Replaces the top one-line
107 | description in the parent docstring with the one supplied. Defaults to None.
108 | add_args (Optional[List[str]], optional): Appends the supplied argument strings
109 | to the list of arguments. Defaults to None.
110 | unsupported_args (Iterable[str]): A list of arguments which are not supported.
111 | Documentation is updated and runtime checks added to enforce this.
112 |
113 | Returns:
114 | Callable[[Type], Type]: The decorator used to wrap the child class.
115 | """
116 |
117 | def decorator(cls: Type[T]) -> Type[T]:
118 | parent = cls.mro()[1]
119 | source = _get_docstring_from_target(
120 | source=cls,
121 | target=parent,
122 | short_description=short_description,
123 | add_args=add_args,
124 | unsupported_args=unsupported_args,
125 | )
126 | source.__init__ = _validate(source.__init__, unsupported_args) # type: ignore
127 | return source
128 |
129 | return decorator
130 |
131 |
132 | def docstring_from(
133 | target: Callable[..., T],
134 | short_description: Optional[str] = None,
135 | add_args: Optional[List[str]] = None,
136 | unsupported_args: Iterable[str] = {},
137 | ) -> Callable[[Callable[..., T]], Callable[..., T]]:
138 | """Returns a decorator which causes the wrapped object to take the docstring from
139 | the target object, with the specified modifications applied.
140 |
141 | Args:
142 | target (Any): The object to take the docstring from.
143 | short_description (Optional[str], optional): Replaces the top one-line
144 | description in the parent docstring with the one supplied. Defaults to None.
145 | add_args (Optional[List[str]], optional): Appends the supplied argument strings
146 | to the list of arguments. Defaults to None.
147 | unsupported_args (Iterable[str]): A list of arguments which are not supported.
148 | Documentation is updated and runtime checks added to enforce this.
149 |
150 | Returns:
151 | Callable[[Callable], Callable]: The decorator used to wrap the child object.
152 | """
153 |
154 | def decorator(source: Callable[..., T]) -> Callable[..., T]:
155 | source = _get_docstring_from_target(
156 | source=source,
157 | target=target,
158 | short_description=short_description,
159 | add_args=add_args,
160 | unsupported_args=unsupported_args,
161 | )
162 | return _validate(source, unsupported_args)
163 |
164 | return decorator
165 |
166 |
167 | def format_docstring(*args: str) -> Callable[[Callable[..., T]], Callable[..., T]]:
168 | """Returns a decorator that applies `cls.__doc__.format(*args)` to the target class.
169 |
170 | Args:
171 | args: (*str): The arguments to be passed to the docstrings `.format()` method.
172 |
173 | Returns:
174 | Callable[[Type], Type]: A decorator to format the docstring.
175 | """
176 |
177 | def f(cls: T) -> T:
178 | if isinstance(cls.__doc__, str):
179 | cls.__doc__ = cls.__doc__.format(*args)
180 | return cls
181 |
182 | return f
183 |
184 |
185 | binary_constraint_docstring = (
186 | "constraint (Optional[str], optional): The name of the constraint function to be"
187 | " applied to the outputs & input gradient. In this case, the constraint name must"
188 | " be one of:"
189 | " [None, 'gmean', 'hmean', 'amean', 'to_output_scale', 'to_grad_input_scale']"
190 | " (see `unit_scaling.constraints` for details on these constraint functions)."
191 | " Defaults to `gmean`."
192 | )
193 |
194 | ternary_constraint_docstring = (
195 | "constraint (Optional[str], optional): The name of the constraint function to be"
196 | " applied to the outputs & input gradients. In this case, the constraint name must"
197 | " be one of:"
198 | " [None, 'gmean', 'hmean', 'amean', 'to_output_scale', 'to_left_grad_scale',"
199 | " to_right_grad_scale]"
200 | " (see `unit_scaling.constraints` for details on these constraint functions)."
201 | " Defaults to `gmean`."
202 | )
203 |
204 | variadic_constraint_docstring = (
205 | "constraint (Optional[str], optional): The name of the constraint function to be"
206 | " applied to the outputs & input gradients. In this case, the constraint name must"
207 | " be one of:"
208 | " [None, 'gmean', 'hmean', 'amean', 'to_output_scale']"
209 | " (see `unit_scaling.constraints` for details on these constraint functions)."
210 | " Defaults to `gmean`."
211 | )
212 |
213 |
214 | def mult_docstring(name: str = "mult") -> str:
215 | return (
216 | f"{name} (float, optional): a multiplier to be applied to change the shape"
217 | " of a nonlinear function. Typically, high multipliers (> 1) correspond to a"
218 | " 'sharper' (low temperature) function, while low multipliers (< 1) correspond"
219 | " to a 'flatter' (high temperature) function."
220 | )
221 |
--------------------------------------------------------------------------------
/unit_scaling/formats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | """Classes for simulating (non-standard) number formats."""
4 |
5 | from dataclasses import dataclass
6 | from typing import Tuple, cast
7 |
8 | import torch
9 | from torch import Tensor
10 |
11 | from ._internal_utils import generate__all__
12 |
13 | Shape = Tuple[int, ...]
14 |
15 |
16 | @dataclass
17 | class FPFormat:
18 | """Generic representation of a floating-point number format."""
19 |
20 | exponent_bits: int
21 | mantissa_bits: int
22 | rounding: str = "stochastic" # "stochastic|nearest"
23 | srbits: int = 0 # Number of bits for stochastic rounding, zero => use all bits
24 |
25 | def __post_init__(self) -> None:
26 | assert self.exponent_bits >= 2, "FPFormat requires at least 2 exponent bits"
27 | assert (
28 | self.srbits == 0 or self.rounding == "stochastic"
29 | ), "Nonzero srbits for non-stochastic rounding"
30 | if self.srbits == 0 and self.rounding == "stochastic":
31 | self.srbits = 23 - self.mantissa_bits
32 |
33 | @property
34 | def bits(self) -> int:
35 | """The number of bits used by the format."""
36 | return 1 + self.exponent_bits + self.mantissa_bits
37 |
38 | def __str__(self) -> str: # pragma: no cover
39 | return (
40 | f"E{self.exponent_bits}M{self.mantissa_bits}-"
41 | + dict(stochastic="SR", nearest="RN")[self.rounding]
42 | )
43 |
44 | @property
45 | def max_absolute_value(self) -> float:
46 | """The maximum absolute value representable by the format."""
47 | max_exponent = 2 ** (self.exponent_bits - 1) - 1
48 | return cast(float, 2**max_exponent * (2 - 2**-self.mantissa_bits))
49 |
50 | @property
51 | def min_absolute_normal(self) -> float:
52 | """The minimum absolute normal value representable by the format."""
53 | min_exponent = 1 - 2 ** (self.exponent_bits - 1)
54 | return cast(float, 2**min_exponent)
55 |
56 | @property
57 | def min_absolute_subnormal(self) -> float:
58 | """The minimum absolute subnormal value representable by the format."""
59 | return self.min_absolute_normal * 2.0**-self.mantissa_bits
60 |
61 | def quantise(self, x: Tensor) -> Tensor:
62 | """Non-differentiably quantise the given tensor in this format."""
63 | absmax = self.max_absolute_value
64 | downscale = 2.0 ** (127 - 2 ** (self.exponent_bits - 1))
65 | mask = torch.tensor(2 ** (23 - self.mantissa_bits) - 1, device=x.device)
66 | if self.rounding == "stochastic":
67 | srbitsbar = 23 - self.mantissa_bits - self.srbits
68 | offset = (
69 | torch.randint(
70 | 0, 2**self.srbits, x.shape, dtype=torch.int32, device=x.device
71 | )
72 | << srbitsbar
73 | )
74 | # Correct for bias. We can do this only for srbits < 23-mantissa_bits,
75 | # but it is only likely to matter when srbits is small.
76 | if srbitsbar > 0:
77 | offset += 1 << (srbitsbar - 1)
78 |
79 | elif self.rounding == "nearest":
80 | offset = mask // 2
81 | else: # pragma: no cover
82 | raise ValueError(
83 | f'Unexpected FPFormat(rounding="{self.rounding}"),'
84 | ' expected "stochastic" or "nearest"'
85 | )
86 | q = x.to(torch.float32)
87 | q = torch.clip(x, -absmax, absmax)
88 | q /= downscale
89 | q = ((q.view(torch.int32) + offset) & ~mask).view(torch.float32)
90 | q *= downscale
91 | return q.to(x.dtype)
92 |
93 | def quantise_fwd(self, x: Tensor) -> Tensor:
94 | """Quantise the given tensor in the forward pass only."""
95 |
96 | class QuantiseForward(torch.autograd.Function):
97 | @staticmethod
98 | def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor) -> Tensor:
99 | return self.quantise(x)
100 |
101 | @staticmethod
102 | def backward( # type:ignore[override]
103 | ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor
104 | ) -> Tensor:
105 | return grad_y
106 |
107 | return QuantiseForward.apply(x) # type: ignore
108 |
109 | def quantise_bwd(self, x: Tensor) -> Tensor:
110 | """Quantise the given tensor in the backward pass only."""
111 |
112 | class QuantiseBackward(torch.autograd.Function):
113 | @staticmethod
114 | def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor) -> Tensor:
115 | return x
116 |
117 | @staticmethod
118 | def backward( # type:ignore[override]
119 | ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor
120 | ) -> Tensor:
121 | return self.quantise(grad_y)
122 |
123 | return QuantiseBackward.apply(x) # type: ignore
124 |
125 |
126 | def format_to_tuple(format: FPFormat) -> Tuple[int, int]:
127 | """Convert the format into a tuple of `(exponent_bits, mantissa_bits)`"""
128 | return (format.exponent_bits, format.mantissa_bits)
129 |
130 |
131 | def tuple_to_format(t: Tuple[int, int]) -> FPFormat:
132 | """Given a tuple of `(exponent_bits, mantissa_bits)` returns the corresponding
133 | :class:`FPFormat`"""
134 | return FPFormat(*t)
135 |
136 |
137 | __all__ = generate__all__(__name__)
138 |
--------------------------------------------------------------------------------
/unit_scaling/optim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2 |
3 | """Optimizer wrappers that apply scaling rules for u-muP.
4 |
5 | Provides :class:`Adam`, :class:`AdamW`, :class:`SGD` as out-of-the-box
6 | optimizers.
7 |
8 | Alternatively, :func:`scaled_parameters` provides finer control by
9 | transforming a parameter group for any downstream optimizer, given a
10 | function that defines the LR scaling rules.
11 | """
12 |
13 | # mypy: disable-error-code="no-any-return"
14 |
15 | from typing import Any, Callable, Optional, Union
16 |
17 | import torch
18 | from torch import Tensor
19 | from torch.optim.optimizer import ParamsT
20 |
21 | from ._internal_utils import generate__all__
22 | from .docs import inherit_docstring
23 | from .parameter import ParameterData, has_parameter_data
24 |
25 |
26 | def lr_scale_for_depth(param: ParameterData) -> float:
27 | """Calculate the LR scaling factor for depth only."""
28 | if param.mup_scaling_depth is None:
29 | return 1
30 | return param.mup_scaling_depth**-0.5
31 |
32 |
33 | def _get_fan_in(param: ParameterData) -> int:
34 | # Note: the "fan_in" of an embedding layer is the hidden (output) dimension
35 | if len(param.shape) == 1:
36 | return param.shape[0]
37 | if len(param.shape) == 2:
38 | return param.shape[1]
39 | if len(param.shape) == 3:
40 | return param.shape[1] * param.shape[2]
41 | raise ValueError(
42 | f"Cannot get fan_in of `ndim >= 4` param, shape={tuple(param.shape)}"
43 | )
44 |
45 |
46 | def lr_scale_func_sgd(
47 | readout_constraint: Optional[str],
48 | ) -> Callable[[ParameterData], float]:
49 | """Calculate the LR scaling factor for :class:`torch.optim.SGD`."""
50 |
51 | if readout_constraint is None:
52 | # If there is no readout constraint we will have unit-scaled gradients and hence
53 | # unit-scaled weight updates. In this case the scaling rules are the same as
54 | # for Adam, which naturally has unit-scaled weight updates.
55 | return lr_scale_func_adam
56 | elif readout_constraint == "to_output_scale":
57 |
58 | def lr_scale_func_sgd_inner(param: ParameterData) -> float:
59 | scale = lr_scale_for_depth(param)
60 |
61 | if param.mup_type in ("bias", "norm"):
62 | return scale * param.shape[0]
63 | if param.mup_type == "weight":
64 | return scale * _get_fan_in(param) ** 0.5
65 | if param.mup_type == "output":
66 | return scale
67 | assert False, f"Unexpected mup_type {param.mup_type}"
68 |
69 | return lr_scale_func_sgd_inner
70 | else:
71 | assert False, f"Unhandled readout constraint: {readout_constraint}"
72 |
73 |
74 | def lr_scale_func_adam(param: ParameterData) -> float:
75 | """Calculate the LR scaling factor for :class:`torch.optim.Adam`
76 | and :class:`torch.optim.AdamW`.
77 | """
78 | scale = lr_scale_for_depth(param)
79 | if param.mup_type in ("bias", "norm"):
80 | return scale
81 | if param.mup_type == "weight":
82 | return scale * _get_fan_in(param) ** -0.5
83 | if param.mup_type == "output":
84 | return scale
85 | assert False, f"Unexpected mup_type {param.mup_type}"
86 |
87 |
88 | def scaled_parameters(
89 | params: ParamsT,
90 | lr_scale_func: Callable[[ParameterData], float],
91 | lr: Union[None, float, Tensor] = None,
92 | weight_decay: float = 0,
93 | independent_weight_decay: bool = True,
94 | allow_non_unit_scaling_params: bool = False,
95 | ) -> ParamsT:
96 | """Create optimizer-appropriate **lr-scaled** parameter groups.
97 |
98 | This method creates param_groups that apply the relevant scaling factors for u-muP
99 | models. For example::
100 |
101 | torch.optim.Adam(uu.optim.scaled_parameters(
102 | model.parameters(), uu.optim.adam_lr_scale_func, lr=1.0
103 | ))
104 |
105 | Args:
106 | params (ParamsT): an iterable of parameters of parameter groups, as passed to
107 | a torch optimizer.
108 | lr_scale_func (Callable): gets the optimizer-appropriate learning rate scale,
109 | based on a parameter tagged with `mup_type` and `mup_scaling_depth`. For
110 | example, :func:`lr_scale_func_sgd`.
111 | lr (float, optional): global learning rate (overridden by groups).
112 | weight_decay (float, optional): weight decay value (overridden by groups).
113 | independent_weight_decay (bool, optional): enable lr-independent weight decay,
114 | which performs an update per-step that does not depend on lr.
115 | allow_non_unit_scaling_params (bool, optional): by default, this method fails
116 | if passed any regular non-unit-scaled params; set to `True` to disable this
117 | check.
118 |
119 | Returns:
120 | ParamsT: for passing on to the optimizer.
121 | """
122 |
123 | result = []
124 | for entry in params:
125 | group = dict(params=[entry]) if isinstance(entry, Tensor) else entry.copy()
126 | group.setdefault("lr", lr) # type: ignore[arg-type]
127 | group.setdefault("weight_decay", weight_decay) # type: ignore[arg-type]
128 | if group["lr"] is None:
129 | raise ValueError(
130 | "scaled_params() requires lr to be provided,"
131 | " unless passing parameter groups which already have an lr"
132 | )
133 | for param in group["params"]:
134 | # Careful not to overwrite `lr` or `weight_decay`
135 | param_lr = group["lr"]
136 | if has_parameter_data(param): # type: ignore[arg-type]
137 | if isinstance(param_lr, Tensor):
138 | param_lr = param_lr.clone()
139 | param_lr *= lr_scale_func(param) # type: ignore[operator]
140 | elif not allow_non_unit_scaling_params:
141 | raise ValueError(
142 | "Non-unit-scaling parameter (no mup_type),"
143 | f" shape {tuple(param.shape)}"
144 | )
145 | param_weight_decay = group["weight_decay"]
146 | if independent_weight_decay:
147 | # Note: only independent of peak LR, not of schedule
148 | param_weight_decay /= float(param_lr) # type: ignore
149 |
150 | result.append(
151 | dict(
152 | params=[param],
153 | lr=param_lr,
154 | weight_decay=param_weight_decay,
155 | **{
156 | k: v
157 | for k, v in group.items()
158 | if k not in ("params", "lr", "weight_decay")
159 | },
160 | )
161 | )
162 | return result
163 |
164 |
165 | @inherit_docstring(
166 | short_description="An **lr-scaled** version of :class:`torch.optim.SGD` for u-muP."
167 | "`readout_constraint` should match the `constraint` arg used in `LinearReadout`."
168 | )
169 | class SGD(torch.optim.SGD):
170 |
171 | def __init__(
172 | self,
173 | params: ParamsT,
174 | lr: Union[float, Tensor] = 1e-3,
175 | *args: Any,
176 | weight_decay: float = 0,
177 | independent_weight_decay: bool = True,
178 | allow_non_unit_scaling_params: bool = False,
179 | readout_constraint: Optional[str] = None,
180 | **kwargs: Any,
181 | ) -> None:
182 | params = scaled_parameters(
183 | params,
184 | lr_scale_func_sgd(readout_constraint),
185 | lr=lr,
186 | weight_decay=weight_decay,
187 | independent_weight_decay=independent_weight_decay,
188 | allow_non_unit_scaling_params=allow_non_unit_scaling_params,
189 | )
190 | # No need to forward {lr, weight_decay}, as each group has these specified
191 | super().__init__(params, *args, **kwargs)
192 |
193 |
194 | @inherit_docstring(
195 | short_description="An **lr-scaled** version of :class:`torch.optim.Adam` for u-muP."
196 | )
197 | class Adam(torch.optim.Adam):
198 | def __init__(
199 | self,
200 | params: ParamsT,
201 | lr: Union[float, Tensor] = 1e-3,
202 | *args: Any,
203 | weight_decay: float = 0,
204 | independent_weight_decay: bool = True,
205 | allow_non_unit_scaling_params: bool = False,
206 | **kwargs: Any,
207 | ) -> None:
208 | params = scaled_parameters(
209 | params,
210 | lr_scale_func_adam,
211 | lr=lr,
212 | weight_decay=weight_decay,
213 | independent_weight_decay=independent_weight_decay,
214 | allow_non_unit_scaling_params=allow_non_unit_scaling_params,
215 | )
216 | # No need to forward {lr, weight_decay}, as each group has these specified
217 | super().__init__(params, *args, **kwargs)
218 |
219 |
220 | @inherit_docstring(
221 | short_description=(
222 | "An **lr-scaled** version of :class:`torch.optim.AdamW` for u-muP."
223 | )
224 | )
225 | class AdamW(torch.optim.AdamW):
226 | def __init__(
227 | self,
228 | params: ParamsT,
229 | lr: Union[float, Tensor] = 1e-3,
230 | *args: Any,
231 | weight_decay: float = 0,
232 | independent_weight_decay: bool = True,
233 | allow_non_unit_scaling_params: bool = False,
234 | **kwargs: Any,
235 | ) -> None:
236 | params = scaled_parameters(
237 | params,
238 | lr_scale_func_adam,
239 | lr=lr,
240 | weight_decay=weight_decay,
241 | independent_weight_decay=independent_weight_decay,
242 | allow_non_unit_scaling_params=allow_non_unit_scaling_params,
243 | )
244 | # No need to forward {lr, weight_decay}, as each group has these specified
245 | super().__init__(params, *args, **kwargs)
246 |
247 |
248 | __all__ = generate__all__(__name__)
249 |
--------------------------------------------------------------------------------
/unit_scaling/parameter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2 |
3 | """Extends :class:`torch.nn.Parameter` with attributes for u-μP."""
4 |
5 | # mypy: disable-error-code="attr-defined, method-assign, no-untyped-call"
6 |
7 | from collections import OrderedDict
8 | from typing import Any, Dict, Literal, Optional, Protocol, TypeGuard
9 |
10 | import torch
11 | from torch import Tensor, nn
12 |
13 | MupType = Literal["weight", "bias", "norm", "output"]
14 |
15 |
16 | class ParameterData(Protocol):
17 | """Extra fields for :class:`torch.nn.Parameter`, tagging u-μP metadata.
18 |
19 | Objects supporting this protocol should implicitly also support
20 | :class:`torch.nn.Parameter`.
21 | """
22 |
23 | mup_type: MupType
24 | mup_scaling_depth: Optional[int]
25 | shape: torch.Size # repeated from nn.Parameter, for convenience
26 |
27 |
28 | def has_parameter_data(parameter: nn.Parameter) -> TypeGuard[ParameterData]:
29 | """Check that the parameter supports the :class:`ParameterData` protocol."""
30 | return (
31 | getattr(parameter, "mup_type", None) in MupType.__args__
32 | and hasattr(parameter, "mup_scaling_depth")
33 | and isinstance(parameter.mup_scaling_depth, (type(None), int))
34 | )
35 |
36 |
37 | def _parameter_deepcopy(self: nn.Parameter, memo: Dict[int, Any]) -> nn.Parameter:
38 | result: nn.Parameter = nn.Parameter.__deepcopy__(self, memo)
39 | result.mup_type = self.mup_type
40 | result.mup_scaling_depth = self.mup_scaling_depth
41 | return result
42 |
43 |
44 | def _rebuild_parameter_with_state(*args: Any, **kwargs: Any) -> nn.Parameter:
45 | p: nn.Parameter = torch._utils._rebuild_parameter_with_state(*args, **kwargs)
46 | p.__deepcopy__ = _parameter_deepcopy.__get__(p)
47 | p.__reduce_ex__ = _parameter_reduce_ex.__get__(p)
48 | return p
49 |
50 |
51 | def _parameter_reduce_ex(self: nn.Parameter, protocol: int) -> Any:
52 | # Based on `torch.nn.Parameter.__reduce_ex__`, filtering out the
53 | # dynamic methods __deepcopy__ and __reduce_ex__, as these
54 | # don't unpickle
55 | state = {
56 | k: v
57 | for k, v in torch._utils._get_obj_state(self).items()
58 | if k not in ["__deepcopy__", "__reduce_ex__"]
59 | }
60 | return (
61 | _rebuild_parameter_with_state,
62 | (self.data, self.requires_grad, OrderedDict(), state),
63 | )
64 |
65 |
66 | def Parameter(
67 | data: Tensor, mup_type: MupType, mup_scaling_depth: Optional[int] = None
68 | ) -> nn.Parameter:
69 | """Construct a u-μP parameter object, an annotated :class:`torch.nn.Parameter`.
70 |
71 | The returned parameter also supports the :class:`ParameterData` protocol:
72 |
73 | p = uu.Parameter(torch.zeros(10), mup_type="weight")
74 | assert p.mup_type == "weight"
75 | assert p.mup_scaling_depth is None
76 | """
77 | p = nn.Parameter(data)
78 | p.mup_type = mup_type
79 | p.mup_scaling_depth = mup_scaling_depth
80 | p.__deepcopy__ = _parameter_deepcopy.__get__(p)
81 | p.__reduce_ex__ = _parameter_reduce_ex.__get__(p)
82 | # Note: cannot override __repr__ as it's __class__.__repr__
83 | return p
84 |
--------------------------------------------------------------------------------
/unit_scaling/scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | """Operations to enable different scaling factors in the forward and backward passes."""
4 |
5 | from __future__ import annotations # required for docs to alias type annotations
6 |
7 | from typing import Tuple
8 |
9 | import torch
10 | from torch import Tensor, fx
11 |
12 | from ._internal_utils import generate__all__
13 |
14 |
15 | class _ScaledGrad(torch.autograd.Function): # pragma: no cover
16 | """Enables a custom backward method which has a different scale to forward."""
17 |
18 | @staticmethod
19 | def forward(
20 | ctx: torch.autograd.function.FunctionCtx,
21 | X: Tensor,
22 | fwd_scale: float,
23 | bwd_scale: float,
24 | ) -> Tensor:
25 | # Special cases required for torch.fx tracing
26 | if isinstance(bwd_scale, fx.proxy.Proxy):
27 | ctx.save_for_backward(bwd_scale) # type: ignore
28 | elif isinstance(X, fx.proxy.Proxy):
29 | ctx.save_for_backward(torch.tensor(bwd_scale))
30 | else:
31 | ctx.save_for_backward(torch.tensor(bwd_scale, dtype=X.dtype))
32 | return fwd_scale * X
33 |
34 | @staticmethod
35 | def backward( # type:ignore[override]
36 | ctx: torch.autograd.function.FunctionCtx, grad_Y: Tensor
37 | ) -> Tuple[Tensor, None, None]:
38 | (bwd_scale,) = ctx.saved_tensors # type: ignore
39 | return bwd_scale * grad_Y, None, None
40 |
41 |
42 | def _scale(
43 | t: Tensor, fwd_scale: float = 1.0, bwd_scale: float = 1.0
44 | ) -> Tensor: # pragma: no cover
45 | """Given a tensor, applies a separate scale in the forward and backward pass."""
46 | return _ScaledGrad.apply(t, fwd_scale, bwd_scale) # type: ignore
47 |
48 |
49 | def scale_fwd(input: Tensor, scale: float) -> Tensor:
50 | """Applies a scalar multiplication to a tensor in only the forward pass.
51 |
52 | Args:
53 | input (Tensor): the tensor to be scaled.
54 | scale (float): the scale factor applied to the tensor in the forward pass.
55 |
56 | Returns:
57 | Tensor: scaled in the forward pass, but with its original grad.
58 | """
59 | return _scale(input, fwd_scale=scale)
60 |
61 |
62 | def scale_bwd(input: Tensor, scale: float) -> Tensor:
63 | """Applies a scalar multiplication to a tensor in only the backward pass.
64 |
65 | Args:
66 | input (Tensor): the tensor to be scaled.
67 | scale (float): the scale factor applied to the tensor in the backward pass.
68 |
69 | Returns:
70 | Tensor: unchanged in the forward pass, but with a scaled grad.
71 | """
72 | return _scale(input, bwd_scale=scale)
73 |
74 |
75 | __all__ = generate__all__(__name__)
76 |
--------------------------------------------------------------------------------
/unit_scaling/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
--------------------------------------------------------------------------------
/unit_scaling/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import random
4 |
5 | import numpy as np
6 | import pytest
7 | import torch
8 |
9 |
10 | @pytest.fixture(scope="function", autouse=True)
11 | def fix_seed() -> None:
12 | """For each test function, reset all random seeds."""
13 | random.seed(1472)
14 | np.random.seed(1472)
15 | torch.manual_seed(1472)
16 |
--------------------------------------------------------------------------------
/unit_scaling/tests/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2 |
--------------------------------------------------------------------------------
/unit_scaling/tests/core/test_functional.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2 |
3 | import pytest
4 | from torch import ones, randn, tensor
5 | from torch.testing import assert_close
6 |
7 | from ...core.functional import rms, scale_elementwise, transformer_residual_scaling_rule
8 | from ..helper import unit_backward
9 |
10 |
11 | def test_scale_elementwise_no_constraint() -> None:
12 | input = randn(2**10, requires_grad=True)
13 | f = lambda x: x
14 | scaled_f = scale_elementwise(
15 | f, output_scale=2.5, grad_input_scale=0.5, constraint=None
16 | )
17 | output = scaled_f(input)
18 | unit_backward(output)
19 |
20 | assert output.std().detach() == pytest.approx(2.5, rel=0.1)
21 | assert input.grad.std().detach() == pytest.approx(0.5, rel=0.1) # type: ignore
22 |
23 |
24 | def test_scale_elementwise_for_output() -> None:
25 | input = randn(2**10, requires_grad=True)
26 | f = lambda x: x
27 | scaled_f = scale_elementwise(
28 | f, output_scale=2.5, grad_input_scale=0.5, constraint="to_output_scale"
29 | )
30 | output = scaled_f(input)
31 | unit_backward(output)
32 |
33 | assert output.std().detach() == pytest.approx(2.5, rel=0.1)
34 | assert input.grad.std().detach() == pytest.approx(2.5, rel=0.1) # type: ignore
35 |
36 |
37 | def test_scale_elementwise_for_grad_input() -> None:
38 | input = randn(2**10, requires_grad=True)
39 | f = lambda x: x
40 | scaled_f = scale_elementwise(
41 | f, output_scale=2.5, grad_input_scale=0.5, constraint="to_grad_input_scale"
42 | )
43 | output = scaled_f(input)
44 | unit_backward(output)
45 |
46 | assert output.std().detach() == pytest.approx(0.5, rel=0.1)
47 | assert input.grad.std().detach() == pytest.approx(0.5, rel=0.1) # type: ignore
48 |
49 |
50 | def test_rms() -> None:
51 | output = rms(-4 + 3 * randn(2**12))
52 | assert output.item() == pytest.approx(5, rel=0.1)
53 |
54 | x = tensor([[2, -2, -2, -2], [0, 2, 0, 0]]).float()
55 | output = rms(x, dims=(1,))
56 | assert_close(output, tensor([2.0, 1.0]))
57 |
58 | output = rms(tensor([0.0, 0.0, 0.0]), eps=1 / 16)
59 | assert output.item() == pytest.approx(1 / 4, rel=0.1)
60 |
61 |
62 | @pytest.mark.parametrize(
63 | ["residual_mult", "residual_attn_ratio", "layers"],
64 | [
65 | (1.0, 1.0, 4),
66 | (0.5, 1.0, 4),
67 | (1.0, 2.0, 4),
68 | (0.5, 1 / 3, 6),
69 | ],
70 | )
71 | def test_transformer_residual_scaling_rule(
72 | residual_mult: float, residual_attn_ratio: float, layers: int
73 | ) -> None:
74 | scaling_rule = transformer_residual_scaling_rule(
75 | residual_mult=residual_mult,
76 | residual_attn_ratio=residual_attn_ratio,
77 | )
78 | scales = ones(layers + 1)
79 | for n in range(layers):
80 | tau = scaling_rule(n, layers)
81 | scales[n + 1] *= tau
82 | scales[: n + 2] /= (1 + tau**2) ** 0.5
83 | embedding_scale = scales[0]
84 | attn_scales = scales[1:][::2]
85 | mlp_scales = scales[2:][::2]
86 |
87 | # Basic properties
88 | assert_close(scales.pow(2).sum(), tensor(1.0))
89 |
90 | s_embedding = embedding_scale
91 | s_attn = attn_scales.pow(2).sum().sqrt()
92 | s_mlp = mlp_scales.pow(2).sum().sqrt()
93 | s_attn_mlp_average = ((s_attn**2 + s_mlp**2) / 2).sqrt()
94 |
95 | assert_close(s_attn_mlp_average / s_embedding, tensor(residual_mult))
96 | assert_close(s_attn / s_mlp, tensor(residual_attn_ratio))
97 |
98 | # Per-layer scales are equal
99 | assert_close(attn_scales, attn_scales[:1].broadcast_to(attn_scales.shape))
100 | assert_close(mlp_scales, mlp_scales[:1].broadcast_to(mlp_scales.shape))
101 |
--------------------------------------------------------------------------------
/unit_scaling/tests/helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | from typing import Optional
4 |
5 | import pytest
6 | import torch
7 | from torch import Tensor, randn
8 |
9 | from ..core.functional import rms
10 |
11 |
12 | def unit_backward(tensor: Tensor) -> Tensor:
13 | """Applies the `backward()` method with a unit normal tensor as input.
14 |
15 | Args:
16 | tensor (Tensor): tensor to have `backward()` applied.
17 |
18 | Returns:
19 | Tensor: the unit normal gradient tensor fed into `backward()`.
20 | """
21 | gradient = randn(*tensor.shape)
22 | tensor.backward(gradient) # type: ignore
23 | return gradient
24 |
25 |
26 | def assert_scale(
27 | *tensors: Optional[Tensor], target: float, abs: float = 0.1, stat: str = "std"
28 | ) -> None:
29 | for t in tensors:
30 | assert t is not None
31 | t = t.detach()
32 | approx_target = pytest.approx(target, abs=abs)
33 | stat_value = dict(rms=rms, std=torch.std)[stat](t) # type:ignore[operator]
34 | assert (
35 | stat_value == approx_target
36 | ), f"{stat}={stat_value:.3f}, shape={list(t.shape)}"
37 |
38 |
39 | def assert_not_scale(
40 | *tensors: Optional[Tensor], target: float, abs: float = 0.1, stat: str = "std"
41 | ) -> None:
42 | for t in tensors:
43 | assert t is not None
44 | t = t.detach()
45 | approx_target = pytest.approx(target, abs=abs)
46 | stat_value = dict(rms=t.pow(2).mean().sqrt(), std=t.std())[stat]
47 | assert (
48 | stat_value != approx_target
49 | ), f"{stat}={stat_value:.3f}, shape={list(t.shape)}"
50 |
51 |
52 | def assert_unit_scaled(
53 | *tensors: Optional[Tensor], abs: float = 0.1, stat: str = "std"
54 | ) -> None:
55 | return assert_scale(*tensors, target=1.0, abs=abs, stat=stat)
56 |
57 |
58 | def assert_not_unit_scaled(
59 | *tensors: Optional[Tensor], abs: float = 0.1, stat: str = "std"
60 | ) -> None:
61 | return assert_not_scale(*tensors, target=1.0, abs=abs, stat=stat)
62 |
63 |
64 | def assert_zeros(*tensors: Optional[Tensor]) -> None:
65 | for t in tensors:
66 | assert t is not None
67 | t = t.detach()
68 | assert torch.all(t == 0)
69 |
70 |
71 | def assert_non_zeros(*tensors: Optional[Tensor]) -> None:
72 | for t in tensors:
73 | assert t is not None
74 | t = t.detach()
75 | assert torch.any(t != 0)
76 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_analysis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | from typing import Tuple
4 |
5 | import torch.nn.functional as F
6 | from torch import Size, Tensor, nn, randn
7 | from transformers import AutoTokenizer # type: ignore[import-untyped]
8 |
9 | from ..analysis import _create_batch, _example_seqs, example_batch, plot, visualiser
10 | from ..transforms import track_scales
11 |
12 |
13 | def test_example_seqs() -> None:
14 | batch_size, min_seq_len = 3, 1024
15 | seqs = _example_seqs(batch_size, min_seq_len)
16 | assert len(seqs) == batch_size, len(seqs)
17 | for s in seqs:
18 | assert isinstance(s, str)
19 | assert not s.isspace()
20 | assert len(s) >= min_seq_len
21 |
22 |
23 | def test_create_batch() -> None:
24 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
25 | batch_size, seq_len = 3, 256
26 | seqs = _example_seqs(batch_size, min_seq_len=seq_len * 4)
27 | input_idxs, attn_mask, labels = _create_batch(tokenizer, seqs, seq_len)
28 |
29 | assert isinstance(input_idxs, Tensor)
30 | assert isinstance(attn_mask, Tensor)
31 | assert isinstance(labels, Tensor)
32 | assert input_idxs.shape == Size([batch_size, seq_len])
33 | assert attn_mask.shape == Size([batch_size, seq_len])
34 | assert labels.shape == Size([batch_size, seq_len])
35 |
36 |
37 | def test_example_batch() -> None:
38 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
39 | batch_size, seq_len = 3, 256
40 | input_idxs, attn_mask, labels = example_batch(tokenizer, batch_size, seq_len)
41 |
42 | assert isinstance(input_idxs, Tensor)
43 | assert isinstance(attn_mask, Tensor)
44 | assert isinstance(labels, Tensor)
45 | assert input_idxs.shape == Size([batch_size, seq_len])
46 | assert attn_mask.shape == Size([batch_size, seq_len])
47 | assert labels.shape == Size([batch_size, seq_len])
48 |
49 |
50 | def test_plot() -> None:
51 | class Model(nn.Module):
52 | def __init__(self, dim: int) -> None:
53 | super().__init__()
54 | self.dim = dim
55 | self.linear = nn.Linear(dim, dim // 2)
56 |
57 | def forward(self, x: Tensor) -> Tensor: # pragma: no cover
58 | y = F.relu(x)
59 | z = self.linear(y)
60 | return z.sum() # type: ignore[no-any-return]
61 |
62 | b, dim = 2**4, 2**8
63 | input = randn(b, dim)
64 | model = Model(dim)
65 | model = track_scales(model)
66 | loss = model(input)
67 | loss.backward()
68 |
69 | graph = model.scales_graph()
70 | axes = plot(graph, "demo", xmin=2**-20, xmax=2**10)
71 | assert axes
72 |
73 |
74 | def test_visualiser() -> None:
75 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
76 |
77 | class Model(nn.Module):
78 | def __init__(self, n_embed: int, dim: int) -> None:
79 | super().__init__()
80 | self.embedding = nn.Embedding(n_embed, dim)
81 | self.linear = nn.Linear(dim, n_embed)
82 |
83 | def forward(
84 | self, inputs: Tensor, labels: Tensor
85 | ) -> Tuple[Tensor, Tensor]: # pragma: no cover
86 | x = self.embedding(inputs)
87 | x = self.linear(x)
88 | loss = F.cross_entropy(x.view(-1, x.size(-1)), labels.view(-1))
89 | return x, loss
90 |
91 | axes = visualiser(
92 | model=Model(n_embed=tokenizer.vocab_size, dim=128),
93 | tokenizer=tokenizer,
94 | batch_size=16,
95 | seq_len=256,
96 | )
97 | assert axes
98 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_constraints.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import pytest
4 |
5 | from ..constraints import (
6 | amean,
7 | apply_constraint,
8 | gmean,
9 | hmean,
10 | to_grad_input_scale,
11 | to_output_scale,
12 | )
13 |
14 |
15 | def test_gmean() -> None:
16 | assert gmean(1.0) == 1.0
17 | assert gmean(8.0, 2.0) == 4.0
18 | assert gmean(12.75, 55.0, 0.001) == pytest.approx(0.8884322)
19 |
20 |
21 | def test_hmean() -> None:
22 | assert hmean(1.0) == 1.0
23 | assert hmean(8.0, 2.0) == 3.2
24 | assert hmean(12.75, 55.0, 0.001) == pytest.approx(0.00299971)
25 |
26 |
27 | def test_amean() -> None:
28 | assert amean(1.0) == 1.0
29 | assert amean(8.0, 2.0) == 5.0
30 | assert amean(12.75, 55.0, 0.001) == pytest.approx(22.583667)
31 |
32 |
33 | def test_to_output_scale() -> None:
34 | assert to_output_scale(2, 3) == 2
35 | assert to_output_scale(2, 3, 4) == 2
36 |
37 |
38 | def test_to_grad_input_scale() -> None:
39 | assert to_grad_input_scale(2, 3) == 3
40 |
41 |
42 | def test_apply_constraint() -> None:
43 | assert apply_constraint("gmean", 1.0, 4.0, 2.0) == (2.0, 2.0, 2.0)
44 | assert apply_constraint("hmean", 8.0, 2.0) == (3.2, 3.2)
45 | with pytest.raises(ValueError):
46 | apply_constraint("invalid", 8.0, 2.0)
47 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_docs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import pytest
4 |
5 | from ..docs import _validate
6 |
7 |
8 | def f(a, b: int, c="3", d: float = 4.0) -> str: # type: ignore
9 | return f"{a} {b} {c} {d}"
10 |
11 |
12 | def test_validate_no_args() -> None:
13 | def g() -> int:
14 | return 0
15 |
16 | valid_g = _validate(g)
17 | assert valid_g() == 0
18 |
19 |
20 | def test_validate_positional_args() -> None:
21 | # Works with no unsupported args
22 | valid_f = _validate(f)
23 | assert valid_f(None, 2) == "None 2 3 4.0"
24 | assert valid_f(None, 2, "3", 4.5) == "None 2 3 4.5"
25 |
26 | # Works with some unsupported args
27 | valid_f = _validate(f, unsupported_args=["c", "d"])
28 |
29 | # Works if unsupported args are not present or equal default
30 | assert valid_f(None, 2) == "None 2 3 4.0"
31 | assert valid_f(None, 2, "3", 4.0) == "None 2 3 4.0"
32 |
33 | # Doesn't work if non-default unsupported args provided
34 | with pytest.raises(ValueError) as e:
35 | valid_f(None, 2, "3.4")
36 | assert "argument has not been implemented" in str(e.value)
37 | with pytest.raises(ValueError) as e:
38 | valid_f(None, 2, "3", 4.5)
39 | assert "argument has not been implemented" in str(e.value)
40 |
41 |
42 | def test_validate_positional_kwargs() -> None:
43 | # Works with no unsupported args
44 | valid_f = _validate(f)
45 | assert valid_f(None, 2) == "None 2 3 4.0"
46 | assert valid_f(None, 2, c="3", d=4.5) == "None 2 3 4.5"
47 |
48 | # Works with some unsupported args
49 | valid_f = _validate(f, unsupported_args=["c", "d"])
50 |
51 | # Works if unsupported args are not present or equal default
52 | assert valid_f(None, 2) == "None 2 3 4.0"
53 | assert valid_f(None, 2, c="3")
54 |
55 | # Doesn't work if non-default unsupported args provided
56 | with pytest.raises(ValueError) as e:
57 | valid_f(None, 2, c="3.4")
58 | assert "argument has not been implemented" in str(e.value)
59 | with pytest.raises(ValueError) as e:
60 | valid_f(None, 2, d=4.5)
61 | assert "argument has not been implemented" in str(e.value)
62 |
63 |
64 | def test_validate_invalid_arg() -> None:
65 | with pytest.raises(ValueError) as e:
66 | _validate(f, unsupported_args=["z"])
67 | assert "is not valid" in str(e.value)
68 |
69 |
70 | def test_validate_no_default_arg() -> None:
71 | with pytest.raises(ValueError) as e:
72 | _validate(f, unsupported_args=["b"])
73 | assert "has no default value" in str(e.value)
74 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_formats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import collections
4 |
5 | import torch
6 |
7 | from ..formats import FPFormat
8 |
9 |
10 | def test_fp_format() -> None:
11 | fmt = FPFormat(2, 1)
12 | assert fmt.bits == 4
13 | assert str(fmt) == "E2M1-SR"
14 | assert fmt.max_absolute_value == 3
15 | assert fmt.min_absolute_normal == 0.5
16 | assert fmt.min_absolute_subnormal == 0.25
17 | assert set(fmt.quantise_fwd(torch.linspace(-4, 4, steps=100)).tolist()) == {
18 | sx for x in [0, 0.25, 0.5, 0.75, 1, 1.5, 2, 3] for sx in [x, -x]
19 | }
20 | assert set(
21 | FPFormat(3, 0).quantise_fwd(torch.linspace(-10, 10, steps=1000)).abs().tolist()
22 | ) == {0, 0.125, 0.25, 0.5, 1, 2, 4, 8}
23 |
24 |
25 | def test_fp_format_rounding() -> None:
26 | n = 10000
27 | x = -1.35
28 |
29 | y_nearest = FPFormat(2, 1, rounding="nearest").quantise(torch.full((n,), x))
30 | assert collections.Counter(y_nearest.tolist()) == {-1.5: n}
31 |
32 | for srbits in (0, 13):
33 | srformat = FPFormat(2, 1, rounding="stochastic", srbits=srbits)
34 | y_stochastic = srformat.quantise(torch.full((n,), x))
35 | count = collections.Counter(y_stochastic.tolist())
36 | assert count.keys() == {-1.5, -1.0}
37 | expected_ratio = (1.35 - 1.0) / 0.5
38 | nearest_ratio = count[-1.5] / sum(count.values())
39 | std_x3 = 3 * (expected_ratio * (1 - expected_ratio) / n) ** 0.5
40 | assert expected_ratio - std_x3 < nearest_ratio < expected_ratio + std_x3
41 |
42 |
43 | def test_fp_format_bwd() -> None:
44 | fmt = FPFormat(2, 1)
45 | x = torch.randn(100, requires_grad=True)
46 | y = fmt.quantise_bwd(x * 1)
47 | y.backward(torch.linspace(-4, 4, steps=100)) # type: ignore[no-untyped-call]
48 | assert x.grad is not None
49 | assert set(x.grad.tolist()) == {
50 | sx for x in [0, 0.25, 0.5, 0.75, 1, 1.5, 2, 3] for sx in [x, -x]
51 | }
52 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_modules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import pytest
4 | import torch
5 | from torch import randint, randn
6 |
7 | from .._modules import (
8 | GELU,
9 | MHSA,
10 | MLP,
11 | Conv1d,
12 | CrossEntropyLoss,
13 | DepthModuleList,
14 | DepthSequential,
15 | Dropout,
16 | Embedding,
17 | LayerNorm,
18 | Linear,
19 | LinearReadout,
20 | RMSNorm,
21 | SiLU,
22 | Softmax,
23 | TransformerDecoder,
24 | TransformerLayer,
25 | )
26 | from ..optim import SGD
27 | from ..parameter import has_parameter_data
28 | from .helper import (
29 | assert_non_zeros,
30 | assert_not_unit_scaled,
31 | assert_unit_scaled,
32 | assert_zeros,
33 | unit_backward,
34 | )
35 |
36 |
37 | def test_gelu() -> None:
38 | input = randn(2**10)
39 | model = GELU()
40 | output = model(input)
41 |
42 | assert float(output.std()) == pytest.approx(1, abs=0.1)
43 |
44 |
45 | def test_silu() -> None:
46 | input = randn(2**10)
47 | model = SiLU()
48 | output = model(input)
49 |
50 | assert float(output.std()) == pytest.approx(1, abs=0.1)
51 |
52 |
53 | def test_softmax() -> None:
54 | input = randn(2**14)
55 | model = Softmax(dim=0)
56 | output = model(input)
57 |
58 | # The approximation is quite rough at mult=1
59 | assert 0.1 < float(output.std()) < 10
60 |
61 |
62 | def test_dropout() -> None:
63 | input = randn(2**12, requires_grad=True)
64 | model = Dropout()
65 | output = model(input)
66 |
67 | assert float(output.std()) == pytest.approx(1, abs=0.1)
68 |
69 | with pytest.raises(ValueError):
70 | Dropout(0.5, inplace=True)
71 |
72 |
73 | def test_linear() -> None:
74 | input = randn(2**8, 2**10, requires_grad=True)
75 | model = Linear(2**10, 2**12, bias=True)
76 | output = model(input)
77 |
78 | assert_unit_scaled(model.weight)
79 | assert_zeros(model.bias)
80 | assert output.shape == torch.Size([2**8, 2**12])
81 |
82 | unit_backward(output)
83 | SGD(model.parameters(), lr=1, readout_constraint="to_output_scale").step()
84 |
85 | assert float(output.std()) == pytest.approx(1, abs=0.1)
86 |
87 | assert_not_unit_scaled(model.weight)
88 | assert_non_zeros(model.bias)
89 |
90 |
91 | def test_conv1d() -> None:
92 | batch_size = 2**6
93 | d_in = 2**6 * 3
94 | d_out = 2**6 * 5
95 | kernel_size = 11
96 | seq_len = 2**6 * 7
97 | input = randn(batch_size, d_in, seq_len, requires_grad=True)
98 | model = Conv1d(d_in, d_out, kernel_size, bias=True)
99 | output = model(input)
100 |
101 | assert_unit_scaled(model.weight)
102 | assert_zeros(model.bias)
103 |
104 | unit_backward(output)
105 | SGD(model.parameters(), lr=1, readout_constraint="to_output_scale").step()
106 |
107 | assert float(output.std()) == pytest.approx(1, abs=0.1)
108 |
109 | assert_not_unit_scaled(model.weight)
110 | assert_non_zeros(model.bias)
111 |
112 |
113 | def test_linear_readout() -> None:
114 | input = randn(2**8, 2**10, requires_grad=True)
115 | model = LinearReadout(2**10, 2**12)
116 | output = model(input)
117 |
118 | assert model.weight.mup_type == "output" # type:ignore[attr-defined]
119 | assert_unit_scaled(model.weight)
120 | assert output.shape == torch.Size([2**8, 2**12])
121 | assert float(output.std()) == pytest.approx(2**-5, rel=0.1)
122 |
123 | unit_backward(output)
124 | SGD(model.parameters(), lr=1).step()
125 | assert_not_unit_scaled(model.weight)
126 |
127 |
128 | def test_layer_norm() -> None:
129 | input = randn(2**8, 2**10, requires_grad=True)
130 | model = LayerNorm(2**10, elementwise_affine=True)
131 | output = model(input)
132 |
133 | assert output.shape == torch.Size([2**8, 2**10])
134 |
135 | unit_backward(output)
136 | SGD(model.parameters(), lr=1).step()
137 |
138 | assert_unit_scaled(output, input.grad, model.weight.grad, model.bias.grad)
139 |
140 |
141 | def test_rms_norm() -> None:
142 | input = randn(2**8, 2**10, requires_grad=True)
143 | model = RMSNorm(2**10, elementwise_affine=True)
144 | output = model(input)
145 |
146 | assert output.shape == torch.Size([2**8, 2**10])
147 | assert model.weight is not None
148 |
149 | unit_backward(output)
150 | SGD(model.parameters(), lr=1).step()
151 |
152 | assert_unit_scaled(output, input.grad, model.weight.grad)
153 |
154 |
155 | def test_embedding() -> None:
156 | batch_sz, seq_len, embedding_dim, num_embeddings = 2**4, 2**5, 2**6, 2**12
157 | input_idxs = randint(low=0, high=2**12, size=(batch_sz, seq_len))
158 | model = Embedding(num_embeddings, embedding_dim)
159 | output = model(input_idxs)
160 |
161 | assert output.shape == torch.Size([batch_sz, seq_len, embedding_dim])
162 |
163 | unit_backward(output)
164 |
165 | assert_unit_scaled(model.weight.grad)
166 |
167 | with pytest.raises(ValueError):
168 | Embedding(num_embeddings, embedding_dim, scale_grad_by_freq=True)
169 | with pytest.raises(ValueError):
170 | Embedding(num_embeddings, embedding_dim, sparse=True)
171 |
172 |
173 | def test_cross_entropy_loss() -> None:
174 | num_tokens, vocab_sz = 2**12, 2**8
175 | input = randn(num_tokens, vocab_sz, requires_grad=True)
176 | labels = randint(low=0, high=vocab_sz, size=(num_tokens,))
177 | model = CrossEntropyLoss()
178 | loss = model(input, labels)
179 | loss.backward()
180 |
181 | assert_unit_scaled(input.grad)
182 |
183 | with pytest.raises(ValueError):
184 | CrossEntropyLoss(weight=randn(vocab_sz))
185 | with pytest.raises(ValueError):
186 | CrossEntropyLoss(label_smoothing=0.5)
187 |
188 |
189 | def test_mlp() -> None:
190 | input = randn(2**8, 2**10, requires_grad=True)
191 | model = MLP(2**10)
192 | output = model(input)
193 |
194 | assert_unit_scaled(
195 | model.linear_1.weight, model.linear_gate.weight, model.linear_2.weight
196 | )
197 | assert output.shape == torch.Size([2**8, 2**10])
198 |
199 | unit_backward(output)
200 | SGD(model.parameters(), lr=1, readout_constraint="to_output_scale").step()
201 |
202 | assert float(output.std()) == pytest.approx(1, abs=0.2)
203 |
204 | assert_unit_scaled(
205 | model.linear_1.weight.grad,
206 | model.linear_gate.weight.grad,
207 | model.linear_2.weight.grad,
208 | )
209 |
210 | assert_not_unit_scaled(
211 | model.linear_1.weight, model.linear_gate.weight, model.linear_2.weight
212 | )
213 |
214 |
215 | def test_mhsa() -> None:
216 | batch_sz, seq_len, hidden_dim = 2**8, 2**6, 2**6
217 | input = randn(batch_sz, seq_len, hidden_dim, requires_grad=True)
218 | model = MHSA(hidden_dim, heads=8, is_causal=False, dropout_p=0.1)
219 | output = model(input)
220 |
221 | assert_unit_scaled(model.linear_qkv.weight, model.linear_o.weight)
222 | assert output.shape == torch.Size([batch_sz, seq_len, hidden_dim])
223 |
224 | unit_backward(output)
225 | SGD(model.parameters(), lr=1, readout_constraint="to_output_scale").step()
226 |
227 | assert float(output.std()) == pytest.approx(1, abs=0.5)
228 |
229 | assert_not_unit_scaled(model.linear_qkv.weight, model.linear_o.weight)
230 |
231 |
232 | def test_transformer_layer() -> None:
233 | batch_sz, seq_len, hidden_dim, heads = 2**8, 2**6, 2**6, 8
234 | input = randn(batch_sz, seq_len, hidden_dim, requires_grad=True)
235 | model = TransformerLayer(
236 | hidden_dim,
237 | heads=heads,
238 | is_causal=False,
239 | dropout_p=0.1,
240 | mhsa_tau=0.1,
241 | mlp_tau=1.0,
242 | )
243 | output = model(input)
244 |
245 | assert output.shape == torch.Size([batch_sz, seq_len, hidden_dim])
246 |
247 | unit_backward(output)
248 | SGD(model.parameters(), lr=1).step()
249 |
250 | assert float(output.std()) == pytest.approx(1, abs=0.1)
251 |
252 |
253 | def test_depth_module_list() -> None:
254 | layers = DepthModuleList([Linear(10, 10) for _ in range(5)])
255 | assert len(layers) == 5
256 | for layer in layers:
257 | assert layer.weight.mup_scaling_depth == 5
258 |
259 | with pytest.raises(ValueError):
260 | DepthModuleList([torch.nn.Linear(10, 10) for _ in range(5)])
261 |
262 |
263 | def test_depth_sequential() -> None:
264 | model = DepthSequential(*(Linear(2**6, 2**6) for _ in range(7)))
265 | for param in model.parameters():
266 | assert has_parameter_data(param)
267 | assert param.mup_scaling_depth == 7
268 |
269 | input = randn(2**4, 2**6, requires_grad=True)
270 | output = model(input)
271 | unit_backward(output)
272 | assert_unit_scaled(output, input.grad)
273 |
274 | with pytest.raises(ValueError):
275 | DepthSequential(*[torch.nn.Linear(2**6, 2**6) for _ in range(7)])
276 |
277 |
278 | def test_transformer_decoder() -> None:
279 | batch_size = 2**8
280 | seq_len = 2**6
281 | hidden_size = 2**6
282 | vocab_size = 2**12
283 | layers = 2
284 | heads = 4
285 |
286 | input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
287 | model = TransformerDecoder(hidden_size, vocab_size, layers, heads, dropout_p=0.1)
288 | loss = model.loss(input_ids)
289 |
290 | expected_loss = torch.tensor(vocab_size).log()
291 | assert expected_loss / 2 < loss.item() < expected_loss * 2
292 |
293 | loss.backward() # type:ignore[no-untyped-call]
294 | SGD(model.parameters(), lr=1).step()
295 |
296 | for name, p in model.named_parameters():
297 | threshold = 5.0
298 | assert p.grad is not None
299 | assert 1 / threshold <= p.grad.std().detach() <= threshold, name
300 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_optim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2 |
3 | from typing import Any, Dict, List, Optional, Type, cast
4 |
5 | import pytest
6 | import torch
7 | from torch import nn, tensor, zeros
8 | from torch.testing import assert_close
9 |
10 | import unit_scaling as uu
11 | import unit_scaling.functional as U
12 |
13 | from ..optim import (
14 | SGD,
15 | Adam,
16 | AdamW,
17 | lr_scale_func_adam,
18 | lr_scale_func_sgd,
19 | scaled_parameters,
20 | )
21 |
22 |
23 | @pytest.mark.parametrize("opt_type", (Adam, AdamW, SGD))
24 | def test_optim_optimizers(opt_type: Type[torch.optim.Optimizer]) -> None:
25 | torch.manual_seed(100)
26 | inputs = torch.randn(10, 16)
27 | outputs = torch.randn(10, 25)
28 | model = uu.Linear(16, 25)
29 | opt = opt_type(
30 | model.parameters(), lr=0.01, weight_decay=1e-6 # type:ignore[call-arg]
31 | )
32 | opt.zero_grad()
33 | loss = U.mse_loss(model(inputs), outputs)
34 | loss.backward() # type:ignore[no-untyped-call]
35 | opt.step()
36 | assert U.mse_loss(model(inputs), outputs) < loss
37 |
38 |
39 | @pytest.mark.parametrize("opt", ["adam", "sgd"])
40 | @pytest.mark.parametrize("readout_constraint", [None, "to_output_scale"])
41 | def test_scaled_parameters(opt: str, readout_constraint: Optional[str]) -> None:
42 | model = nn.Sequential(
43 | uu.Embedding(2**8, 2**4),
44 | uu.DepthSequential(*(uu.Linear(2**4, 2**4, bias=True) for _ in range(3))),
45 | uu.LinearReadout(
46 | 2**4,
47 | 2**10,
48 | bias=False,
49 | weight_mup_type="output",
50 | constraint=readout_constraint,
51 | ),
52 | )
53 |
54 | base_lr = 0.1
55 | base_wd = 0.001
56 | param_groups = scaled_parameters(
57 | model.parameters(),
58 | dict(sgd=lr_scale_func_sgd(readout_constraint), adam=lr_scale_func_adam)[opt],
59 | lr=base_lr,
60 | weight_decay=base_wd,
61 | )
62 |
63 | # Match parameters based on shapes, as their names have disappeared
64 | sqrt_d = 3**0.5
65 | shape_to_expected_lr = {
66 | (2**8, 2**4): base_lr / 2**2, # embedding.weight
67 | (2**4, 2**4): base_lr / 2**2 / sqrt_d, # stack.linear.weight
68 | (2**4,): base_lr / sqrt_d, # stack.linear.bias
69 | (2**10, 2**4): base_lr, # linear.weight (output)
70 | }
71 | if opt == "sgd" and readout_constraint == "to_output_scale":
72 | shape_to_expected_lr[(2**8, 2**4)] *= 2**4
73 | shape_to_expected_lr[(2**4, 2**4)] *= 2**4
74 | shape_to_expected_lr[(2**4,)] *= 2**4
75 |
76 | for shape, expected_lr in shape_to_expected_lr.items():
77 | for g in param_groups:
78 | assert isinstance(g, dict)
79 | (param,) = g["params"]
80 | if param.shape == shape:
81 | assert g["lr"] == pytest.approx(
82 | expected_lr, rel=1e-3
83 | ), f"bad LR for param.shape={shape}"
84 | assert g["weight_decay"] == pytest.approx(
85 | base_wd / expected_lr, rel=1e-3
86 | ), f"bad WD for param.shape={shape}"
87 |
88 |
89 | def test_scaled_parameters_with_existing_groups() -> None:
90 | original_params = cast(
91 | List[Dict[str, Any]],
92 | [
93 | # Two parameters in this group, sharing a tensor LR, which must be cloned
94 | dict(
95 | params=[
96 | uu.Parameter(zeros(1, 2**4), mup_type="weight"),
97 | uu.Parameter(zeros(2, 2**6), mup_type="output"),
98 | ],
99 | lr=torch.tensor(0.3),
100 | weight_decay=0.05,
101 | ),
102 | # One parameter in this group, with no explicit LR or WD
103 | dict(
104 | params=[
105 | uu.Parameter(
106 | zeros(3, 2**6), mup_type="weight", mup_scaling_depth=5
107 | ),
108 | ],
109 | ),
110 | ],
111 | )
112 |
113 | g0, g1, g2 = scaled_parameters(
114 | original_params, lr_scale_func_adam, lr=torch.tensor(0.02)
115 | )
116 |
117 | assert isinstance(g0, dict)
118 | assert g0["params"][0].shape == (1, 2**4)
119 | assert_close(g0["lr"], tensor(0.3 / 2**2)) # also checks it's still a Tensor
120 | assert_close(g0["weight_decay"], 0.05 * 2**2 / 0.3)
121 |
122 | assert isinstance(g1, dict)
123 | assert g1["params"][0].shape == (2, 2**6)
124 | assert_close(g1["lr"], tensor(0.3))
125 | assert_close(g1["weight_decay"], 0.05 / 0.3)
126 |
127 | assert isinstance(g2, dict)
128 | assert g2["params"][0].shape == (3, 2**6)
129 | assert_close(g2["lr"], tensor(0.02 / 2**3 / 5**0.5))
130 | assert g2["weight_decay"] == 0
131 |
132 | # ### Check error conditions ###
133 |
134 | # No lr, missing for group 1
135 | with pytest.raises(ValueError):
136 | params = scaled_parameters(original_params, lr_scale_func_adam)
137 | # No need for an lr when all groups have it explicitly
138 | params = scaled_parameters(original_params[:1], lr_scale_func_adam)
139 | assert len(params) == 2 # type:ignore[arg-type]
140 |
141 | # Non-unit-scaling parameters
142 | with pytest.raises(ValueError):
143 | params = scaled_parameters(
144 | original_params + [dict(params=[nn.Parameter(zeros(4, 2**4))])],
145 | lr_scale_func_adam,
146 | lr=0.1,
147 | )
148 | # Allow non-unit-scaling parameters
149 | params = scaled_parameters(
150 | original_params + [dict(params=[nn.Parameter(zeros(4, 2**4))])],
151 | lr_scale_func_adam,
152 | lr=0.1,
153 | allow_non_unit_scaling_params=True,
154 | )
155 | assert len(params) == 4 # type:ignore[arg-type]
156 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_parameter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2 |
3 | import copy
4 | import pickle
5 |
6 | import torch
7 | from torch import Tensor, empty, full, tensor
8 | from torch.testing import assert_close
9 |
10 | from ..parameter import Parameter, has_parameter_data
11 |
12 |
13 | def test_parameter() -> None:
14 | param = Parameter(torch.zeros(10), "weight")
15 | assert has_parameter_data(param)
16 | assert param.mup_type == "weight"
17 | assert param.mup_scaling_depth is None
18 |
19 | param.mup_scaling_depth = 8
20 |
21 | param_copy = copy.deepcopy(param)
22 | assert has_parameter_data(param_copy) # type:ignore[arg-type]
23 | assert param_copy.mup_type == "weight"
24 | assert param_copy.mup_scaling_depth == 8
25 |
26 | param_pickle = pickle.loads(pickle.dumps(param))
27 | assert has_parameter_data(param_pickle)
28 | assert param_pickle.mup_type == "weight"
29 | assert param_pickle.mup_scaling_depth == 8
30 |
31 | param_pickle_copy = copy.deepcopy(param_pickle)
32 | assert has_parameter_data(param_pickle_copy) # type:ignore[arg-type]
33 | assert param_pickle_copy.mup_type == "weight"
34 | assert param_pickle_copy.mup_scaling_depth == 8
35 |
36 |
37 | def test_parameter_compile() -> None:
38 | parameter = Parameter(empty(3), mup_type="norm")
39 |
40 | def update_parameter(mult: Tensor) -> Tensor:
41 | parameter.data.mul_(mult)
42 | return parameter
43 |
44 | parameter.data.fill_(1)
45 | assert_close(update_parameter(tensor(123.0)), full((3,), 123.0))
46 |
47 | parameter.data.fill_(1)
48 | update_parameter = torch.compile(fullgraph=True)(update_parameter)
49 | assert_close(update_parameter(tensor(0.5)), full((3,), 0.5))
50 | assert_close(update_parameter(tensor(8.0)), full((3,), 4.0))
51 | assert_close(update_parameter(parameter), full((3,), 16.0))
52 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import torch
4 | from torch import randn
5 |
6 | from ..scale import scale_bwd, scale_fwd
7 | from .helper import unit_backward
8 |
9 |
10 | def test_scale_fwd() -> None:
11 | x = randn(2**10, requires_grad=True)
12 | x_scaled = scale_fwd(x, 3.5)
13 | grad_in = unit_backward(x_scaled)
14 |
15 | assert torch.equal(x_scaled, x * 3.5)
16 | assert torch.equal(x.grad, grad_in) # type: ignore
17 |
18 |
19 | def test_scale_bwd() -> None:
20 | x = randn(2**10, requires_grad=True)
21 | x_scaled = scale_bwd(x, 3.5)
22 | grad_in = unit_backward(x_scaled)
23 |
24 | assert torch.equal(x_scaled, x)
25 | assert torch.equal(x.grad, grad_in * 3.5) # type: ignore
26 |
--------------------------------------------------------------------------------
/unit_scaling/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import re
4 |
5 | import torch
6 |
7 | from .._modules import MHSA, MLP
8 | from ..utils import analyse_module
9 |
10 |
11 | def remove_scales(code: str) -> str:
12 | """Takes a code string containing scale annotations such as `(-> 1.0, <- 0.819)` and
13 | replaces them all with `(-> _, <- _)`."""
14 | return re.sub(r"\(-> \d+(\.\d+)?, <- \d+(\.\d+)?\)", "(-> _, <- _)", code)
15 |
16 |
17 | def test_analyse_mlp() -> None:
18 | batch_size = 2**10
19 | hidden_size = 2**10
20 | input = torch.randn(batch_size, hidden_size).requires_grad_()
21 | backward = torch.randn(batch_size, hidden_size)
22 |
23 | annotated_code = analyse_module(
24 | MLP(hidden_size), input, backward, syntax_highlight=False
25 | )
26 | print(annotated_code)
27 |
28 | expected_code = """
29 | def forward(self, input : Tensor) -> Tensor:
30 | input_1 = input; (-> 1.0, <- 1.44)
31 | linear_1_weight = self.linear_1.weight; (-> 1.0, <- 0.503)
32 | linear = U.linear(input_1, linear_1_weight, None, None); (-> 1.0, <- 0.502)
33 | linear_gate_weight = self.linear_gate.weight; (-> 1.0, <- 0.519)
34 | linear_1 = U.linear(input_1, linear_gate_weight, None, None); (-> 1.0, <- 0.518)
35 | silu_glu = U.silu_glu(linear, linear_1); (-> 1.0, <- 0.5)
36 | linear_2_weight = self.linear_2.weight; (-> 1.0, <- 1.0)
37 | linear_2 = U.linear(silu_glu, linear_2_weight, None, None); (-> 1.0, <- 1.0)
38 | return linear_2
39 | """.strip() # noqa: E501
40 |
41 | assert remove_scales(annotated_code) == remove_scales(expected_code)
42 |
43 |
44 | def test_analyse_mhsa() -> None:
45 | batch_size = 2**8
46 | seq_len = 2**6
47 | hidden_size = 2**6
48 | heads = 4
49 | input = torch.randn(batch_size, seq_len, hidden_size).requires_grad_()
50 | backward = torch.randn(batch_size, seq_len, hidden_size)
51 |
52 | annotated_code = analyse_module(
53 | MHSA(hidden_size, heads, is_causal=False, dropout_p=0.1),
54 | input,
55 | backward,
56 | syntax_highlight=False,
57 | )
58 | print(annotated_code)
59 |
60 | expected_code = """
61 | def forward(self, input : Tensor) -> Tensor:
62 | input_1 = input; (-> 1.0, <- 1.13)
63 | linear_qkv_weight = self.linear_qkv.weight; (-> 1.01, <- 0.662)
64 | linear = U.linear(input_1, linear_qkv_weight, None, 'to_output_scale'); (-> 1.01, <- 0.633)
65 | rearrange = einops_einops_rearrange(linear, 'b s (z h d) -> z b h s d', h = 4, z = 3); (-> 1.01, <- 0.633)
66 | getitem = rearrange[0]; (-> 1.0, <- 0.344)
67 | getitem_1 = rearrange[1]; (-> 1.0, <- 0.257)
68 | getitem_2 = rearrange[2]; (-> 1.02, <- 1.01)
69 | scaled_dot_product_attention = U.scaled_dot_product_attention(getitem, getitem_1, getitem_2, dropout_p = 0.1, is_causal = False, mult = 1.0); (-> 1.04, <- 1.0)
70 | rearrange_1 = einops_einops_rearrange(scaled_dot_product_attention, 'b h s d -> b s (h d)'); (-> 1.04, <- 1.0)
71 | linear_o_weight = self.linear_o.weight; (-> 1.0, <- 1.03)
72 | linear_1 = U.linear(rearrange_1, linear_o_weight, None, 'to_output_scale'); (-> 1.06, <- 1.0)
73 | return linear_1
74 | """.strip() # noqa: E501
75 |
76 | assert remove_scales(annotated_code) == remove_scales(expected_code)
77 |
--------------------------------------------------------------------------------
/unit_scaling/tests/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
--------------------------------------------------------------------------------
/unit_scaling/tests/transforms/test_compile.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | from typing import Tuple
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import Tensor, nn
8 |
9 | from ...transforms import compile, unit_scale
10 |
11 |
12 | def test_compile() -> None:
13 | class Module(nn.Module): # pragma: no cover
14 | def __init__(
15 | self,
16 | hidden_size: int,
17 | ) -> None:
18 | super().__init__()
19 | self.layer_norm = nn.LayerNorm(hidden_size)
20 | self.l1 = nn.Linear(hidden_size, 4 * hidden_size)
21 | self.l2 = nn.Linear(4 * hidden_size, hidden_size)
22 |
23 | def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
24 | input = self.layer_norm(input)
25 | input = self.l1(input)
26 | input = F.gelu(input)
27 | input = self.l2(input)
28 | input = F.dropout(input, 0.2)
29 | return input, input.sum()
30 |
31 | mod = Module(2**6)
32 | x = torch.randn(2**3, 2**6)
33 |
34 | compile(mod)(x)
35 | compile(unit_scale(mod))(x)
36 |
--------------------------------------------------------------------------------
/unit_scaling/tests/transforms/test_simulate_format.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import logging
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from pytest import LogCaptureFixture
8 | from torch import Tensor, nn, randn
9 |
10 | from ... import _modules as uu
11 | from ... import functional as U
12 | from ...transforms import simulate_fp8
13 |
14 |
15 | def test_simulate_fp8_linear(caplog: LogCaptureFixture) -> None:
16 | caplog.set_level(logging.INFO)
17 |
18 | class Model(nn.Module):
19 | def __init__(self, d_in: int, d_out: int) -> None:
20 | super().__init__()
21 | self.linear = nn.Linear(d_in, d_out)
22 |
23 | def forward(self, t: Tensor) -> Tensor:
24 | return self.linear(t).sum() # type: ignore[no-any-return]
25 |
26 | input = randn(2**8, 2**9, requires_grad=True)
27 | model = Model(2**9, 2**10)
28 | output = model(input)
29 | output.backward()
30 |
31 | fp8_input = input.clone().detach().requires_grad_()
32 | fp8_model = simulate_fp8(model)
33 | fp8_output = fp8_model(fp8_input)
34 | fp8_output.backward()
35 |
36 | assert not torch.all(fp8_output == output)
37 | assert not torch.all(fp8_input.grad == input.grad) # type: ignore
38 | assert not torch.all(
39 | fp8_model.linear.weight.grad == model.linear.weight.grad # type: ignore
40 | )
41 | assert "quantising function" in caplog.text
42 |
43 |
44 | def test_simulate_fp8_unit_scaled_linear() -> None:
45 | class Model(nn.Module):
46 | def __init__(self, d_in: int, d_out: int) -> None:
47 | super().__init__()
48 | self.linear = uu.Linear(d_in, d_out)
49 |
50 | def forward(self, t: Tensor) -> Tensor:
51 | return self.linear(t).sum() # type: ignore[no-any-return]
52 |
53 | input = randn(2**8, 2**9, requires_grad=True)
54 | model = Model(2**9, 2**10)
55 | output = model(input)
56 | output.backward()
57 |
58 | fp8_input = input.clone().detach().requires_grad_()
59 | fp8_model = simulate_fp8(model)
60 | fp8_output = fp8_model(fp8_input)
61 | fp8_output.backward()
62 |
63 | assert not torch.all(fp8_output == output)
64 | assert not torch.all(fp8_input.grad == input.grad) # type: ignore
65 | assert not torch.all(
66 | fp8_model.linear.weight.grad == model.linear.weight.grad # type: ignore
67 | )
68 |
69 |
70 | def test_simulate_fp8_attention() -> None:
71 | class Model(nn.Module):
72 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
73 | return F.scaled_dot_product_attention(q, k, v).sum()
74 |
75 | inputs = list(randn(2**8, 2**8, requires_grad=True) for _ in range(3))
76 | model = Model()
77 | output = model(*inputs)
78 | output.backward()
79 |
80 | fp8_inputs = list(t.clone().detach().requires_grad_() for t in inputs)
81 | fp8_model = simulate_fp8(model)
82 |
83 | fp8_output = fp8_model(*fp8_inputs)
84 | fp8_output.backward()
85 |
86 | assert not torch.all(fp8_output == output)
87 | for fp8_input, input in zip(fp8_inputs, inputs):
88 | assert not torch.all(fp8_input.grad == input.grad) # type: ignore
89 |
90 |
91 | def test_simulate_fp8_unit_scaled_attention() -> None:
92 | class Model(nn.Module):
93 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
94 | return U.scaled_dot_product_attention(q, k, v).sum()
95 |
96 | inputs = list(randn(2**8, 2**8, requires_grad=True) for _ in range(3))
97 | model = Model()
98 | output = model(*inputs)
99 | output.backward()
100 |
101 | fp8_inputs = list(t.clone().detach().requires_grad_() for t in inputs)
102 | fp8_model = simulate_fp8(model)
103 |
104 | fp8_output = fp8_model(*fp8_inputs)
105 | fp8_output.backward()
106 |
107 | assert not torch.all(fp8_output == output)
108 | for fp8_input, input in zip(fp8_inputs, inputs):
109 | assert not torch.all(fp8_input.grad == input.grad) # type: ignore
110 |
--------------------------------------------------------------------------------
/unit_scaling/tests/transforms/test_track_scales.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import operator
4 | from math import pi, sqrt
5 | from typing import Any, Callable, Dict, Set, Tuple, Union
6 |
7 | import pytest
8 | import torch
9 | import torch.nn.functional as F
10 | from torch import Tensor, nn, randint, randn, randn_like
11 | from torch.fx.graph import Graph
12 | from torch.fx.node import Node
13 |
14 | from ...transforms import (
15 | prune_non_float_tensors,
16 | prune_same_scale_tensors,
17 | prune_selected_nodes,
18 | track_scales,
19 | )
20 |
21 |
22 | def get_target_or_node_name(node: Node) -> Union[str, Callable[..., Any]]:
23 | return node.meta["clean_name"] if isinstance(node.target, str) else node.target
24 |
25 |
26 | def get_targets(graph: Graph) -> Set[Union[str, Callable]]: # type: ignore[type-arg]
27 | return set(get_target_or_node_name(node) for node in graph.nodes)
28 |
29 |
30 | def get_target_map(
31 | graph: Graph,
32 | ) -> Dict[Union[str, Callable], Dict[str, Any]]: # type: ignore[type-arg]
33 | return {get_target_or_node_name(node): node.meta for node in graph.nodes}
34 |
35 |
36 | def test_track_scales() -> None:
37 | class Model(nn.Module):
38 | def forward(self, x: Tensor) -> Tensor: # pragma: no cover
39 | x = F.relu(x)
40 | y = torch.ones_like(x, dtype=x.dtype)
41 | z = x + y
42 | return z.sum()
43 |
44 | model = Model()
45 | model = track_scales(model)
46 | assert len(model.scales_graph().nodes) == 0
47 |
48 | input = randn(2**4, 2**10)
49 | loss = model(input)
50 |
51 | graph = model.scales_graph()
52 |
53 | assert all("outputs_float_tensor" in n.meta for n in graph.nodes)
54 | meta_map = get_target_map(graph)
55 |
56 | assert "metrics" in meta_map["x"]
57 | assert meta_map["x"]["metrics"].bwd is None
58 | assert meta_map["x"]["metrics"].fwd.mean_abs == pytest.approx(
59 | sqrt(2 / pi), abs=0.01
60 | )
61 | assert meta_map["x"]["metrics"].fwd.abs_mean == pytest.approx(0, abs=0.01)
62 | assert meta_map["x"]["metrics"].fwd.std == pytest.approx(1, abs=0.01)
63 | assert meta_map["x"]["metrics"].fwd.numel == 2**14
64 |
65 | assert "metrics" in meta_map[F.relu]
66 | assert meta_map[F.relu]["metrics"].bwd is None
67 | assert (
68 | meta_map[F.relu]["metrics"].fwd.mean_abs
69 | == meta_map[F.relu]["metrics"].fwd.abs_mean
70 | == pytest.approx(sqrt(1 / (2 * pi)), abs=0.01)
71 | )
72 | assert meta_map[F.relu]["metrics"].fwd.std == pytest.approx(
73 | sqrt((1 - 1 / pi) / 2), abs=0.01
74 | )
75 | assert meta_map[F.relu]["metrics"].fwd.numel == 2**14
76 |
77 | assert "metrics" in meta_map[torch.ones_like]
78 | assert meta_map[torch.ones_like]["metrics"].bwd is None
79 | assert meta_map[torch.ones_like]["metrics"].fwd.mean_abs == 1.0
80 | assert meta_map[torch.ones_like]["metrics"].fwd.abs_mean == 1.0
81 | assert meta_map[torch.ones_like]["metrics"].fwd.std == 0.0
82 | assert meta_map[torch.ones_like]["metrics"].fwd.abs_max == 1.0
83 | assert meta_map[torch.ones_like]["metrics"].fwd.abs_min == 1.0
84 | assert meta_map[torch.ones_like]["metrics"].fwd.numel == 2**14
85 |
86 | assert "metrics" in meta_map[operator.add]
87 | assert "metrics" in meta_map["sum_1"]
88 | assert meta_map["sum_1"]["metrics"].fwd.numel == 1
89 |
90 | loss.backward()
91 | graph = model.scales_graph()
92 |
93 | meta_map = get_target_map(graph)
94 |
95 | assert meta_map["sum_1"]["metrics"].bwd is not None
96 | assert meta_map["sum_1"]["metrics"].bwd.numel == 1
97 |
98 | assert meta_map[operator.add]["metrics"].bwd is not None
99 | assert meta_map[operator.add]["metrics"].bwd.mean_abs == 1.0
100 | assert meta_map[operator.add]["metrics"].bwd.abs_mean == 1.0
101 | assert meta_map[operator.add]["metrics"].bwd.std == 0.0
102 | assert meta_map[operator.add]["metrics"].bwd.abs_max == 1.0
103 | assert meta_map[operator.add]["metrics"].bwd.abs_min == 1.0
104 | assert meta_map[operator.add]["metrics"].bwd.numel == 2**14
105 |
106 | assert meta_map["x"]["metrics"].bwd is not None
107 | assert meta_map["x"]["metrics"].bwd.std == pytest.approx(
108 | 0.5, abs=0.01 # same as fwd pass except 0s are now 1s
109 | )
110 |
111 |
112 | def test_prune_non_float_tensors() -> None:
113 | class Model(nn.Module):
114 | def __init__(self, emb_size: int, dim: int) -> None:
115 | super().__init__()
116 | self.emb = nn.Embedding(emb_size, dim)
117 | self.linear = nn.Linear(dim, dim)
118 |
119 | def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover
120 | x = self.emb(idxs)
121 | scores = F.softmax(self.linear(x), dim=-1)
122 | top_idx = torch.argmax(scores, dim=-1)
123 | top_idx = torch.unsqueeze(top_idx, -1)
124 | top_score_x = torch.gather(x, -1, top_idx)
125 | top_score_x -= x.mean()
126 | return top_score_x, top_idx
127 |
128 | idxs = randint(0, 2**10, (2**3, 2**5))
129 | model = Model(2**10, 2**6)
130 | model = track_scales(model)
131 | model(idxs)
132 |
133 | graph = model.scales_graph()
134 | expected_targets = {
135 | "idxs",
136 | "self_modules_emb_parameters_weight",
137 | F.embedding,
138 | F.linear,
139 | "self_modules_linear_parameters_weight",
140 | "self_modules_linear_parameters_bias",
141 | F.softmax,
142 | torch.argmax,
143 | torch.unsqueeze,
144 | torch.gather,
145 | operator.isub,
146 | "mean",
147 | "output",
148 | }
149 | graph_targets = get_targets(graph)
150 | assert graph_targets == expected_targets
151 |
152 | graph = prune_non_float_tensors(graph)
153 | graph_targets = get_targets(graph)
154 | expected_targets -= {"idxs", torch.argmax, torch.unsqueeze}
155 | assert graph_targets == expected_targets
156 |
157 |
158 | def test_prune_same_scale_tensors() -> None:
159 | class Model(nn.Module):
160 | def __init__(self, emb_size: int, dim: int) -> None:
161 | super().__init__()
162 | self.emb = nn.Embedding(emb_size, dim)
163 | self.linear = nn.Linear(dim, dim // 2)
164 |
165 | def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover
166 | # idxs has 0 args -> shouldn't be pruned
167 | x = self.emb(idxs) # emb has 1 float arg (weights) -> depends on tol
168 | _x = x.flatten(start_dim=0, end_dim=-1) # 1 float, same scale -> prune
169 | x = _x.view(x.shape) # 1 float arg, same scale -> prune
170 | y = self.linear(x) # scale changes -> shouldn't be pruned
171 | scores = F.softmax(y, dim=-1) # scale changes -> shouldn't be pruned
172 | top_idx = torch.argmax(scores, dim=-1) # not float -> shouldn't be pruned
173 | top_idx = torch.unsqueeze(top_idx, -1) # not float -> shouldn't be pruned
174 | top_score_x = torch.gather(x, -1, top_idx) # small change -> depends on tol
175 | top_score_x += randn_like(top_score_x) # 2 floats, same scale -> no prune
176 | return top_score_x, top_idx
177 |
178 | idxs = randint(0, 2**10, (2**3, 2**5))
179 | model = Model(2**10, 2**6)
180 | model = track_scales(model)
181 | model(idxs)
182 |
183 | graph = model.scales_graph()
184 |
185 | # Version-dependent, see https://github.com/graphcore-research/unit-scaling/pull/52
186 | var_lhs_flatten = "x"
187 | var_lhs_view = "x_1"
188 | expected_targets = {
189 | "idxs",
190 | "self_modules_emb_parameters_weight",
191 | F.embedding,
192 | var_lhs_flatten,
193 | var_lhs_view,
194 | "self_modules_linear_parameters_weight",
195 | "self_modules_linear_parameters_bias",
196 | F.linear,
197 | F.softmax,
198 | torch.argmax,
199 | torch.unsqueeze,
200 | torch.gather,
201 | randn_like,
202 | operator.iadd,
203 | "output",
204 | }
205 | graph_targets = get_targets(graph)
206 | assert graph_targets == expected_targets
207 |
208 | graph = prune_same_scale_tensors(graph)
209 | graph_targets = get_targets(graph)
210 | expected_targets -= {var_lhs_flatten, var_lhs_view}
211 | assert graph_targets == expected_targets
212 |
213 | graph = prune_same_scale_tensors(graph, rtol=2**-4)
214 | graph_targets = get_targets(graph)
215 | expected_targets -= {torch.gather, F.embedding}
216 | assert graph_targets == expected_targets
217 |
218 |
219 | def test_prune_same_scale_tensors_with_grad() -> None:
220 | class Model(nn.Module):
221 | def forward(self, a: Tensor) -> Tensor: # pragma: no cover
222 | b = a / 1.0 # same scale fwd & bwd
223 | c = b * 1.0 # same scale fwd, as b sums grads -> different scale bwd
224 | d = F.relu(c) # different scale fwd & bwd
225 | e = b - d # different scale fwd, same bwd
226 | f = e.sum() # different scale fwd & bwd
227 | return f
228 |
229 | input = randn(2**6, 2**8)
230 | model = Model()
231 | model = track_scales(model)
232 | loss = model(input)
233 |
234 | graph = model.scales_graph()
235 | expected_targets = {
236 | "a",
237 | operator.truediv,
238 | operator.mul,
239 | F.relu,
240 | operator.sub,
241 | "f",
242 | "output",
243 | }
244 | graph_targets = get_targets(graph)
245 | assert graph_targets == expected_targets
246 |
247 | graph = prune_same_scale_tensors(graph)
248 | graph_targets = get_targets(graph)
249 | expected_targets -= {operator.truediv, operator.mul}
250 | assert graph_targets == expected_targets
251 |
252 | # The mul still has the same scale before & after in the fwd pass, but the same is
253 | # not true for its grads. It should no longer be pruned after `loss.backward`.
254 | loss.backward()
255 | graph = model.scales_graph()
256 | graph = prune_same_scale_tensors(graph)
257 | graph_targets = get_targets(graph)
258 | expected_targets.add(operator.mul)
259 | assert graph_targets == expected_targets
260 |
261 |
262 | def test_prune_selected_nodes() -> None:
263 | class Model(nn.Module):
264 | def forward(self, x: Tensor) -> Tensor: # pragma: no cover
265 | x = x + 1
266 | x = F.relu(x)
267 | x = torch.abs(x)
268 | return x.sum()
269 |
270 | input = randn(2**6, 2**8)
271 | model = Model()
272 | model = track_scales(model)
273 | model(input)
274 |
275 | graph = model.scales_graph()
276 | expected_targets = {
277 | "x",
278 | operator.add,
279 | F.relu,
280 | torch.abs,
281 | "sum_1",
282 | "output",
283 | }
284 | graph_targets = get_targets(graph)
285 | assert graph_targets == expected_targets
286 |
287 | graph = prune_selected_nodes(graph, targets=[torch.abs, F.relu])
288 | graph_targets = get_targets(graph)
289 | expected_targets -= {torch.abs, F.relu}
290 | assert graph_targets == expected_targets
291 |
--------------------------------------------------------------------------------
/unit_scaling/tests/transforms/test_unit_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import logging
4 | import math
5 | import re
6 | from typing import Tuple
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from pytest import LogCaptureFixture
11 | from torch import Tensor, nn, randn
12 |
13 | from ...transforms import simulate_fp8, unit_scale
14 | from ..helper import assert_unit_scaled
15 |
16 |
17 | def test_unit_scale(caplog: LogCaptureFixture) -> None:
18 | caplog.set_level(logging.INFO)
19 |
20 | def custom_gelu(x: Tensor) -> Tensor:
21 | inner = math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
22 | return 0.5 * x * (1.0 + torch.tanh(inner))
23 |
24 | class MLPLayer(nn.Module): # pragma: no cover
25 | def __init__(
26 | self,
27 | hidden_size: int,
28 | ) -> None:
29 | super().__init__()
30 | self.layer_norm = nn.LayerNorm(hidden_size)
31 | self.l1 = nn.Linear(hidden_size, 4 * hidden_size)
32 | self.l2 = nn.Linear(4 * hidden_size, hidden_size)
33 |
34 | def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
35 | input = self.layer_norm(input)
36 | input = self.l1(input)
37 | input = custom_gelu(input)
38 | input = self.l2(input)
39 | input = F.dropout(input, 0.2)
40 | return input, input.sum()
41 |
42 | input = randn(2**6, 2**10, requires_grad=True)
43 | model = nn.Sequential(MLPLayer(2**10))
44 | model = unit_scale(model, replace={custom_gelu: F.gelu})
45 | output, loss = model(input)
46 | loss.backward()
47 |
48 | assert_unit_scaled(
49 | output,
50 | input.grad,
51 | model[0].layer_norm.weight.grad,
52 | model[0].l1.weight.grad,
53 | model[0].l2.weight.grad,
54 | abs=0.2,
55 | )
56 |
57 | expected_logs = [
58 | "unit scaling weight",
59 | "setting bias to zero",
60 | "unit scaling function",
61 | "replacing function",
62 | "unconstraining node",
63 | ]
64 | for log_msg in expected_logs:
65 | assert log_msg in caplog.text
66 |
67 |
68 | def test_unit_scale_residual_add(caplog: LogCaptureFixture) -> None:
69 | caplog.set_level(logging.INFO)
70 |
71 | class MLPLayer(nn.Module):
72 | def __init__(
73 | self,
74 | hidden_size: int,
75 | ) -> None:
76 | super().__init__()
77 | self.l1 = nn.Linear(hidden_size, hidden_size)
78 | self.l2 = nn.Linear(hidden_size, hidden_size)
79 |
80 | def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover
81 | skip = input
82 | input = input + 1
83 | input = self.l1(input)
84 | input = input + skip
85 | skip = input
86 | input += 1
87 | input = self.l2(input)
88 | input += skip
89 | return input, input.sum()
90 |
91 | input = randn(2**6, 2**10, requires_grad=True)
92 | model = MLPLayer(2**10)
93 | us_model = unit_scale(model)
94 | output, loss = us_model(input)
95 | loss.backward()
96 |
97 | expected_logs = [
98 | r"unit scaling function: (input_2)\n",
99 | r"unit scaling function: (input_4)\n",
100 | r"unit scaling function: (skip_1|input_3) \(residual-add, tau=0\.5\)",
101 | r"unit scaling function: (add_1|input_6) \(residual-add, tau=0\.5\)",
102 | ]
103 |
104 | for log_msg in expected_logs:
105 | assert re.search(log_msg, caplog.text)
106 |
107 |
108 | def test_fp8_unit_scaling(caplog: LogCaptureFixture) -> None:
109 | caplog.set_level(logging.INFO)
110 |
111 | class Model(nn.Module):
112 | def __init__(self, d_in: int, d_out: int) -> None:
113 | super().__init__()
114 | self.linear = nn.Linear(d_in, d_out)
115 |
116 | def forward(self, t: Tensor) -> Tensor: # pragma: no cover
117 | return self.linear(t) # type: ignore[no-any-return]
118 |
119 | input = randn(2**8, 2**9)
120 | model = Model(2**9, 2**10)
121 | model = simulate_fp8(model)
122 | model = unit_scale(model)
123 | model(input)
124 |
125 | expected_logs = [
126 | "moving unit scaling backend to precede quantisation backend",
127 | "running unit scaling backend",
128 | "running quantisation backend",
129 | ]
130 | for log_msg in expected_logs:
131 | assert log_msg in caplog.text
132 |
--------------------------------------------------------------------------------
/unit_scaling/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | """Useful torch dynamo transforms of modules for the sake of numerics and unit
4 | scaling."""
5 |
6 | # This all has to be done manually to keep mypy happy.
7 | # Removing the `--no-implicit-reexport` option ought to fix this, but doesn't appear to.
8 |
9 | from ._compile import compile
10 | from ._simulate_format import simulate_format, simulate_fp8
11 | from ._track_scales import (
12 | Metrics,
13 | prune_non_float_tensors,
14 | prune_same_scale_tensors,
15 | prune_selected_nodes,
16 | track_scales,
17 | )
18 | from ._unit_scale import unit_scale
19 |
20 | __all__ = [
21 | "Metrics",
22 | "compile",
23 | "prune_non_float_tensors",
24 | "prune_same_scale_tensors",
25 | "prune_selected_nodes",
26 | "simulate_format",
27 | "simulate_fp8",
28 | "track_scales",
29 | "unit_scale",
30 | ]
31 |
--------------------------------------------------------------------------------
/unit_scaling/transforms/_compile.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | from typing import TypeVar
4 |
5 | from torch import _TorchCompileInductorWrapper, nn
6 |
7 | from .utils import apply_transform
8 |
9 | M = TypeVar("M", bound=nn.Module)
10 |
11 |
12 | def compile(module: M) -> M:
13 | """A transform that applies torch.compile to a module.
14 |
15 | Note that this is slightly different to calling :code:`torch.compile(module)`.
16 |
17 | The current version of :func:`torch.compile` doesn't allow for nested transforms, so
18 | the following is not supported:
19 |
20 | .. code-block:: python
21 |
22 | import torch
23 |
24 | from unit_scaling.transforms import unit_scale
25 |
26 | module = torch.compile(unit_scale(module))
27 |
28 | :mod:`unit_scaling.transforms` addresses this by introducing a range of composable
29 | transforms. This works by moving the call to
30 | :func:`torch._dynamo.optimize` within the :code:`forward()` method of the module
31 | and only executing it on the first call to the module, or if a new transform
32 | is applied, the optimised call being cached thereafter.
33 |
34 | The :func:`unit_scaling.transforms.compile` function is one such composable
35 | transform. This means that the following can be written:
36 |
37 | .. code-block:: python
38 |
39 | from unit_scaling.transforms import compile, unit_scale
40 |
41 | module = compile(unit_scale(module))
42 |
43 | This will successfully combine the two transforms in a single module. Note that
44 | the call to compile must still come last, as its underlying backend returns a
45 | standard :class:`torch.nn.Module` rather than a :class:`torch.fx.GraphModule`.
46 |
47 | Currently :func:`unit_scaling.transforms.compile` does not support the ops needed
48 | for the :func:`unit_scaling.transforms.simulate_fp8` transform, though this may
49 | change in future PyTorch releases.
50 |
51 | Modules implemented manually with unit-scaled layers (i.e. without the global
52 | :code:`unit_scale(module)` transform) can still use :func:`torch.compile` in the
53 | standard way.
54 |
55 | Args:
56 | module (M): the module to be compiled.
57 |
58 | Returns:
59 | M: the compiled module.
60 | """
61 | return apply_transform( # type: ignore[no-any-return]
62 | module, _TorchCompileInductorWrapper("default", None, None) # type: ignore
63 | )
64 |
--------------------------------------------------------------------------------
/unit_scaling/transforms/_simulate_format.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 | import logging
3 | from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
4 |
5 | import torch.nn.functional as F
6 | from torch import Tensor, nn
7 | from torch.fx.graph import Graph
8 | from torch.fx.graph_module import GraphModule
9 | from torch.fx.node import Node
10 |
11 | from .. import functional as U
12 | from .._internal_utils import generate__all__
13 | from ..formats import FPFormat, format_to_tuple, tuple_to_format
14 | from .utils import Backend, apply_transform, replace_node_with_function
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | # These functions currently have to be defined explicitly to make PyTorch happy
20 | # Creating temporary wrapped functions doesn't work...
21 | def _quantised_linear(
22 | input: Tensor,
23 | weight: Tensor,
24 | bias: Optional[Tensor],
25 | fwd_format_tuple: Tuple[int, int],
26 | bwd_format_tuple: Tuple[int, int],
27 | ) -> Tensor:
28 | fwd_format = tuple_to_format(fwd_format_tuple)
29 | bwd_format = tuple_to_format(bwd_format_tuple)
30 | input = fwd_format.quantise_fwd(input)
31 | weight = fwd_format.quantise_fwd(weight)
32 | output = F.linear(input, weight, bias)
33 | return bwd_format.quantise_bwd(output)
34 |
35 |
36 | def _quantised_u_linear(
37 | input: Tensor,
38 | weight: Tensor,
39 | bias: Optional[Tensor],
40 | fwd_format_tuple: Tuple[int, int],
41 | bwd_format_tuple: Tuple[int, int],
42 | constraint: Optional[str] = "to_output_scale",
43 | ) -> Tensor:
44 | fwd_format = tuple_to_format(fwd_format_tuple)
45 | bwd_format = tuple_to_format(bwd_format_tuple)
46 | input, weight = (fwd_format.quantise_fwd(t) for t in (input, weight))
47 | output = U.linear(input, weight, bias, constraint)
48 | return bwd_format.quantise_bwd(output)
49 |
50 |
51 | def _quantised_scaled_dot_product_attention(
52 | query: Tensor,
53 | key: Tensor,
54 | value: Tensor,
55 | fwd_format_tuple: Tuple[int, int],
56 | bwd_format_tuple: Tuple[int, int],
57 | **kwargs: Any,
58 | ) -> Tensor:
59 | fwd_format = tuple_to_format(fwd_format_tuple)
60 | bwd_format = tuple_to_format(bwd_format_tuple)
61 | query, key, value = (fwd_format.quantise_fwd(t) for t in (query, key, value))
62 | output = F.scaled_dot_product_attention(query, key, value, **kwargs)
63 | return bwd_format.quantise_bwd(output)
64 |
65 |
66 | def _quantised_u_scaled_dot_product_attention(
67 | query: Tensor,
68 | key: Tensor,
69 | value: Tensor,
70 | fwd_format_tuple: Tuple[int, int],
71 | bwd_format_tuple: Tuple[int, int],
72 | **kwargs: Any,
73 | ) -> Tensor:
74 | fwd_format = tuple_to_format(fwd_format_tuple)
75 | bwd_format = tuple_to_format(bwd_format_tuple)
76 | query, key, value = (fwd_format.quantise_fwd(t) for t in (query, key, value))
77 | output = U.scaled_dot_product_attention(query, key, value, **kwargs)
78 | return bwd_format.quantise_bwd(output)
79 |
80 |
81 | _replacement_map: Dict[Callable[..., Any], Callable[..., Any]] = {
82 | F.linear: _quantised_linear,
83 | U.linear: _quantised_u_linear,
84 | F.scaled_dot_product_attention: _quantised_scaled_dot_product_attention,
85 | U.scaled_dot_product_attention: _quantised_u_scaled_dot_product_attention,
86 | }
87 |
88 |
89 | def _replace_with_quantised(
90 | graph: Graph,
91 | node: Node,
92 | fwd_format: FPFormat,
93 | bwd_format: FPFormat,
94 | ) -> None:
95 | # Ideally we'd pass the formats as kwargs, but it currently causes a torch fx bug.
96 | # This workaround will suffice for now...
97 | args = [*node.args]
98 | if len(node.args) == 2: # pragma: no cover
99 | args.append(None)
100 | # Breaks when I pass in FPFormat objects, so convert to tuple and back
101 | args = (
102 | args[:3] + [format_to_tuple(fwd_format), format_to_tuple(bwd_format)] + args[3:]
103 | )
104 |
105 | assert callable(node.target)
106 | quantised_fn = _replacement_map[node.target]
107 | logger.info("quantising function: %s", node)
108 | replace_node_with_function(graph, node, quantised_fn, args=tuple(args))
109 |
110 |
111 | def _quantisation_backend(fwd_format: FPFormat, bwd_format: FPFormat) -> Backend:
112 | def backend_fn(gm: GraphModule, example_inputs: List[Tensor]) -> GraphModule:
113 | logger.info("running quantisation backend")
114 | graph = gm.graph
115 | for node in graph.nodes:
116 | if node.op == "call_function" and node.target in _replacement_map:
117 | _replace_with_quantised(graph, node, fwd_format, bwd_format)
118 | graph.lint() # type: ignore[no-untyped-call]
119 | return GraphModule(gm, graph)
120 |
121 | return backend_fn
122 |
123 |
124 | M = TypeVar("M", bound=nn.Module)
125 |
126 |
127 | def simulate_format(module: M, fwd_format: FPFormat, bwd_format: FPFormat) -> M:
128 | """**[Experimental]** Given a module, uses TorchDynamo to return a new module which
129 | simulates the effect of using the supplied formats for matmuls.
130 |
131 | Specifically, before each :func:`torch.nn.functional.linear` and
132 | :func:`torch.nn.functional.scaled_dot_product_attention` call, a quantisation op
133 | is inserted which simulates the effect of using the supplied `fwd_format`. This op
134 | reduces the range of values to that of the given format, and (stochastically) rounds
135 | values to only those representable by the format.
136 |
137 | The same is true for the backward pass, where an op is inserted to quantise to the
138 | `bwd_format`. Models which use modules that contain these functions internally
139 | (such as :class:`torch.nn.Linear`) will be inspected by TorchDynamo and have the
140 | correct quantisation ops inserted.
141 |
142 | If the equivalent unit-scaled functions from :mod:`unit_scaling.functional` are
143 | used in the module, these too will be quantised.
144 |
145 | Simulation of formats is run in FP32. Users should not expect speedups from using
146 | this method. The purpose is to simulate the numerical effects of running matmuls
147 | in various formats.
148 |
149 | Args:
150 | module (nn.Module): the module to be quantised
151 | fwd_format (FPFormat): the quantisation format to be used in the forward pass
152 | (activations and weights)
153 | bwd_format (FPFormat): the quantisation format to be used in the backward pass
154 | (gradients of activations and weights)
155 |
156 | Returns:
157 | nn.Module: a new module which when used, will run using the simulated formats.
158 | """
159 | return apply_transform( # type: ignore[no-any-return]
160 | module, _quantisation_backend(fwd_format, bwd_format)
161 | )
162 |
163 |
164 | def simulate_fp8(module: M) -> M:
165 | """**[Experimental]** Given a module, uses TorchDynamo to return a new module which
166 | simulates the effect of running matmuls in FP8. As is standard in the literature
167 | (Noune et al., 2022; Micikevicius et al., 2022), we use the FP8 E4 format in the
168 | forwards pass, and FP8 E5 in the backward pass.
169 |
170 | Specifically, before each :func:`torch.nn.functional.linear` and
171 | :func:`torch.nn.functional.scaled_dot_product_attention` call, a quantisation op
172 | is inserted which simulates the effect of using FP8. This op
173 | reduces the range of values to that of the format, and (stochastically) rounds
174 | values to only those representable by the format.
175 |
176 | The same is true for the backward pass.
177 | Models which use modules that contain these functions internally
178 | (such as :class:`torch.nn.Linear`) will be inspected by TorchDynamo and have the
179 | correct quantisation ops inserted.
180 |
181 | If the equivalent unit-scaled functions from :mod:`unit_scaling.functional` are
182 | used in the module, these too will be quantised.
183 |
184 | Simulation of formats is run in FP32. Users should not expect speedups from using
185 | this method. The purpose is to simulate the numerical effects of running matmuls
186 | in FP8.
187 |
188 | Args:
189 | module (nn.Module): the module to be quantised
190 |
191 | Returns:
192 | nn.Module: a new module which when used, will run with matmul inputs in FP8.
193 | """
194 | return simulate_format(module, fwd_format=FPFormat(4, 3), bwd_format=FPFormat(5, 2))
195 |
196 |
197 | __all__ = generate__all__(__name__)
198 |
--------------------------------------------------------------------------------
/unit_scaling/transforms/_unit_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | import logging
4 | from inspect import signature
5 | from operator import getitem
6 | from types import BuiltinFunctionType
7 | from typing import Any, Callable, Dict, List, Set, TypeVar
8 |
9 | import torch
10 | import torch._dynamo
11 | import torch.nn.functional as F
12 | from torch import Tensor, nn
13 | from torch.fx.graph import Graph
14 | from torch.fx.graph_module import GraphModule
15 | from torch.fx.node import Node
16 |
17 | from .. import functional as U
18 | from .._internal_utils import generate__all__
19 | from .utils import Backend, apply_transform, replace_node_with_function
20 |
21 | T = TypeVar("T")
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def _add_dependency_meta(graph: Graph, recalculate: bool = False) -> None:
26 | def recurse(n: Node) -> Set[Node]:
27 | if "dependencies" in n.meta:
28 | return n.meta["dependencies"] # type: ignore[no-any-return]
29 | deps = set(n.all_input_nodes)
30 | for parent in n.all_input_nodes:
31 | deps.update(recurse(parent))
32 | n.meta["dependencies"] = deps
33 | return deps
34 |
35 | if recalculate:
36 | for n in graph.nodes:
37 | if "dependencies" in n.meta:
38 | del n.meta["dependencies"]
39 |
40 | for n in graph.nodes:
41 | if n.op == "output":
42 | recurse(n)
43 |
44 |
45 | def _is_add(n: Node) -> bool:
46 | return (
47 | n.op == "call_function"
48 | and isinstance(n.target, BuiltinFunctionType)
49 | and n.target.__name__ in ["add", "iadd"]
50 | )
51 |
52 |
53 | def _is_self_attention(skip_node: Node, residual_node: Node) -> bool:
54 | # Identify all operations on the residual branch
55 | residual_fns = {residual_node.target}
56 | parents = [*residual_node.all_input_nodes]
57 | while parents:
58 | p = parents.pop()
59 | if p == skip_node:
60 | continue
61 | residual_fns.add(p.target)
62 | parents += p.all_input_nodes
63 | # Check if any of them are self-attention ops (softmax or fused self-attention)
64 | self_attention_fns = {
65 | F.scaled_dot_product_attention,
66 | U.scaled_dot_product_attention,
67 | F.softmax,
68 | U.softmax,
69 | }
70 | return bool(residual_fns.intersection(self_attention_fns))
71 |
72 |
73 | def _unit_scale_residual(
74 | graph: Graph,
75 | add: Node,
76 | residual_arg_idx: int,
77 | is_self_attention: bool,
78 | ) -> None:
79 | residual, skip = add.args[residual_arg_idx], add.args[1 - residual_arg_idx]
80 | tau = 0.01 if is_self_attention else 0.5
81 | logger.info("unit scaling function: %s (residual-add, tau=%s)", add, tau)
82 | old_start_residuals = [
83 | u for u in skip.users if u is not add # type: ignore[union-attr]
84 | ]
85 | with graph.inserting_after(skip): # type: ignore[arg-type]
86 | split = graph.call_function(
87 | U.residual_split,
88 | args=(skip, tau),
89 | type_expr=getattr(skip, "type", None),
90 | )
91 | with graph.inserting_after(split):
92 | new_start_residual = graph.call_function(getitem, args=(split, 0))
93 | for old_start_residual in old_start_residuals:
94 | old_start_residual.replace_input_with(skip, new_start_residual) # type: ignore
95 | with graph.inserting_after(split):
96 | new_skip = graph.call_function(getitem, args=(split, 1))
97 | replace_node_with_function(
98 | graph, add, U.residual_add, args=(residual, new_skip, tau)
99 | )
100 |
101 |
102 | def _unconstrain_node(node: Node) -> None:
103 | if (
104 | node.op == "call_function"
105 | and callable(node.target)
106 | and not isinstance(node.target, BuiltinFunctionType)
107 | and "constraint" in signature(node.target).parameters
108 | ):
109 | logger.info("unconstraining node: %s", node)
110 | node.kwargs = dict(node.kwargs, constraint=None)
111 |
112 |
113 | def unit_scaling_backend(
114 | replacement_map: Dict[Callable[..., Any], Callable[..., Any]] = dict()
115 | ) -> Backend:
116 | def inner_backend(gm: GraphModule, example_inputs: List[Tensor]) -> GraphModule:
117 | logger.info("running unit scaling backend")
118 | graph = gm.graph
119 | # Replace function nodes with those in `replacement_map` or with their
120 | # unit scaled equivalents
121 | for node in graph.nodes:
122 | if node.op == "call_function":
123 | if node.target in replacement_map:
124 | target_fn = replacement_map[node.target]
125 | logger.info(
126 | "replacing function: %s with %s", node, target_fn.__name__
127 | )
128 | replace_node_with_function(graph, node, target_fn)
129 | elif node.target in U.torch_map:
130 | target_fn = U.torch_map[node.target]
131 | logger.info("unit scaling function: %s", node)
132 | replace_node_with_function(graph, node, target_fn)
133 |
134 | # Add metadata denoting the dependencies of every node in the graph
135 | _add_dependency_meta(graph)
136 |
137 | # Go through and mark nodes which represent residual-adds
138 | residual_layer_number = 1
139 | for node in graph.nodes:
140 | if _is_add(node):
141 | is_residual_add = False
142 | if len(node.args) == 2:
143 | l, r = node.args
144 | if isinstance(l, Node) and isinstance(r, Node):
145 | l_deps = l.meta.get("dependencies", list())
146 | r_deps = r.meta.get("dependencies", list())
147 | if l in r_deps or r in l_deps:
148 | node.meta["residual_add"] = {
149 | "residual_arg_idx": 1 if l in r_deps else 0,
150 | }
151 | residual_layer_number += 1
152 | is_residual_add = True
153 | skip_node, residual_node = (l, r) if l in r_deps else (r, l)
154 | is_sa = _is_self_attention(skip_node, residual_node)
155 | node.meta["residual_add"]["is_self_attention"] = is_sa
156 | # Regular adds are not picked up by the unit scaling sweep above as
157 | # the inbuilt + operation is handled differently when traced. It is
158 | # instead substituted for its unit scaled equivalent here.
159 | if not is_residual_add:
160 | logger.info("unit scaling function: %s", node)
161 | args = (*node.args, None) # None denotes unconstrained
162 | replace_node_with_function(graph, node, U.add, args=args)
163 |
164 | # Replace nodes marked as residual-adds with unit scaled equivalent
165 | for node in graph.nodes:
166 | residual_add = node.meta.get("residual_add", None)
167 | if residual_add is not None:
168 | _unit_scale_residual(graph, node, **residual_add)
169 |
170 | _add_dependency_meta(graph, recalculate=True)
171 | for node in graph.nodes:
172 | if node.target == U.residual_add:
173 | node.meta["has_residual_successor"] = True
174 | dependencies = node.meta.get("dependencies", [])
175 | for d in dependencies:
176 | d.meta["has_residual_successor"] = True
177 |
178 | for node in graph.nodes:
179 | if "has_residual_successor" not in node.meta:
180 | _unconstrain_node(node)
181 |
182 | graph.lint() # type: ignore[no-untyped-call]
183 | return GraphModule(gm, graph)
184 |
185 | return inner_backend
186 |
187 |
188 | def _unit_init_weights(m: nn.Module) -> None:
189 | for name, mod in m.named_modules():
190 | if isinstance(mod, (nn.Linear, nn.Embedding)):
191 | with torch.no_grad():
192 | if isinstance(mod.weight, Tensor):
193 | logger.info("unit scaling weight: %s.weight", name)
194 | mod.weight /= mod.weight.std()
195 |
196 |
197 | def _zero_init_biases(m: nn.Module) -> None:
198 | for name, mod in m.named_modules():
199 | if isinstance(mod, (nn.Linear, nn.Embedding)):
200 | with torch.no_grad():
201 | if hasattr(mod, "bias") and isinstance(mod.bias, Tensor):
202 | logger.info("setting bias to zero: %s.bias", name)
203 | mod.bias -= mod.bias
204 |
205 |
206 | # Unit scaling should come before quantisation
207 | def _order_backends(backends: List[Backend]) -> None:
208 | unit_scaling_backend_idx = -1
209 | quantisation_backend_idx = float("inf")
210 | for i, b in enumerate(backends):
211 | if "unit_scaling_backend" in b.__qualname__:
212 | unit_scaling_backend_idx = i
213 | if "quantisation_backend" in b.__qualname__:
214 | quantisation_backend_idx = i
215 | if unit_scaling_backend_idx > quantisation_backend_idx:
216 | logger.info("moving unit scaling backend to precede quantisation backend")
217 | u = backends.pop(unit_scaling_backend_idx)
218 | backends.insert(quantisation_backend_idx, u) # type: ignore[arg-type]
219 |
220 |
221 | M = TypeVar("M", bound=nn.Module)
222 |
223 |
224 | def unit_scale(
225 | module: M, replace: Dict[Callable[..., Any], Callable[..., Any]] = dict()
226 | ) -> M:
227 | """**[Experimental]** Returns a unit-scaled version of the input model.
228 |
229 | Uses TorchDynamo to trace and transform the user-supplied module.
230 | This transformation identifies all :class:`torch.nn.functional` uses in the input
231 | module, and replaces them with their unit-scaled equivalents, should they exist.
232 |
233 | The tracing procedure automatically recurses into modules
234 | (whether defined in libraries, or by the user), identifying inner calls to any
235 | :class:`torch.nn.functional` operations, to build a graph of fundamental operations.
236 | Unit scaling is then applied as a transformation on top of this graph.
237 |
238 | This transformation proceeds in five stages:
239 |
240 | #. **Replacement of user-defined functions** according to the supplied `replace`
241 | dictionary.
242 | #. **Replacement of all functions with unit-scaled equivalents** defined in
243 | :mod:`unit_scaling.functional`.
244 | #. Identification & **replacement of all add operations that represent
245 | residual-adds**. The identification of residual connections is done via a
246 | dependency analysis on the graph. Residual-adds require special scaling compared
247 | with regular adds (see paper / User Guide for details).
248 | #. **Unconstraining of all operations after the final residual layer**. By default
249 | all unit scaled operations have their scaling factors constrained in the forward
250 | and backward pass to give valid gradients. This is not required in these final
251 | layers (see paper for proof), and hence we can unconstrain the operations to give
252 | better scaling.
253 | #. **Unit-scaling of all weights** and zero-initialisation of all biases.
254 |
255 | Note that by using TorchDynamo, `unit_scale()` is able to trace a much larger set of
256 | modules / operations than with previous PyTorch tracing approaches. This enables
257 | the process of unit scaling to be expressed as a generic graph transform that can be
258 | applied to arbitrary modules.
259 |
260 | Note that the current version of TorchDynamo (or :func:`torch.compile`, which is a
261 | wrapper around TorchDynamo) doesn't support nested transforms, so we implement our
262 | own system here. This makes it easy to nest transforms:
263 |
264 | .. code-block:: python
265 |
266 | from unit_scaling.transforms import compile, simulate_fp8, unit_scale
267 |
268 | module = compile(simulate_fp8(unit_scale(module)))
269 |
270 | However, these transforms are not interoperable with the standard
271 | :func:`torch.compile` interface.
272 |
273 | In some cases users may have a model definition that uses a custom implementation of
274 | a basic operation. In this case, `unit_scale()` can be told explicitly to substitute
275 | the layer for an equivalent, using the `replace` dictionary:
276 |
277 | .. code-block:: python
278 |
279 | import unit_scaling.functional as U
280 | from unit_scaling.transforms import unit_scale
281 |
282 | def new_gelu(x):
283 | ...
284 |
285 | class Model(nn.Module):
286 | def forward(x):
287 | ...
288 | x = new_gelu(x)
289 | ...
290 |
291 | model = unit_scale(Model(), replace={new_gelu: U.gelu})
292 |
293 | This can also be used to substitute a particular function for a user-defined
294 | unit-scaled function not provided by :mod:`unit_scaling.functional`.
295 |
296 | **Note:** `unit_scale()` is experimental and has not yet been widely tested on a
297 | range of models. The standard approach to unit scaling a model is still to
298 | manually substitute the layers/operations in a model with their unit-scaled
299 | equivalents. Having said this, `unit_scale()` is implemented in a sufficiently
300 | generic way that we anticipate many users will ultimately be able to rely on this
301 | graph transform alone.
302 |
303 | Args:
304 | module (nn.Module): the input module to be unit scaled.
305 | replace (Dict[Callable, Callable], optional): a dictionary where keys represent
306 | functions to be replaced by the corresponding value-functions. Note that
307 | these substitutions take priority over the standard unit scaling
308 | substitutions. Defaults to dict().
309 |
310 | Returns:
311 | nn.Module: the unit scaled module (with an independent copy of parameters)
312 | """
313 | unit_scaled_module = apply_transform(
314 | module,
315 | unit_scaling_backend(replace),
316 | non_recurse_functions=list(replace.keys()),
317 | )
318 | _order_backends(unit_scaled_module.backends)
319 | _unit_init_weights(unit_scaled_module)
320 | _zero_init_biases(unit_scaled_module)
321 | return unit_scaled_module # type: ignore[no-any-return]
322 |
323 |
324 | __all__ = generate__all__(__name__)
325 |
--------------------------------------------------------------------------------
/unit_scaling/transforms/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | """Utilities for working with transforms."""
4 |
5 | import copy
6 | import functools
7 | from typing import (
8 | Any,
9 | Callable,
10 | Dict,
11 | Iterable,
12 | List,
13 | Optional,
14 | Tuple,
15 | TypeVar,
16 | no_type_check,
17 | )
18 | from unittest.mock import patch
19 |
20 | import torch
21 | import torch._dynamo
22 | from torch import Tensor, nn
23 | from torch.fx.graph import Graph
24 | from torch.fx.graph_module import GraphModule
25 | from torch.fx.node import Node
26 |
27 | from .. import functional as U
28 | from .._internal_utils import generate__all__
29 |
30 | T = TypeVar("T")
31 |
32 | Backend = Callable[[GraphModule, List[Tensor]], Callable[..., Any]]
33 |
34 | _unit_scaled_functions = [getattr(U, f) for f in U.__all__]
35 |
36 |
37 | def torch_nn_modules_to_user_modules(mod: nn.Module) -> None:
38 | """
39 | Convert torch.nn.module classes to `trivial_subclass` versions.
40 |
41 | By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or
42 | :mod:`torch.nn.functional` functions when capturing the FX graph.
43 |
44 | This function makes `torch.nn` modules into user modules.
45 |
46 | To use this with a :class:`torch.nn.Module` the typical use case
47 | is to call `module = torch_nn_modules_to_user_modules(module)`.
48 | """
49 |
50 | for n, submod in mod.named_children():
51 | torch_nn_modules_to_user_modules(submod)
52 |
53 | # Mirroring the check at https://github.com/pytorch/pytorch/blob/34bce27f0d12bf7226b37dfe365660aad456701a/torch/_dynamo/variables/nn_module.py#L307 # noqa: E501
54 | if submod.__module__.startswith(("torch.nn.", "torch.ao.")):
55 | # Generate a new name, so e.g. torch.nn.modules.sparse.Embedding
56 | # becomes trivial_subclass_modules_sparse_Embedding
57 | modulename = submod.__module__
58 | modulename = modulename.replace("torch.nn.", "", 1)
59 | modulename = modulename.replace(".", "_")
60 | newtypename = "trivial_subclass_" + modulename + "_" + type(submod).__name__
61 |
62 | # Create a new type object deriving from type(submod)
63 | newmodtype = type(newtypename, (type(submod),), {})
64 |
65 | # Initialize and copy state using pickle
66 | newsubmod = newmodtype.__new__(newmodtype) # type: ignore [call-overload]
67 | state = submod.__getstate__() # type: ignore [no-untyped-call]
68 | newsubmod.__setstate__(state)
69 |
70 | # Update module in mod
71 | setattr(mod, n, newsubmod)
72 |
73 |
74 | def patch_to_expand_modules(fn: Callable[..., T]) -> Callable[..., T]:
75 | """By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or
76 | :mod:`torch.nn.functional` functions when capturing the FX graph.
77 | Any function which is wrapped in
78 | :func:`torch._dynamo.optimise` (or :func:`torch.compile`) and is then passed to
79 | this function as `fn` will now automatically recurse into
80 | :mod:`torch.nn` modules or :mod:`torch.nn.functional` functions.
81 |
82 | In practice, to use this with a :class:`torch.nn.Module` the typical use case
83 | is to call `module = torch._dynamo.optimize(backend)(module)`, followed by
84 | `module.forward = patch_to_expand_modules(module.forward)`.
85 |
86 | This should be used in conjunction with :func:`torch_nn_modules_to_user_modules`
87 |
88 | Args:
89 | fn (Callable[..., T]): the function to be patched.
90 |
91 | Returns:
92 | Callable[..., T]: the new version of `fn` with patching applied.
93 | """
94 |
95 | def _patched_call_function( # type: ignore[no-untyped-def]
96 | self,
97 | tx,
98 | args,
99 | kwargs,
100 | ): # pragma: no cover
101 | # Removing the check in https://github.com/pytorch/pytorch/blob/72662bf05b3499ce96aae9183a489c78f0c44c84/torch/_dynamo/variables/functions.py#L335 # noqa: E501
102 | return super(
103 | torch._dynamo.variables.functions.UserMethodVariable, self
104 | ).call_function(tx, args, kwargs)
105 |
106 | @functools.wraps(fn)
107 | def new_fn(*args: Any, **kwargs: Any) -> Any:
108 | with patch(
109 | "torch._dynamo.variables.functions.UserMethodVariable.call_function",
110 | new=_patched_call_function,
111 | ):
112 | return fn(*args, **kwargs)
113 |
114 | return new_fn
115 |
116 |
117 | def replace_node_with_function(
118 | graph: Graph,
119 | source: Node,
120 | target_fn: Callable[..., Any],
121 | args: Optional[Tuple[Any, ...]] = None,
122 | kwargs: Optional[Dict[Any, Any]] = None,
123 | keep_type_expr: bool = True,
124 | ) -> None:
125 | """Given a source node and its accompanying graph, remove the node and replace it
126 | with a new node that represents calling the target function.
127 |
128 | Args:
129 | graph (Graph): the graph in which the node is present.
130 | source (Node): the node to be replaced.
131 | target_fn (Callable[..., Any]): the function to be contained in the new node.
132 | args (Optional[Tuple[Any, ...]], optional): args of the new node.
133 | Defaults to None.
134 | kwargs (Optional[Dict[Any, Any]], optional): kwargs of the new node.
135 | Defaults to None.
136 | keep_type_expr (bool, optional): retain the type expression of the removed node.
137 | Defaults to True.
138 | """
139 | if args is None:
140 | args = source.args
141 | if kwargs is None:
142 | kwargs = source.kwargs
143 | type_expr = getattr(source, "type", None) if keep_type_expr else None
144 | with graph.inserting_after(source):
145 | new_node = graph.call_function(target_fn, args, kwargs, type_expr)
146 | source.replace_all_uses_with(new_node)
147 | graph.erase_node(source)
148 |
149 |
150 | def _compose_backends(backends: Iterable[Backend]) -> Backend:
151 | def composite_backend(
152 | gm: GraphModule, example_inputs: List[Tensor]
153 | ) -> Callable[..., Any]:
154 | for b in backends:
155 | new_gm = b(gm, example_inputs)
156 | new_gm._param_name_to_source = getattr( # type: ignore
157 | gm,
158 | "_param_name_to_source",
159 | None,
160 | )
161 | gm = new_gm # type: ignore[assignment]
162 | return gm
163 |
164 | return composite_backend
165 |
166 |
167 | M = TypeVar("M", bound=nn.Module)
168 |
169 |
170 | @no_type_check
171 | def apply_transform(
172 | module: M,
173 | backend: Backend,
174 | non_recurse_functions: List[Callable[..., Any]] = list(),
175 | ) -> M:
176 | """Applies a graph transformation to a module.
177 |
178 | The user-supplied :code:`backend` represents a transformation of a
179 | :class:`torch.fx.graph_module.GraphModule`. :code:`apply_transform()` uses
180 | :func:`torch._dynamo.optimize` to apply this transformation to the module,
181 | returning a new transformed module.
182 |
183 | Note that the current version of TorchDynamo (or :func:`torch.compile`, which is a
184 | wrapper around TorchDynamo) doesn't support nested transforms, so we implement our
185 | own system here. This makes it easy to nest transforms:
186 |
187 | .. code-block:: python
188 |
189 | module = apply_transform(apply_transform(module, backend_1), backend_2)
190 |
191 | However, it should be noted these transforms are not interoperable with the standard
192 | :func:`torch.compile` interface.
193 |
194 | This nesting system is implemented by moving the call to
195 | :func:`torch._dynamo.optimize` within the :code:`forward()` method of the module
196 | (though it is only executed on the first call to the module, or if a new transform
197 | is applied, the optimised call being cached thereafter). This differs from the
198 | standard approach used with :func:`torch._dynamo.optimize`, but enables this
199 | convenient nesting functionality.
200 |
201 | Args:
202 | _module (nn.Module): the module to be transformed.
203 | backend (Backend): the graph transformation to be applied.
204 | non_recurse_functions (Iterable[Callable[..., Any]], optional): functions which
205 | the user does not wish to be recursed into. Defaults to list().
206 |
207 | Returns:
208 | nn.Module: the transformed module.
209 | """
210 | module = copy.deepcopy(module)
211 |
212 | torch_nn_modules_to_user_modules(module)
213 |
214 | if not hasattr(module, "backends"):
215 | module.backends = []
216 | module.backends.append(backend)
217 |
218 | for v in non_recurse_functions:
219 | torch._dynamo.allow_in_graph(v)
220 |
221 | backend = _compose_backends(module.backends)
222 |
223 | def new_forward(*args: Any, **kwargs: Any) -> Any:
224 | if module.rerun_transform:
225 | torch._dynamo.reset()
226 | dynamo_module = torch._dynamo.optimize(backend)(module)
227 | module.dynamo_forward = patch_to_expand_modules(dynamo_module.forward)
228 | module.rerun_transform = False
229 | with patch.object(module, "forward", module.base_forward):
230 | return module.dynamo_forward(*args, **kwargs)
231 |
232 | module.rerun_transform = True
233 | module.base_forward = getattr(module, "base_forward", module.forward)
234 | module.forward = new_forward
235 | return module
236 |
237 |
238 | __all__ = generate__all__(__name__)
239 |
--------------------------------------------------------------------------------
/unit_scaling/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2 |
3 | """Utility functions for developing unit-scaled models."""
4 |
5 | import ast
6 | import math
7 | import re
8 | import typing
9 | from collections import OrderedDict
10 | from dataclasses import dataclass
11 | from types import FunctionType, ModuleType
12 | from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union, cast
13 |
14 | import einops
15 | import torch
16 | from pygments import highlight
17 | from pygments.formatters import TerminalFormatter
18 | from pygments.lexers import PythonLexer
19 | from torch import Tensor, fx, nn
20 |
21 | from . import functional
22 | from ._internal_utils import generate__all__
23 |
24 |
25 | @dataclass
26 | class ScalePair:
27 | """Dataclass containing a pair of scalars, intended to represent the standard
28 | deviation of an arbitrary tensor in the forward and backward passes."""
29 |
30 | forward: Optional[float] = None
31 | backward: Optional[float] = None
32 |
33 | def __str__(self) -> str:
34 | fwd = f"{self.forward:.3}" if self.forward is not None else "n/a"
35 | bwd = f"{self.backward:.3}" if self.backward is not None else "n/a"
36 | return f"(-> {fwd}, <- {bwd})"
37 |
38 |
39 | ScaleDict = typing.OrderedDict[str, ScalePair]
40 |
41 |
42 | class ScaleTracker(torch.autograd.Function):
43 | """Given a `nn.Tensor`, records its standard deviation in the forward and
44 | backward pass in the supplied `ScalePair`."""
45 |
46 | @staticmethod
47 | def forward(
48 | ctx: torch.autograd.function.FunctionCtx,
49 | t: Tensor,
50 | scale_tracker: ScalePair,
51 | ) -> Tensor:
52 | scale_tracker.forward = float(t.std())
53 | ctx.scale_tracker = scale_tracker # type: ignore
54 | return t
55 |
56 | @staticmethod
57 | def backward( # type:ignore[override]
58 | ctx: torch.autograd.function.FunctionCtx, t: Tensor
59 | ) -> Tuple[Tensor, None, None]:
60 | ctx.scale_tracker.backward = float(t.std()) # type: ignore
61 | return t, None, None
62 |
63 | @staticmethod
64 | def track(t: Tensor, scale_tracker: ScalePair) -> Tensor:
65 | # Add typing information to `apply()` method from `torch.autograd.Function`
66 | apply = cast(Callable[[Tensor, ScalePair], Tensor], ScaleTracker.apply)
67 | return apply(t, scale_tracker)
68 |
69 |
70 | class ScaleTrackingInterpreter(fx.Interpreter):
71 | """Wraps an `fx.GraphModule` such than when executed it records the standard
72 | deviation of every intermediate `nn.Tensor` in the forward and backward pass.
73 |
74 | Args:
75 | module (fx.GraphModule): the module to be instrumented.
76 | """
77 |
78 | def __init__(self, module: fx.GraphModule):
79 | super().__init__(module)
80 | self.scales: typing.OrderedDict[str, ScalePair] = OrderedDict()
81 |
82 | def run_node(self, n: fx.Node) -> Any:
83 | out = super().run_node(n)
84 | if isinstance(out, Tensor) and out.is_floating_point():
85 | scale_pair = ScalePair()
86 | out = ScaleTracker.track(out, scale_pair)
87 | self.scales[n.name] = scale_pair
88 | return out
89 |
90 | def call_function(
91 | self, target: fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]
92 | ) -> Any:
93 | return super().call_function(target, args, kwargs)
94 |
95 | def placeholder(
96 | self,
97 | target: fx.node.Target,
98 | args: Tuple[fx.node.Argument, ...],
99 | kwargs: Dict[str, Any],
100 | ) -> Any:
101 | """To handle functions being passed as arguments (for example constraints) the
102 | tracer represents them as placeholder nodes. This method extracts the original
103 | function from the node, as stored in the `target_to_function` dict."""
104 | if isinstance(target, str) and target.startswith("function_placeholder__"):
105 | return self.module.graph._tracer_extras["target_to_function"][
106 | target
107 | ] # pragma: no cover
108 | return super().placeholder(target, args, kwargs)
109 |
110 |
111 | def _record_scales(
112 | fx_graph_module: fx.GraphModule,
113 | inputs: Tuple[Tensor, ...],
114 | backward: Optional[Tensor] = None,
115 | ) -> ScaleDict:
116 | """Given a `torch.fx.GraphModule`, and dummy tensors to feed into the forward and
117 | backward passes, returns a dictionary of the scales (standard deviations) of every
118 | intermediate tensor in the model (forward and backward pass).
119 |
120 | Args:
121 | fx_graph_module (fx.GraphModule): the module to record.
122 | input (Tuple[Tensor, ...]): fed into the forward pass for analysis.
123 | backward (Tensor, optional): fed into the output's `.backward()` method for
124 | analysis. Defaults to `None`, equivalent to calling plain `.backward()`.
125 |
126 | Returns:
127 | ScaleDict: An ordered dictionary with `ScalePair`s for each intermediate tensor.
128 | """
129 | tracking_module = ScaleTrackingInterpreter(fx_graph_module)
130 | out = tracking_module.run(*inputs)
131 | out.backward(backward)
132 | return tracking_module.scales
133 |
134 |
135 | def _annotate(code: str, scales: ScaleDict, syntax_highlight: bool) -> str:
136 | """Given a string representation of some code and an `ScaleDict` with accompanying
137 | scales, annotates the code to include the scales on the right-hand side."""
138 |
139 | function_placeholder_regex = r"function_placeholder__(\w+)"
140 |
141 | def is_function_placeholder_line(code_line: str) -> bool:
142 | return bool(re.search(f" = {function_placeholder_regex}$", code_line))
143 |
144 | def cleanup_function_signature(code_line: str) -> str:
145 | code_line = re.sub(f", {function_placeholder_regex}", "", code_line)
146 | inner_code_line = code_line.split("(", 1)[1]
147 | replacement = re.sub(r"_([a-zA-Z0-9_]+)_", r"\1", inner_code_line)
148 | return code_line.replace(inner_code_line, replacement)
149 |
150 | def annotate_line(code_line: str) -> str:
151 | if code_line.startswith("torch.fx._symbolic_trace.wrap"):
152 | return ""
153 | code_line = code_line.split(";")[0]
154 | if is_function_placeholder_line(code_line): # pragma: no cover
155 | return ""
156 | words = code_line.strip().split(" ")
157 | if words:
158 | if words[0] in scales:
159 | return f"{code_line}; {scales[words[0]]}"
160 | elif words[0] == "def":
161 | parsed = ast.parse(code_line + "\n\t...").body[0]
162 | assert isinstance(parsed, ast.FunctionDef)
163 | arg_names = [arg.arg for arg in parsed.args.args]
164 | scale_strs = [str(scales[a]) for a in arg_names if a in scales]
165 | code_line = cleanup_function_signature(code_line)
166 | if scale_strs:
167 | return f"{code_line} {', '.join(scale_strs)}" # pragma: no cover
168 | else:
169 | return code_line
170 | return code_line
171 |
172 | def remove_empty_lines(code_lines: Iterator[str]) -> Iterator[str]:
173 | return (line for line in code_lines if line.strip())
174 |
175 | code_lines = map(annotate_line, code.splitlines())
176 | code = "\n".join(remove_empty_lines(code_lines)).strip()
177 | code = code.replace("unit_scaling_functional_", "U.")
178 | if syntax_highlight:
179 | return highlight(code, PythonLexer(), TerminalFormatter()) # pragma: no cover
180 | return code
181 |
182 |
183 | class _DeepTracer(fx.Tracer):
184 | """Version of `torch.fx.Tracer` which recurses into all sub-modules (if specified).
185 |
186 | Args:
187 | recurse_modules (bool): toggles recursive behavour. Defaults to True.
188 | autowrap_modules (Tuple[ModuleType]): defaults to
189 | `(math, einops, U.functional)`,
190 | Python modules whose functions should be wrapped automatically
191 | without needing to use fx.wrap().
192 | autowrap_function (Tuple[Callable, ...]): defaults to `()`,
193 | Python functions that should be wrapped automatically without
194 | needing to use fx.wrap().
195 | """
196 |
197 | def __init__(
198 | self,
199 | recurse_modules: bool = True,
200 | autowrap_modules: Tuple[ModuleType, ...] = (math, einops, functional),
201 | autowrap_functions: Tuple[Callable[..., Any], ...] = (),
202 | ) -> None:
203 | super().__init__(
204 | autowrap_modules=autowrap_modules, # type: ignore[arg-type]
205 | autowrap_functions=autowrap_functions,
206 | )
207 | self.recurse_modules = recurse_modules
208 | self.target_to_function: Dict[str, FunctionType] = {}
209 | self.function_to_node: Dict[FunctionType, fx.Node] = {}
210 | # Fixes: `TypeError: __annotations__ must be set to a dict object`
211 | if id(FunctionType) in self._autowrap_function_ids:
212 | self._autowrap_function_ids.remove(id(FunctionType))
213 |
214 | def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
215 | return not self.recurse_modules
216 |
217 | def create_arg(self, a: Any) -> fx.node.Argument:
218 | """Replaces callable arguments with strings for tracing."""
219 | if isinstance(a, FunctionType): # pragma: no cover
220 | node = self.function_to_node.get(a)
221 | if node is None:
222 | assert hasattr(
223 | a, "__name__"
224 | ), f"can't create arg for unnamed function: {a}"
225 | name = getattr(a, "__name__")
226 | target = f"function_placeholder__{name}"
227 | node = self.create_node("placeholder", target, (), {}, name)
228 | self.target_to_function[target] = a
229 | self.function_to_node[a] = node
230 | return node
231 | return super().create_arg(a)
232 |
233 | def trace(
234 | self,
235 | root: Union[torch.nn.Module, Callable[..., Any]],
236 | concrete_args: Optional[Dict[str, Any]] = None,
237 | ) -> fx.Graph:
238 | """Adds the `target_to_function` dict to the graph so the interpreter can use it
239 | downstream."""
240 | graph = super().trace(root, concrete_args)
241 | if not hasattr(graph, "_tracer_extras") or graph._tracer_extras is None:
242 | graph._tracer_extras = {}
243 | graph._tracer_extras["target_to_function"] = self.target_to_function
244 | return graph
245 |
246 |
247 | def analyse_module(
248 | module: nn.Module,
249 | inputs: Union[Tensor, Tuple[Tensor, ...]],
250 | backward: Optional[Tensor] = None,
251 | recurse_modules: bool = True,
252 | syntax_highlight: bool = True,
253 | autowrap_modules: Tuple[ModuleType, ...] = (math, einops, functional),
254 | autowrap_functions: Tuple[Callable[..., Any], ...] = (),
255 | ) -> str:
256 | """Given a `nn.Module` and dummy forward and backward tensors, generates code
257 | representing the module annotated with the scales (standard deviation) of each
258 | tensor in both forward and backward passes. Implemented using `torch.fx`.
259 |
260 | Args:
261 | module (nn.Module): the module to analyse.
262 | inputs (Union[Tensor, Tuple[Tensor, ...]]): fed into the forward pass for
263 | analysis.
264 | backward (Tensor, optional): fed into the output's `.backward()` method for
265 | analysis. Defaults to `None`, equivalent to calling plain `.backward()`.
266 | recurse_modules (bool, optional): toggles recursive behavour. Defaults to True.
267 | syntax_highlight (bool, optional): Defaults to True.
268 | autowrap_modules (Tuple[ModuleType]): defaults to
269 | `(math, einops, U.functional)`,
270 | Python modules whose functions should be wrapped automatically
271 | without needing to use fx.wrap().
272 | autowrap_function (Tuple[Callable, ...]): defaults to `()`,
273 | Python functions that should be wrapped automatically without
274 | needing to use fx.wrap().
275 |
276 | Returns:
277 | str:
278 | a code string representing the operations in the module with scale
279 | annotations for each tensor, reflecting their standard deviations in the
280 | forward and backward passes.
281 |
282 | Examples::
283 |
284 | >>> class MLP(nn.Module):
285 | >>> def __init__(self, d):
286 | >>> super().__init__()
287 | >>> self.fc1 = nn.Linear(d, d * 4)
288 | >>> self.relu = nn.ReLU()
289 | >>> self.fc2 = nn.Linear(d * 4, d)
290 |
291 | >>> def forward(self, x):
292 | >>> x = self.fc1(x)
293 | >>> x = self.relu(x)
294 | >>> x = self.fc2(x)
295 | >>> return x
296 |
297 |
298 | >>> hidden_size = 2**10
299 | >>> x = torch.randn(hidden_size, hidden_size).requires_grad_()
300 | >>> bwd = torch.randn(hidden_size, hidden_size)
301 |
302 | >>> code = analyse_module(MLP(hidden_size), x, bwd)
303 | >>> print(code)
304 | def forward(self, x): (-> 1.0, <- 0.236)
305 | fc1_weight = self.fc1.weight; (-> 0.018, <- 6.54)
306 | fc1_bias = self.fc1.bias; (-> 0.0182, <- 6.51)
307 | linear = torch._C._nn.linear(x, fc1_weight, fc1_bias); (-> 0.578, <- 0.204)
308 | relu = torch.nn.functional.relu(linear, inplace = False); (-> 0.337, <- 0.288)
309 | fc2_weight = self.fc2.weight; (-> 0.00902, <- 13.0)
310 | fc2_bias = self.fc2.bias; (-> 0.00904, <- 31.6)
311 | linear_1 = torch._C._nn.linear(relu, fc2_weight, fc2_bias); (-> 0.235, <- 0.999)
312 | return linear_1
313 | """ # noqa: E501
314 | tracer = _DeepTracer(recurse_modules, autowrap_modules, autowrap_functions)
315 | fx_graph = tracer.trace(module)
316 | fx_graph_module = fx.GraphModule(tracer.root, fx_graph)
317 |
318 | if not isinstance(inputs, tuple):
319 | inputs = (inputs,)
320 | scales = _record_scales(fx_graph_module, inputs, backward)
321 | return _annotate(fx_graph_module.code, scales, syntax_highlight=syntax_highlight)
322 |
323 |
324 | __all__ = generate__all__(__name__)
325 |
--------------------------------------------------------------------------------