├── .github ├── FUNDING.yml └── workflows │ ├── create-release.yml │ ├── publish-package.yml │ └── run-tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── docs ├── blog.md ├── images │ └── stateful-transforms.png └── tiny_nnx.ipynb ├── examples ├── 00_demo.ipynb ├── 01_functional_api.py ├── 02_lifted_transforms.py ├── 03_train_state.py ├── 04_pure.py ├── 05_vae.py ├── 06_scan_over_layers.py ├── 07_transformer.py ├── 08_save_load_checkpoints.py └── 09_parameter_surgery.py ├── ideas ├── nnx_example.py ├── pure │ ├── __init__.py │ ├── full │ │ ├── partitioning_full.py │ │ └── state_full.py │ ├── module.py │ ├── partitioning.py │ ├── rngs.py │ └── state.py ├── pure_example.py ├── pure_nnx_example.py ├── pure_pytree │ ├── __init__.py │ ├── dataclass.py │ ├── full │ │ ├── partitioning_full.py │ │ └── state_full.py │ ├── module.py │ ├── partitioning.py │ └── rngs.py ├── pure_pytree_example.py └── shape_inference.py ├── nnx ├── __init__.py ├── containers.py ├── contextlib.py ├── dataclasses.py ├── errors.py ├── helpers.py ├── ids.py ├── module.py ├── nn │ ├── __init__.py │ ├── activations.py │ ├── dtypes.py │ ├── initializers.py │ ├── linear.py │ ├── normalization.py │ └── stochastic.py ├── nodes.py ├── partitioning.py ├── pytreelib.py ├── reprlib.py ├── spmd.py ├── state.py ├── tracers.py └── transforms.py ├── poetry.lock ├── pyproject.toml ├── scripts ├── deploy-docs.sh ├── run-all-examples.bash └── update_version.py └── tests ├── __init__.py ├── test_containers.py ├── test_context.py ├── test_helpers.py ├── test_ids.py ├── test_integration.py ├── test_module.py ├── test_partitioning.py ├── test_pytree.py ├── test_spmd.py ├── test_transforms.py └── test_variable.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [cgarciae] 4 | # patreon: # Replace with a single Patreon username 5 | # open_collective: # Replace with a single Open Collective username 6 | # ko_fi: # Replace with a single Ko-fi username 7 | # tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | # community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | # liberapay: # Replace with a single Liberapay username 10 | # issuehunt: # Replace with a single IssueHunt username 11 | # otechie: # Replace with a single Otechie username 12 | # lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | # custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/workflows/create-release.yml: -------------------------------------------------------------------------------- 1 | name: Create Release 2 | 3 | on: 4 | create 5 | 6 | jobs: 7 | create-release: 8 | if: startsWith(github.ref_name, 'version-') && endsWith(github.ref_name, '-create-release') 9 | name: Create Release 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Check out the code 13 | uses: actions/checkout@v3 14 | 15 | - name: Set up Python 3.9 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: 3.9 19 | 20 | - name: Setup 21 | id: setup 22 | run: | 23 | # install python dependencies 24 | pip install typer==0.4.0 click==8.0.3 25 | 26 | # variables 27 | RELEASE_VERSION='${{ github.ref_name }}' 28 | RELEASE_VERSION=${RELEASE_VERSION//version-/} 29 | RELEASE_VERSION=${RELEASE_VERSION//-create-release/} 30 | echo "::set-output name=RELEASE_VERSION::${RELEASE_VERSION}" 31 | 32 | 33 | - name: Update version 34 | run: | 35 | RELEASE_VERSION='${{ steps.setup.outputs.RELEASE_VERSION }}' 36 | 37 | # setup git 38 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 39 | git config --local user.name "github-actions[bot]" 40 | 41 | # switch to main 42 | git pull origin main 43 | git checkout main 44 | 45 | # update version 46 | python scripts/update_version.py $RELEASE_VERSION 47 | git commit -am "Update version to $RELEASE_VERSION" 48 | 49 | # create tag 50 | git fetch --tags 51 | git tag $RELEASE_VERSION 52 | 53 | # push to main 54 | git push 55 | git push --tags 56 | 57 | # delete branch 58 | git push -d origin ${{ github.ref_name }} 59 | 60 | - name: Create Release 61 | uses: actions/create-release@v1 62 | with: 63 | tag_name: ${{ steps.setup.outputs.RELEASE_VERSION }} 64 | release_name: ${{ steps.setup.outputs.RELEASE_VERSION }} 65 | draft: true 66 | env: 67 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/publish-package.yml: -------------------------------------------------------------------------------- 1 | name: Publish Package 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | publish-docs-and-package: 7 | name: Publish Docs and Package 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Check out the code 11 | uses: actions/checkout@v3 12 | 13 | - name: Set up Python 3.9 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.9 17 | 18 | - name: Install Poetry 19 | uses: snok/install-poetry@v1.3.3 20 | with: 21 | version: 1.4.0 22 | 23 | - name: Setup Poetry 24 | run: | 25 | poetry config virtualenvs.in-project true 26 | 27 | - name: Cache 28 | id: cache 29 | uses: actions/cache@v3.2.2 30 | with: 31 | path: '.venv' 32 | key: publish-package-${{ hashFiles('poetry.lock') }} 33 | 34 | - name: Install Dependencies 35 | if: steps.cache.outputs.cache-hit != 'true' 36 | run: | 37 | poetry install --without dev 38 | 39 | - name: Install Package 40 | run: | 41 | poetry install --without dev 42 | 43 | # ---------------------------------------- 44 | # No docs for now 45 | # ---------------------------------------- 46 | # - name: Build Docs 🔨 47 | # run: | 48 | # cp README.md docs/index.md 49 | # poetry run mkdocs build 50 | 51 | # - name: Deploy Page 🚀 52 | # uses: JamesIves/github-pages-deploy-action@4.1.6 53 | # with: 54 | # branch: gh-pages 55 | # folder: site 56 | 57 | - name: Publish to PyPI 58 | run: | 59 | poetry build 60 | poetry publish \ 61 | --username ${{ secrets.PYPI_USERNAME }} \ 62 | --password ${{ secrets.PYPI_PASSWORD }} 63 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | # Checks that we can build and validate the Unittest 2 | name: Run Tests 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v4 14 | - uses: pre-commit/action@v3.0.0 15 | test: 16 | name: Run Tests 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: ['3.9', '3.10', '3.11'] 21 | steps: 22 | - name: Check out the code 23 | uses: actions/checkout@v3 24 | with: 25 | fetch-depth: 1 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Install Poetry 32 | uses: snok/install-poetry@v1.3.3 33 | with: 34 | version: 1.4.0 35 | 36 | - name: Setup Poetry 37 | run: | 38 | poetry config virtualenvs.in-project true 39 | 40 | - name: Cache 41 | id: cache 42 | uses: actions/cache@v3.2.2 43 | with: 44 | path: '.venv' 45 | key: run-tests-${{ hashFiles('poetry.lock') }} 46 | 47 | - name: Install Dependencies 48 | if: steps.cache.outputs.cache-hit != 'true' 49 | run: | 50 | if [ -d ".venv" ]; then rm -rf .venv; fi 51 | poetry install 52 | 53 | - name: Install Package 54 | run: | 55 | poetry install 56 | 57 | - name: Run Tests 58 | run: poetry run pytest --cov=nnx --cov-report=term-missing --cov-report=xml 59 | 60 | - name: Upload coverage 61 | uses: codecov/codecov-action@v3 62 | 63 | test-import: 64 | name: Test Import without Dev Dependencies 65 | if: ${{ !contains(github.event.pull_request.title, 'WIP') }} 66 | runs-on: ubuntu-latest 67 | strategy: 68 | matrix: 69 | python-version: ['3.9', '3.10', '3.11'] 70 | steps: 71 | - name: Check out the code 72 | uses: actions/checkout@v3 73 | with: 74 | fetch-depth: 1 75 | - name: Set up Python ${{ matrix.python-version }} 76 | uses: actions/setup-python@v4 77 | with: 78 | python-version: ${{ matrix.python-version }} 79 | 80 | - name: Install Poetry 81 | uses: snok/install-poetry@v1.3.3 82 | with: 83 | version: 1.4.0 84 | 85 | - name: Setup Poetry 86 | run: | 87 | poetry config virtualenvs.in-project true 88 | 89 | - name: Cache 90 | id: cache 91 | uses: actions/cache@v3.2.2 92 | with: 93 | path: '.venv' 94 | key: test-import-${{ hashFiles('poetry.lock') }} 95 | 96 | - name: Install Dependencies 97 | if: steps.cache.outputs.cache-hit != 'true' 98 | run: | 99 | poetry install --only main 100 | 101 | - name: Install Package 102 | run: | 103 | poetry install --only main 104 | 105 | - name: Test Import 106 | run: | 107 | poetry run python -c "import nnx" 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # project specific 132 | .vscode 133 | /tmp -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | 2 | repos: 3 | - repo: https://github.com/google/pyink 4 | rev: 23.3.0 5 | hooks: 6 | - id: pyink 7 | - repo: https://github.com/pycqa/isort 8 | rev: 5.12.0 9 | hooks: 10 | - id: isort 11 | args: ["--profile", "black"] 12 | 13 | 14 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests" 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true, 7 | "python.formatting.blackPath": "pyink", 8 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Cristian Garcia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Moved to [flax/nnx](https://github.com/google/flax/tree/main/flax/nnx) 2 | -------------------------------------------------------------------------------- /docs/blog.md: -------------------------------------------------------------------------------- 1 | ### Do we need another JAX NN library? 2 | 3 | Hello, today I want to talk to you about a new JAX library that I have been working on, but before I do that, I wanted to discuss the topic: Do we need another JAX NN library? 4 | 5 | ### JAX Libraries 6 | 7 | JAX NN libraries come in a wide variety ranging from functional like Flax and Haiku, to Pytree-based like Equinox. -------------------------------------------------------------------------------- /docs/images/stateful-transforms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/3e4c750791ea39d26f5667ed26066fa6c13ff46b/docs/images/stateful-transforms.png -------------------------------------------------------------------------------- /examples/00_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" 13 | ] 14 | }, 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "Linear(\n", 20 | " din=2,\n", 21 | " dout=2\n", 22 | ")\n", 23 | "[[0.63114893 1.2928092 ]\n", 24 | " [0.63114893 1.2928092 ]]\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "import numpy as np\n", 30 | "import nnx\n", 31 | "import jax\n", 32 | "import jax.numpy as jnp\n", 33 | "\n", 34 | "\n", 35 | "class Linear(nnx.Module):\n", 36 | "\n", 37 | " def __init__(self, din: int, dout: int, *, ctx: nnx.Context):\n", 38 | " # static attributes\n", 39 | " self.din = din\n", 40 | " self.dout = dout\n", 41 | " # variables\n", 42 | " self.w = nnx.Param(jax.random.uniform(ctx.make_rng(\"params\"), (din, dout)))\n", 43 | " self.b = nnx.Param(jnp.zeros((dout,)))\n", 44 | " # other state\n", 45 | " self.jax_array = jnp.array(1)\n", 46 | " self.numpy_array = np.array(1)\n", 47 | "\n", 48 | " def __call__(self, x):\n", 49 | " return x @ self.w + self.b\n", 50 | "\n", 51 | "\n", 52 | "linear = Linear(2, 2, ctx=nnx.context(0))\n", 53 | "\n", 54 | "y = linear(jnp.ones((2, 2)))\n", 55 | "\n", 56 | "print(linear)\n", 57 | "print(y)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "State({\n", 70 | " 'b': Param(\n", 71 | " sharding=None,\n", 72 | " value=Array([0., 0.], dtype=float32)\n", 73 | " ),\n", 74 | " 'jax_array': Array(1, dtype=int32, weak_type=True),\n", 75 | " 'numpy_array': array(1),\n", 76 | " 'w': Param(\n", 77 | " sharding=None,\n", 78 | " value=Array([[0.31696808, 0.55285215],\n", 79 | " [0.31418085, 0.7399571 ]], dtype=float32)\n", 80 | " )\n", 81 | "})\n", 82 | "ModuleDef(\n", 83 | " type=Linear,\n", 84 | " index=0,\n", 85 | " submodules=(),\n", 86 | " static_fields=(('din', 2), ('dout', 2))\n", 87 | ")\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "state, moduledef = linear.partition()\n", 93 | "\n", 94 | "print(state)\n", 95 | "print(moduledef)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "Linear(\n", 108 | " din=2,\n", 109 | " dout=2,\n", 110 | " submodule=Linear(...)\n", 111 | ")\n", 112 | "[[0.63114893 1.2928092 ]\n", 113 | " [0.63114893 1.2928092 ]]\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "class Linear(nnx.Module):\n", 119 | "\n", 120 | " def __init__(self, din: int, dout: int, *, ctx: nnx.Context):\n", 121 | " self.din = din\n", 122 | " self.dout = dout\n", 123 | " self.w = nnx.Param(jax.random.uniform(ctx.make_rng(\"params\"), (din, dout)))\n", 124 | " self.b = nnx.Param(jnp.zeros((dout,)))\n", 125 | " # introduce a self-reference\n", 126 | " self.submodule = self\n", 127 | "\n", 128 | " def __call__(self, x):\n", 129 | " return x @ self.submodule.w + self.submodule.b\n", 130 | "\n", 131 | "\n", 132 | "linear = Linear(2, 2, ctx=nnx.context(0))\n", 133 | "\n", 134 | "y = linear(jnp.ones((2, 2)))\n", 135 | "\n", 136 | "print(linear)\n", 137 | "print(y)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "State({\n", 150 | " 'b': Param(\n", 151 | " sharding=None,\n", 152 | " value=Array([0., 0.], dtype=float32)\n", 153 | " ),\n", 154 | " 'w': Param(\n", 155 | " sharding=None,\n", 156 | " value=Array([[0.31696808, 0.55285215],\n", 157 | " [0.31418085, 0.7399571 ]], dtype=float32)\n", 158 | " )\n", 159 | "})\n", 160 | "ModuleDef(\n", 161 | " type=Linear,\n", 162 | " index=0,\n", 163 | " submodules=(\n", 164 | " ('submodule', 0)\n", 165 | " ),\n", 166 | " static_fields=(('din', 2), ('dout', 2))\n", 167 | ")\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "state, moduledef = linear.partition()\n", 173 | "\n", 174 | "print(state)\n", 175 | "print(moduledef)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 5, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/plain": [ 186 | "True" 187 | ] 188 | }, 189 | "execution_count": 5, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "linear2 = moduledef.merge(state)\n", 196 | "\n", 197 | "linear2.submodule is linear2" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 6, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "Linear(\n", 210 | " din=2,\n", 211 | " dout=2\n", 212 | ")\n", 213 | "[[0.63114893 1.2928092 ]\n", 214 | " [0.63114893 1.2928092 ]]\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "class Linear(nnx.Module):\n", 220 | "\n", 221 | " def __init__(self, din: int, dout: int, *, ctx: nnx.Context):\n", 222 | " # static attributes\n", 223 | " self.din = din\n", 224 | " self.dout = dout\n", 225 | " # variables\n", 226 | " self.w = nnx.Param(jax.random.uniform(ctx.make_rng(\"params\"), (din, dout)))\n", 227 | " self.b = nnx.Param(jnp.zeros((dout,)))\n", 228 | "\n", 229 | " def __call__(self, x):\n", 230 | " y = x @ self.w + self.b\n", 231 | " self.y = nnx.Intermediate(y)\n", 232 | " return y\n", 233 | "\n", 234 | "\n", 235 | "linear = Linear(2, 2, ctx=nnx.context(0))\n", 236 | "\n", 237 | "y = linear(jnp.ones((2, 2)))\n", 238 | "\n", 239 | "print(linear)\n", 240 | "print(y)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 8, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "State({\n", 253 | " 'y': Intermediate(\n", 254 | " sharding=None,\n", 255 | " value=Array([[0.63114893, 1.2928092 ],\n", 256 | " [0.63114893, 1.2928092 ]], dtype=float32)\n", 257 | " )\n", 258 | "})\n", 259 | "State({\n", 260 | " 'b': Param(\n", 261 | " sharding=None,\n", 262 | " value=Array([0., 0.], dtype=float32)\n", 263 | " ),\n", 264 | " 'w': Param(\n", 265 | " sharding=None,\n", 266 | " value=Array([[0.31696808, 0.55285215],\n", 267 | " [0.31418085, 0.7399571 ]], dtype=float32)\n", 268 | " )\n", 269 | "})\n" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "intermediates = linear.pop_state(nnx.Intermediate)\n", 275 | "state, moduledef = linear.partition()\n", 276 | "\n", 277 | "print(intermediates)\n", 278 | "print(state)" 279 | ] 280 | } 281 | ], 282 | "metadata": { 283 | "kernelspec": { 284 | "display_name": ".venv", 285 | "language": "python", 286 | "name": "python3" 287 | }, 288 | "language_info": { 289 | "codemirror_mode": { 290 | "name": "ipython", 291 | "version": 3 292 | }, 293 | "file_extension": ".py", 294 | "mimetype": "text/x-python", 295 | "name": "python", 296 | "nbconvert_exporter": "python", 297 | "pygments_lexer": "ipython3", 298 | "version": "3.9.16" 299 | }, 300 | "orig_nbformat": 4 301 | }, 302 | "nbformat": 4, 303 | "nbformat_minor": 2 304 | } 305 | -------------------------------------------------------------------------------- /examples/01_functional_api.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import jax 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | import nnx 8 | 9 | X = np.linspace(0, 1, 100)[:, None] 10 | Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) 11 | 12 | 13 | def dataset(batch_size): 14 | while True: 15 | idx = np.random.choice(len(X), size=batch_size) 16 | yield X[idx], Y[idx] 17 | 18 | 19 | class Linear(nnx.Module): 20 | 21 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 22 | self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) 23 | self.b = nnx.Param(jnp.zeros((dout,))) 24 | 25 | def __call__(self, x): 26 | return x @ self.w + self.b 27 | 28 | 29 | class MLP(nnx.Module): 30 | 31 | def __init__(self, din, dhidden, dout, *, ctx: nnx.Context): 32 | self.count = jnp.array(0) 33 | self.linear1 = Linear(din, dhidden, ctx=ctx) 34 | self.linear2 = Linear(dhidden, dout, ctx=ctx) 35 | 36 | def __call__(self, x): 37 | self.count += 1 38 | x = self.linear1(x) 39 | x = jax.nn.relu(x) 40 | x = self.linear2(x) 41 | return x 42 | 43 | 44 | (params, buffers), modeldef = MLP( 45 | din=1, dhidden=32, dout=1, ctx=nnx.context(0) 46 | ).partition(nnx.Param, ...) 47 | 48 | 49 | @jax.jit 50 | def train_step(params, buffers, batch): 51 | x, y = batch 52 | 53 | def loss_fn(params): 54 | y_pred, (updates, _) = modeldef.apply(params, buffers)(x) 55 | _state = updates.filter(nnx.buffers) 56 | loss = jnp.mean((y - y_pred) ** 2) 57 | return loss, _state 58 | 59 | grad, buffers = jax.grad(loss_fn, has_aux=True)(params) 60 | # |-------- sgd ---------| 61 | params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad) 62 | 63 | return params, buffers 64 | 65 | 66 | @jax.jit 67 | def test_step(params: nnx.State, buffers: nnx.State, batch): 68 | x, y = batch 69 | y_pred, _ = modeldef.apply(params, buffers)(x) 70 | loss = jnp.mean((y - y_pred) ** 2) 71 | return {"loss": loss} 72 | 73 | 74 | total_steps = 10_000 75 | for step, batch in enumerate(dataset(32)): 76 | params, buffers = train_step(params, buffers, batch) 77 | 78 | if step % 1000 == 0: 79 | logs = test_step(params, buffers, (X, Y)) 80 | print(f"step: {step}, loss: {logs['loss']}") 81 | 82 | if step >= total_steps - 1: 83 | break 84 | 85 | model = modeldef.merge(params, buffers) 86 | print("times called:", model.count) 87 | 88 | y_pred = model(X) 89 | 90 | plt.scatter(X, Y, color="blue") 91 | plt.plot(X, y_pred, color="black") 92 | plt.show() 93 | -------------------------------------------------------------------------------- /examples/02_lifted_transforms.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import jax 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | import nnx 8 | 9 | X = np.linspace(0, 1, 100)[:, None] 10 | Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) 11 | 12 | 13 | def dataset(batch_size): 14 | while True: 15 | idx = np.random.choice(len(X), size=batch_size) 16 | yield X[idx], Y[idx] 17 | 18 | 19 | class Linear(nnx.Module): 20 | 21 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 22 | self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) 23 | self.b = nnx.Param(jnp.zeros((dout,))) 24 | 25 | def __call__(self, x): 26 | return x @ self.w + self.b 27 | 28 | 29 | class Count(nnx.Variable): 30 | pass 31 | 32 | 33 | class MLP(nnx.Module): 34 | 35 | def __init__(self, din, dhidden, dout, *, ctx: nnx.Context): 36 | self.count = Count(jnp.array(0)) 37 | self.linear1 = Linear(din, dhidden, ctx=ctx) 38 | self.linear2 = Linear(dhidden, dout, ctx=ctx) 39 | 40 | def __call__(self, x): 41 | self.count += 1 42 | x = self.linear1(x) 43 | x = jax.nn.relu(x) 44 | x = self.linear2(x) 45 | return x 46 | 47 | 48 | model = MLP(din=1, dhidden=32, dout=1, ctx=nnx.context(0)) 49 | 50 | 51 | @nnx.jit 52 | def train_step(model: MLP, batch): 53 | x, y = batch 54 | 55 | def loss_fn(model: MLP): 56 | y_pred = model(x) 57 | return jnp.mean((y - y_pred) ** 2) 58 | 59 | # |--default--| 60 | grad: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) 61 | # sdg update 62 | model.update_state( 63 | jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grad) 64 | ) 65 | 66 | # no return!!! 67 | 68 | 69 | @nnx.jit 70 | def test_step(model: MLP, batch): 71 | x, y = batch 72 | y_pred = model(x) 73 | loss = jnp.mean((y - y_pred) ** 2) 74 | return {"loss": loss} 75 | 76 | 77 | total_steps = 10_000 78 | for step, batch in enumerate(dataset(32)): 79 | train_step(model, batch) 80 | 81 | if step % 1000 == 0: 82 | logs = test_step(model, (X, Y)) 83 | print(f"step: {step}, loss: {logs['loss']}") 84 | 85 | if step >= total_steps - 1: 86 | break 87 | 88 | print("times called:", model.count) 89 | 90 | y_pred = model(X) 91 | 92 | plt.scatter(X, Y, color="blue") 93 | plt.plot(X, y_pred, color="black") 94 | plt.show() 95 | -------------------------------------------------------------------------------- /examples/03_train_state.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import jax 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import optax 7 | from flax.training import train_state 8 | 9 | import nnx 10 | 11 | X = np.linspace(0, 1, 100)[:, None] 12 | Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) 13 | 14 | 15 | def dataset(batch_size): 16 | while True: 17 | idx = np.random.choice(len(X), size=batch_size) 18 | yield X[idx], Y[idx] 19 | 20 | 21 | class Linear(nnx.Module): 22 | 23 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 24 | self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) 25 | self.b = nnx.Param(jnp.zeros((dout,))) 26 | 27 | def __call__(self, x): 28 | return x @ self.w + self.b 29 | 30 | 31 | class MLP(nnx.Module): 32 | 33 | def __init__(self, din, dhidden, dout, *, ctx: nnx.Context): 34 | self.count = jnp.array(0) 35 | self.linear1 = Linear(din, dhidden, ctx=ctx) 36 | self.linear2 = Linear(dhidden, dout, ctx=ctx) 37 | 38 | def __call__(self, x): 39 | self.count += 1 40 | x = self.linear1(x) 41 | x = jax.nn.relu(x) 42 | x = self.linear2(x) 43 | return x 44 | 45 | 46 | (params, buffers), moduledef = MLP( 47 | din=1, dhidden=32, dout=1, ctx=nnx.context(0) 48 | ).partition(nnx.Param, ...) 49 | 50 | state = nnx.TrainState( 51 | moduledef, 52 | params=params, 53 | tx=optax.sgd(0.1), 54 | buffers=buffers, 55 | ) 56 | del params, buffers 57 | 58 | 59 | @jax.jit 60 | def train_step(state: nnx.TrainState, batch): 61 | x, y = batch 62 | 63 | def loss_fn(params): 64 | y_pred, (updates, _) = state.apply(params, "buffers")(x) 65 | buffers = updates.filter(nnx.buffers) 66 | loss = jnp.mean((y - y_pred) ** 2) 67 | return loss, buffers 68 | 69 | grads, buffers = jax.grad(loss_fn, has_aux=True)(state.params) 70 | # sdg update 71 | state = state.apply_gradients(grads=grads, buffers=buffers) 72 | 73 | return state 74 | 75 | 76 | @jax.jit 77 | def test_step(state: nnx.TrainState, batch): 78 | x, y = batch 79 | y_pred, _ = state.apply("params", "buffers")(x) 80 | loss = jnp.mean((y - y_pred) ** 2) 81 | return {"loss": loss} 82 | 83 | 84 | total_steps = 10_000 85 | for step, batch in enumerate(dataset(32)): 86 | state = train_step(state, batch) 87 | 88 | if step % 1000 == 0: 89 | logs = test_step(state, (X, Y)) 90 | print(f"step: {step}, loss: {logs['loss']}") 91 | 92 | if step >= total_steps - 1: 93 | break 94 | 95 | model = moduledef.merge(state.params, state.buffers) 96 | print("times called:", model.count) 97 | 98 | y_pred = model(X) 99 | 100 | plt.scatter(X, Y, color="blue") 101 | plt.plot(X, y_pred, color="black") 102 | plt.show() 103 | -------------------------------------------------------------------------------- /examples/04_pure.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import jax 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | import nnx 8 | 9 | X = np.linspace(0, 1, 100)[:, None] 10 | Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) 11 | 12 | 13 | def dataset(batch_size): 14 | while True: 15 | idx = np.random.choice(len(X), size=batch_size) 16 | yield X[idx], Y[idx] 17 | 18 | 19 | class Linear(nnx.Module): 20 | 21 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 22 | self.w = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) 23 | self.b = nnx.Param(jnp.zeros((dout,))) 24 | 25 | def __call__(self, x): 26 | return x @ self.w + self.b 27 | 28 | 29 | class Count(nnx.Variable): 30 | pass 31 | 32 | 33 | class MLP(nnx.Module): 34 | 35 | def __init__(self, din, dhidden, dout, *, ctx: nnx.Context): 36 | self.count = Count(jnp.array(0)) 37 | self.linear1 = Linear(din, dhidden, ctx=ctx) 38 | self.linear2 = Linear(dhidden, dout, ctx=ctx) 39 | 40 | def __call__(self, x) -> jax.Array: 41 | self.count += 1 42 | x = self.linear1(x) 43 | x = jax.nn.relu(x) 44 | x = self.linear2(x) 45 | return x 46 | 47 | 48 | pure_model = MLP(din=1, dhidden=32, dout=1, ctx=nnx.context(0)).partition() 49 | 50 | 51 | @jax.jit 52 | def train_step(pure_model: nnx.PureModule[MLP], batch): 53 | x, y = batch 54 | model = pure_model.merge() 55 | 56 | def loss_fn(model: MLP): 57 | y_pred = model(x) 58 | return jnp.mean((y - y_pred) ** 2) 59 | 60 | grad: nnx.State = nnx.grad(loss_fn)(model) 61 | # sdg update 62 | model.update_state( 63 | jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grad) 64 | ) 65 | 66 | return model.partition() 67 | 68 | 69 | @jax.jit 70 | def test_step(pure_model: nnx.PureModule[MLP], batch): 71 | x, y = batch 72 | y_pred = pure_model.call(x) 73 | loss = jnp.mean((y - y_pred) ** 2) 74 | return {"loss": loss} 75 | 76 | 77 | total_steps = 10_000 78 | for step, batch in enumerate(dataset(32)): 79 | pure_model = train_step(pure_model, batch) 80 | 81 | if step % 1000 == 0: 82 | logs = test_step(pure_model, (X, Y)) 83 | print(f"step: {step}, loss: {logs['loss']}") 84 | 85 | if step >= total_steps - 1: 86 | break 87 | 88 | model = pure_model.merge() 89 | print("times called:", model.count) 90 | 91 | y_pred = model(X) 92 | 93 | plt.scatter(X, Y, color="blue") 94 | plt.plot(X, y_pred, color="black") 95 | plt.show() 96 | -------------------------------------------------------------------------------- /examples/05_vae.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import typing as tp 3 | from functools import partial 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import optax 10 | from datasets import load_dataset 11 | 12 | import nnx 13 | 14 | np.random.seed(42) 15 | latent_size = 32 16 | image_shape: tp.Sequence[int] = (28, 28) 17 | steps_per_epoch: int = 200 18 | batch_size: int = 64 19 | epochs: int = 20 20 | 21 | 22 | dataset = load_dataset("mnist") 23 | X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8) 24 | X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8) 25 | # Now binarize data 26 | X_train = (X_train > 0).astype(jnp.float32) 27 | X_test = (X_test > 0).astype(jnp.float32) 28 | 29 | print("X_train:", X_train.shape, X_train.dtype) 30 | print("X_test:", X_test.shape, X_test.dtype) 31 | 32 | 33 | class Loss(nnx.Variable): 34 | pass 35 | 36 | 37 | # %% 38 | class Encoder(nnx.Module): 39 | 40 | def __init__(self, din: int, dmid: int, dout: int, *, ctx: nnx.Context): 41 | self.linear1 = nnx.Linear(din, dmid, ctx=ctx) 42 | self.linear_mean = nnx.Linear(dmid, dout, ctx=ctx) 43 | self.linear_std = nnx.Linear(dmid, dout, ctx=ctx) 44 | 45 | def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: 46 | x = x.reshape((x.shape[0], -1)) # flatten 47 | x = self.linear1(x) 48 | x = jax.nn.relu(x) 49 | 50 | mean = self.linear_mean(x) 51 | std = jnp.exp(self.linear_std(x)) 52 | 53 | self.kl_loss = Loss( 54 | jnp.mean( 55 | 0.5 * jnp.mean(-jnp.log(std**2) - 1.0 + std**2 + mean**2, axis=-1) 56 | ) 57 | ) 58 | key = ctx.make_rng("noise") 59 | z = mean + std * jax.random.normal(key, mean.shape) 60 | return z 61 | 62 | 63 | class Decoder(nnx.Module): 64 | 65 | def __init__(self, din: int, dmid: int, dout: int, *, ctx: nnx.Context): 66 | self.linear1 = nnx.Linear(din, dmid, ctx=ctx) 67 | self.linear2 = nnx.Linear(dmid, dout, ctx=ctx) 68 | 69 | def __call__(self, z: jax.Array) -> jax.Array: 70 | z = self.linear1(z) 71 | z = jax.nn.relu(z) 72 | logits = self.linear2(z) 73 | return logits 74 | 75 | 76 | class VAE(nnx.Module): 77 | 78 | def __init__( 79 | self, 80 | din: int, 81 | hidden_size: int, 82 | latent_size: int, 83 | output_shape: tp.Sequence[int], 84 | *, 85 | ctx: nnx.Context, 86 | ): 87 | self.output_shape = output_shape 88 | self.encoder = Encoder(din, hidden_size, latent_size, ctx=ctx) 89 | self.decoder = Decoder( 90 | latent_size, hidden_size, int(np.prod(output_shape)), ctx=ctx 91 | ) 92 | 93 | def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: 94 | z = self.encoder(x, ctx=ctx) 95 | logits = self.decoder(z) 96 | logits = jnp.reshape(logits, (-1, *self.output_shape)) 97 | return logits 98 | 99 | def generate(self, z): 100 | logits = self.decoder(z) 101 | logits = jnp.reshape(logits, (-1, *self.output_shape)) 102 | return nnx.sigmoid(logits) 103 | 104 | 105 | params, moduledef = VAE( 106 | din=int(np.prod(image_shape)), 107 | hidden_size=256, 108 | latent_size=latent_size, 109 | output_shape=image_shape, 110 | ctx=nnx.context(0), 111 | ).partition(nnx.Param) 112 | 113 | state = nnx.TrainState( 114 | moduledef, 115 | params=params, 116 | tx=optax.adam(1e-3), 117 | ) 118 | 119 | 120 | # %% 121 | @jax.jit 122 | def train_step(state: nnx.TrainState[VAE], x: jax.Array, key: jax.Array): 123 | def loss_fn(params: nnx.State): 124 | ctx = nnx.context(noise=jax.random.fold_in(key, state.step)) 125 | logits, (updates, _) = state.apply(params)(x, ctx=ctx) 126 | 127 | losses = updates.filter(Loss) 128 | kl_loss = sum(jax.tree_util.tree_leaves(losses), 0.0) 129 | reconstruction_loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, x)) 130 | 131 | loss = reconstruction_loss + 0.1 * kl_loss 132 | return loss, loss 133 | 134 | grad_fn = jax.grad(loss_fn, has_aux=True) 135 | grads, loss = grad_fn(state.params) 136 | state.apply_gradients(grads=grads) 137 | 138 | return state, loss 139 | 140 | 141 | @partial(jax.jit, donate_argnums=(0,)) 142 | def forward(state: nnx.TrainState[VAE], x: jax.Array, key: jax.Array) -> jax.Array: 143 | ctx = nnx.context(noise=key) 144 | y_pred = state.apply("params")(x, ctx=ctx)[0] 145 | return jax.nn.sigmoid(y_pred) 146 | 147 | 148 | @jax.jit 149 | def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array: 150 | return state.apply("params").generate(z)[0] 151 | 152 | 153 | # %% 154 | key = jax.random.PRNGKey(0) 155 | 156 | for epoch in range(epochs): 157 | losses = [] 158 | for step in range(steps_per_epoch): 159 | idxs = np.random.randint(0, len(X_train), size=(batch_size,)) 160 | x_batch = X_train[idxs] 161 | 162 | state, loss = train_step(state, x_batch, key) 163 | losses.append(np.asarray(loss)) 164 | 165 | print(f"Epoch {epoch} loss: {np.mean(losses)}") 166 | 167 | exit() 168 | # %% 169 | # get random samples 170 | idxs = np.random.randint(0, len(X_test), size=(5,)) 171 | x_sample = X_test[idxs] 172 | 173 | # get predictions 174 | y_pred = forward(state, x_sample, key) 175 | 176 | # plot reconstruction 177 | figure = plt.figure(figsize=(3 * 5, 3 * 2)) 178 | plt.title("Reconstruction Samples") 179 | for i in range(5): 180 | plt.subplot(2, 5, i + 1) 181 | plt.imshow(x_sample[i], cmap="gray") 182 | plt.subplot(2, 5, 5 + i + 1) 183 | plt.imshow(y_pred[i], cmap="gray") 184 | # # tbwriter.add_figure("VAE Example", figure, epochs) 185 | 186 | plt.show() 187 | 188 | # %% 189 | # plot generative samples 190 | z_samples = np.random.normal(scale=1.5, size=(12, latent_size)) 191 | samples = sample(state, z_samples) 192 | 193 | figure = plt.figure(figsize=(3 * 5, 3 * 2)) 194 | plt.title("Generative Samples") 195 | for i in range(5): 196 | plt.subplot(2, 5, 2 * i + 1) 197 | plt.imshow(samples[i], cmap="gray") 198 | plt.subplot(2, 5, 2 * i + 2) 199 | plt.imshow(samples[i + 1], cmap="gray") 200 | 201 | plt.show() 202 | 203 | # %% 204 | -------------------------------------------------------------------------------- /examples/06_scan_over_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | import nnx 7 | 8 | 9 | class Block(nnx.Module): 10 | 11 | def __init__(self, dim: int, *, ctx: nnx.Context): 12 | self.linear = nnx.Linear(dim, dim, ctx=ctx) 13 | self.dropout = nnx.Dropout(0.5) 14 | 15 | def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: 16 | x = self.linear(x) 17 | x = self.dropout(x, ctx=ctx) 18 | x = jax.nn.gelu(x) 19 | return x 20 | 21 | 22 | class ScanMLP(nnx.Module): 23 | """ 24 | An MLP that uses `vmap` during `__init__` to create a Block instance 25 | with an additional `layer` axis, and `scan` during `__call__` to apply 26 | the sequence of layers iteratively over the input / output `x`. 27 | """ 28 | 29 | def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context): 30 | self.n_layers = n_layers 31 | # partition Context and split the `params` key 32 | keys, ctxdef = ctx.partition() 33 | params_key = jax.random.split(keys["params"], n_layers) 34 | 35 | def create_block(params_key): 36 | # merge back Context using the sliced `params` key 37 | ctx = ctxdef.merge({"params": params_key}) 38 | # create Block instance and return its partition 39 | return Block(dim, ctx=ctx).partition() 40 | 41 | # call vmap over create_block, passing the split `params` key 42 | # and immediately merge to get a Block instance 43 | self.layers = jax.vmap(create_block)(params_key).merge() 44 | 45 | def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array: 46 | # partition Context and split the `dropout` key 47 | keys, ctxdef = ctx.partition() 48 | dropout_key = jax.random.split(keys["dropout"], self.n_layers) 49 | # partition Module to get params 50 | params, moduledef = self.layers.partition(nnx.Param) 51 | 52 | def scan_fn( 53 | x: jax.Array, inputs: Tuple[nnx.State, jax.Array] 54 | ) -> Tuple[jax.Array, nnx.State]: 55 | params, dropout_key = inputs 56 | # merge back Module and Context 57 | ctx = ctxdef.merge({"dropout": dropout_key}) 58 | module = moduledef.merge(params) 59 | # forward pass 60 | x = module(x, ctx=ctx) 61 | # partition state and return 62 | params, _ = module.partition(nnx.Param) 63 | return x, params 64 | 65 | # call scan passing x as the carry, and params + dropout_key as the input 66 | x, params = jax.lax.scan(scan_fn, x, (params, dropout_key)) 67 | # update layers state and return 68 | self.layers.update_state(params) 69 | return x 70 | 71 | 72 | model = ScanMLP(10, n_layers=5, ctx=nnx.context(0)) 73 | 74 | x = jnp.ones((3, 10)) 75 | y = model(x, ctx=nnx.context(dropout=1, flags=dict(deterministic=False))) 76 | 77 | print(jax.tree_map(jnp.shape, model.get_state())) 78 | print(y.shape) 79 | -------------------------------------------------------------------------------- /examples/07_transformer.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing as tp 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from jax.sharding import PartitionSpec as P 8 | 9 | import nnx 10 | 11 | ShardSpec = tp.Union[str, tp.Tuple[str, ...], None] 12 | 13 | 14 | # Sharding 15 | @dataclasses.dataclass 16 | class Sharding: 17 | batch: ShardSpec = "data" 18 | sequence: ShardSpec = None 19 | layers: ShardSpec = None 20 | vocab: ShardSpec = "model" 21 | embed: ShardSpec = None 22 | heads: ShardSpec = "model" 23 | depth: ShardSpec = None 24 | hidden: ShardSpec = "model" 25 | 26 | 27 | # Config 28 | @dataclasses.dataclass 29 | class Config: 30 | # mode 31 | decode: bool = False 32 | # shapes 33 | batch: int = 16 34 | layers: int = 2 35 | vocab: int = 1024 36 | embed: int = 64 37 | heads: int = 12 38 | depth: int = 64 39 | hidden: int = 256 40 | max_length: int = 256 41 | # dtypes 42 | param_dtype: tp.Any = jnp.float32 43 | dtype: tp.Any = jnp.float32 44 | # sharding 45 | sharding: Sharding = Sharding() 46 | scanned: bool = False 47 | # layer params 48 | epsilon: float = 1e-6 49 | dropout_rate: float = 0.0 50 | rp_num_buckets: int = 32 51 | rp_max_distance: int = 128 52 | 53 | 54 | cfg = Config() 55 | 56 | 57 | def nd_dense_init(scale, mode, distribution): 58 | """Initializer with in_axis, out_axis set at call time.""" 59 | 60 | def init_fn(key, shape, dtype, in_axis, out_axis) -> jax.Array: 61 | fn = jax.nn.initializers.variance_scaling( 62 | scale, mode, distribution, in_axis, out_axis 63 | ) 64 | return fn(key, shape, dtype) 65 | 66 | return init_fn 67 | 68 | 69 | dense_init = nd_dense_init(1.0, "fan_in", "truncated_normal") 70 | embed_init = nd_dense_init(1.0, "fan_in", "normal") 71 | 72 | 73 | def make_attention_mask( 74 | query_input: tp.Any, 75 | key_input: tp.Any, 76 | pairwise_fn: tp.Callable = jnp.multiply, 77 | dtype: tp.Any = jnp.float32, 78 | ): 79 | mask = pairwise_fn( 80 | jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) 81 | ) 82 | return jnp.expand_dims(mask, axis=-3).astype(dtype) 83 | 84 | 85 | def make_causal_mask(x, dtype=jnp.float32): 86 | idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) 87 | return make_attention_mask(idxs, idxs, jnp.greater_equal, dtype=dtype) 88 | 89 | 90 | # padding mask 91 | # make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype) 92 | # packing mask 93 | # make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype) 94 | 95 | 96 | def sine_table(features, length, min_timescale=1.0, max_timescale=10000.0): 97 | fraction = jnp.arange(0, features, 2, dtype=jnp.float32) / features 98 | timescale = min_timescale * (max_timescale / min_timescale) ** fraction 99 | rotational_frequency = 1.0 / timescale 100 | # Must use high precision einsum here, bfloat16 rounding is catastrophic. 101 | sinusoid_inp = jnp.einsum( 102 | "i,j->ij", 103 | jnp.arange(length), 104 | rotational_frequency, 105 | precision=jax.lax.Precision.HIGHEST, 106 | ) 107 | sinusoid_inp = jnp.concatenate([sinusoid_inp, sinusoid_inp], axis=-1) 108 | return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) 109 | 110 | 111 | def rotate_half(x): 112 | x1, x2 = jnp.split(x, 2, axis=-1) 113 | x = jnp.concatenate([-x2, x1], axis=-1) 114 | return x 115 | 116 | 117 | def apply_rotary_embedding(q, k, cos, sin, index=None): 118 | """Helper function to apply Rotary Embeddings.""" 119 | batch, qlen, qheads, d = q.shape 120 | kbatch, klen, kheads, kd = k.shape 121 | if index is not None: 122 | qcos = jax.lax.broadcast_in_dim(cos[index, :], (batch, qlen, qheads, d), (3,)) 123 | qsin = jax.lax.broadcast_in_dim(sin[index, :], (batch, qlen, qheads, d), (3,)) 124 | else: 125 | qcos = jax.lax.broadcast_in_dim(cos[:qlen, :], (batch, qlen, qheads, d), (1, 3)) 126 | qsin = jax.lax.broadcast_in_dim(sin[:qlen, :], (batch, qlen, qheads, d), (1, 3)) 127 | kcos = jax.lax.broadcast_in_dim(cos[:klen, :], (batch, klen, kheads, d), (1, 3)) 128 | ksin = jax.lax.broadcast_in_dim(sin[:klen, :], (batch, klen, kheads, d), (1, 3)) 129 | out_q = (q * qcos) + (rotate_half(q) * qsin) 130 | out_k = (k * kcos) + (rotate_half(k) * ksin) 131 | return out_q, out_k 132 | 133 | 134 | def rms_norm(cfg, scale, x): 135 | x = jnp.asarray(x, jnp.float32) 136 | mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) 137 | y = jnp.asarray(x * jax.lax.rsqrt(mean2 + cfg.epsilon), cfg.dtype) 138 | return y * jnp.asarray(scale, cfg.dtype) 139 | 140 | 141 | def dropout(cfg: Config, x, broadcast_dims=(-2,), *, ctx: nnx.Context): 142 | if cfg.dropout_rate == 0.0: 143 | return x 144 | broadcast_shape = list(x.shape) 145 | for dim in broadcast_dims: 146 | broadcast_shape[dim] = 1 147 | keep_rate = 1.0 - cfg.dropout_rate 148 | key = ctx.make_rng("dropout") 149 | mask = jax.random.bernoulli(key, p=keep_rate, shape=broadcast_shape) 150 | return jax.lax.select( 151 | jnp.broadcast_to(mask, x.shape), x / keep_rate, jnp.zeros_like(x) 152 | ) 153 | 154 | 155 | class Attention(nnx.Module): 156 | 157 | def __init__(self, cfg: Config, *, ctx: nnx.Context): 158 | sharding = cfg.sharding 159 | 160 | key = ctx.make_rng("params") 161 | self.WQ = nnx.Param( 162 | dense_init(key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2)), 163 | P(sharding.embed, sharding.heads, sharding.depth), 164 | ) 165 | key = ctx.make_rng("params") 166 | self.WK = nnx.Param( 167 | dense_init(key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2)), 168 | P(sharding.embed, sharding.heads, sharding.depth), 169 | ) 170 | key = ctx.make_rng("params") 171 | self.WV = nnx.Param( 172 | dense_init(key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2)), 173 | P(sharding.embed, sharding.heads, sharding.depth), 174 | ) 175 | key = ctx.make_rng("params") 176 | self.WO = nnx.Param( 177 | dense_init(key, (cfg.heads, cfg.depth, cfg.embed), cfg.param_dtype, (0, 1), 2), 178 | P(sharding.heads, sharding.depth, sharding.embed), 179 | ) 180 | # cache 181 | self.index = nnx.variable("cache", jnp.array(0, dtype=jnp.int32), P()) 182 | self.key = nnx.variable( 183 | "cache", 184 | jnp.zeros( 185 | (cfg.batch, cfg.heads, cfg.depth, cfg.max_length), 186 | jnp.bfloat16, 187 | ), 188 | P(sharding.batch, sharding.heads, sharding.depth, None), 189 | ) 190 | self = nnx.variable( 191 | "cache", 192 | jnp.zeros( 193 | (cfg.batch, cfg.heads, cfg.depth, cfg.max_length), 194 | jnp.bfloat16, 195 | ), 196 | P(sharding.batch, sharding.heads, sharding.depth, None), 197 | ) 198 | 199 | # We combine the cache and params into "vs", but it would be no harder at all 200 | # to thread through a separate "cache" argument storing cache entries. 201 | def __call__(self, cfg: Config, x_q, x_kv, mask=None, *, ctx: nnx.Context): 202 | q = jnp.einsum("bse,enh->bsnh", x_q, self.WQ.astype(cfg.dtype)).astype(jnp.float32) 203 | k = jnp.einsum("bte,enh->btnh", x_kv, self.WK.astype(cfg.dtype)).astype(jnp.float32) 204 | v = jnp.einsum("bte,enh->btnh", x_kv, self.WV.astype(cfg.dtype)) 205 | 206 | index = None 207 | if cfg.decode: 208 | index = self.index 209 | one_hot_indices = jax.nn.one_hot(self.index, cfg.max_length, dtype=cfg.dtype) 210 | self.key = self.key + jnp.moveaxis(k, -3, -1) * one_hot_indices 211 | self = self + jnp.moveaxis(v, -3, -1) * one_hot_indices 212 | k = jnp.moveaxis(self.key, -1, -3) 213 | v = jnp.moveaxis(self, -1, -3) 214 | cache_mask = jnp.broadcast_to( 215 | jnp.arange(cfg.max_length) <= self.index, 216 | (cfg.batch, 1, 1, cfg.max_length), 217 | ) 218 | mask = jnp.logical_and(cache_mask if mask is None else mask, cache_mask).astype( 219 | cfg.dtype 220 | ) 221 | self.index = self.index + 1 222 | 223 | attention_bias = 0.0 224 | if mask is None: # Hack in lieu of general mask routing. 225 | mask = make_causal_mask(x, jnp.float32) 226 | if mask is not None: 227 | attention_bias = jax.lax.select( 228 | mask > 0, 229 | jnp.full(mask.shape, 0.0, cfg.dtype), 230 | jnp.full(mask.shape, -1e10, cfg.dtype), 231 | ) 232 | 233 | sin, cos = sine_table(q.shape[-1], max(q.shape[1], k.shape[1])) 234 | q, k = apply_rotary_embedding(q, k, cos, sin, index=index) 235 | 236 | l = jnp.einsum("bsnh,btnh->bnst", q, k) / np.sqrt(cfg.depth) + attention_bias 237 | s = jax.nn.softmax(l).astype(cfg.dtype) 238 | s = dropout(cfg, s, ctx=ctx) 239 | a = jnp.einsum("bnst,btnh->bsnh", s, v) 240 | o = jnp.einsum("bsnh,nhe->bse", a, self.WO.astype(cfg.dtype)) 241 | 242 | return o 243 | 244 | 245 | class MLP(nnx.Module): 246 | 247 | def __init__(self, cfg: Config, *, ctx: nnx.Context): 248 | sharding = cfg.sharding 249 | self.Win1 = nnx.Param( 250 | dense_init( 251 | ctx.make_rng("params"), (cfg.embed, cfg.hidden), cfg.param_dtype, 0, 1 252 | ), 253 | P(sharding.embed, sharding.hidden), 254 | ) 255 | self.Win2 = nnx.Param( 256 | dense_init( 257 | ctx.make_rng("params"), (cfg.embed, cfg.hidden), cfg.param_dtype, 0, 1 258 | ), 259 | P(sharding.embed, sharding.hidden), 260 | ) 261 | self.Wout = nnx.Param( 262 | dense_init( 263 | ctx.make_rng("params"), (cfg.hidden, cfg.embed), cfg.param_dtype, 0, 1 264 | ), 265 | P(sharding.hidden, sharding.embed), 266 | ) 267 | 268 | def __call__(self, cfg: Config, x, *, ctx: nnx.Context): 269 | h1 = jnp.einsum("bse,eh->bsh", x, self.Win1.astype(cfg.dtype)) 270 | h2 = jnp.einsum("bse,eh->bsh", x, self.Win2.astype(cfg.dtype)) 271 | h = jax.nn.gelu(h1) * h2 272 | h = dropout(cfg, h, ctx=ctx) 273 | o = jnp.einsum("bsh,he->bse", h, self.Wout.astype(cfg.dtype)) 274 | return o 275 | 276 | 277 | class DecoderBlock(nnx.Module): 278 | 279 | def __init__(self, cfg: Config, *, ctx: nnx.Context): 280 | sharding = cfg.sharding 281 | self.attn = Attention(cfg, ctx=ctx) 282 | self.mlp = MLP(cfg, ctx=ctx) 283 | self.scale1 = nnx.Param(jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed)) 284 | self.scale2 = nnx.Param(jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed)) 285 | 286 | def __call__(self, cfg: Config, input, *, ctx: nnx.Context): 287 | x = rms_norm(cfg, self.scale1, input) 288 | x = self.attn(cfg, x, x, mask=None, ctx=ctx) 289 | x = dropout(cfg, x, ctx=ctx) 290 | x = x + input 291 | y = rms_norm(cfg, self.scale2, x) 292 | y = self.mlp(cfg, y, ctx=ctx) 293 | y = dropout(cfg, y, ctx=ctx) 294 | return y + x 295 | 296 | 297 | class Decoder(nnx.Module): 298 | 299 | def __init__(self, cfg: Config, *, ctx: nnx.Context): 300 | sharding = cfg.sharding 301 | self.embed = nnx.Param( 302 | embed_init( 303 | ctx.make_rng("params"), (cfg.vocab, cfg.embed), cfg.param_dtype, 1, 0 304 | ), 305 | P(sharding.vocab, sharding.embed), 306 | ) 307 | self.unembed = nnx.Param( 308 | dense_init(ctx.make_rng("params"), (cfg.embed, cfg.vocab), jnp.float32, 0, 1), 309 | P(sharding.embed, sharding.vocab), 310 | ) 311 | self.scale1 = nnx.Param(jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed)) 312 | 313 | if cfg.scanned: 314 | self.layers = jax.vmap( 315 | lambda key: DecoderBlock(cfg, ctx=nnx.context(key)).partition() 316 | )(jax.random.split(ctx.make_rng("params"), cfg.layers)).merge() 317 | else: 318 | self.layers = nnx.Sequence(DecoderBlock(cfg, ctx=ctx) for _ in range(cfg.layers)) 319 | 320 | def __call__(self, cfg: Config, x, *, ctx: nnx.Context): 321 | # TODO: handle right-shifting for training: here or in train loop. 322 | # TODO: handle general mask routing. 323 | x = self.embed.astype(cfg.dtype)[x] 324 | 325 | if cfg.scanned: 326 | assert isinstance(self.layers, DecoderBlock) 327 | 328 | state, moduledef = self.layers.partition() 329 | rngs, ctxdef = ctx.partition() 330 | dropout_key = jax.random.split(rngs["dropout"], cfg.layers) 331 | 332 | def scan_fn(x, s: tp.Tuple[jax.random.KeyArray, nnx.State]): 333 | dropout_key, state = s 334 | ctx = ctxdef.merge({"dropout": dropout_key}) 335 | y, (state, _) = moduledef.apply(state)(cfg, x, ctx=ctx) 336 | return y, state 337 | 338 | x, state = jax.lax.scan( 339 | scan_fn, 340 | x, 341 | (dropout_key, state), 342 | ) 343 | self.layers.update_state(state) 344 | else: 345 | assert isinstance(self.layers, nnx.Sequence) 346 | for decoder_block in self.layers: 347 | x = decoder_block(cfg, x, ctx=ctx) 348 | 349 | x = jnp.einsum("bse,ev->bsv", x, self.unembed) 350 | return x 351 | -------------------------------------------------------------------------------- /examples/08_save_load_checkpoints.py: -------------------------------------------------------------------------------- 1 | from tempfile import TemporaryDirectory 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import orbax.checkpoint as orbax 6 | 7 | import nnx 8 | 9 | 10 | class MLP(nnx.Module): 11 | 12 | def __init__(self, din: int, dmid: int, dout: int, *, ctx: nnx.Context): 13 | self.dense1 = nnx.Linear(din, dmid, ctx=ctx) 14 | self.dense2 = nnx.Linear(dmid, dout, ctx=ctx) 15 | 16 | def __call__(self, x: jax.Array) -> jax.Array: 17 | x = self.dense1(x) 18 | x = jax.nn.relu(x) 19 | x = self.dense2(x) 20 | return x 21 | 22 | 23 | def create_model(seed: int): 24 | return MLP(10, 20, 30, ctx=nnx.context(seed)) 25 | 26 | 27 | def create_and_save(seed: int, path: str): 28 | model = create_model(seed) 29 | state = model.get_state() 30 | # Save the parameters 31 | checkpointer = orbax.PyTreeCheckpointer() 32 | checkpointer.save(f"{path}/state", state) 33 | 34 | 35 | def load_model(path: str) -> MLP: 36 | # create that model with abstract shapes 37 | state, moduledef = jax.eval_shape(lambda: create_model(0).partition()) 38 | # Load the parameters 39 | checkpointer = orbax.PyTreeCheckpointer() 40 | state = checkpointer.restore(f"{path}/state", item=state) 41 | # Merge the parameters into the model 42 | model = moduledef.merge(state) 43 | return model 44 | 45 | 46 | with TemporaryDirectory() as tmpdir: 47 | # create a checkpoint 48 | create_and_save(42, tmpdir) 49 | # load model from checkpoint 50 | model = load_model(tmpdir) 51 | # run the model 52 | y = model(jnp.ones((1, 10))) 53 | print(model) 54 | print(y) 55 | -------------------------------------------------------------------------------- /examples/09_parameter_surgery.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | 5 | import nnx 6 | 7 | 8 | # lets pretend this function loads a pretrained model from a checkpoint 9 | def load_backbone(): 10 | return nnx.Linear(784, 128, ctx=nnx.context(0)) 11 | 12 | 13 | # create a simple linear classifier using a pretrained backbone 14 | class Classifier(nnx.Module): 15 | 16 | def __init__(self, backbone: Callable[[jax.Array], jax.Array], *, ctx: nnx.Context): 17 | self.backbone = backbone 18 | self.head = nnx.Linear(128, 10, ctx=ctx) 19 | 20 | def __call__(self, x): 21 | x = self.backbone(x) 22 | x = nnx.relu(x) 23 | x = self.head(x) 24 | return x 25 | 26 | 27 | backbone = load_backbone() 28 | 29 | # create the classifier using the pretrained backbone, here we are technically 30 | # doing "parameter surgery", however, compared to Haiku/Flax where you must manually 31 | # construct the parameter structure, in NNX this is done automatically 32 | model = Classifier(backbone, ctx=nnx.context(42)) 33 | 34 | # create a filter to select all the parameters that are not part of the 35 | # backbone, i.e. the classifier parameters 36 | is_trainable = nnx.All(nnx.Param, lambda path, node: path.startswith("backbone")) 37 | 38 | # partition the parameters into trainable and non-trainable parameters 39 | (trainable_params, non_trainable), moduledef = model.partition(is_trainable, ...) 40 | 41 | print("trainable_params =", jax.tree_map(jax.numpy.shape, trainable_params)) 42 | print("non_trainable = ", jax.tree_map(jax.numpy.shape, non_trainable)) 43 | -------------------------------------------------------------------------------- /ideas/nnx_example.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Tuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | import nnx 8 | 9 | 10 | class Linear(nnx.Module): 11 | kernel: jax.Array = nnx.Param() 12 | bias: jax.Array = nnx.Param() 13 | 14 | def __init__(self, din: int, dout: int): 15 | self.kernel = jax.random.uniform(nnx.make_rng("params"), (din, dout)) 16 | self.bias = jax.numpy.zeros((dout,)) 17 | 18 | def __call__(self, x): 19 | return x @ self.kernel + self.bias 20 | 21 | 22 | class BatchNorm(nnx.Module): 23 | scale: jax.Array = nnx.Param() 24 | bias: jax.Array = nnx.Param() 25 | mean: jax.Array = nnx.variable("batch_stats") 26 | var: jax.Array = nnx.variable("batch_stats") 27 | mu: float = nnx.static_field() 28 | 29 | def __init__(self, din: int, mu: float = 0.95): 30 | self.scale = jax.random.uniform(nnx.make_rng("params"), (din,)) 31 | self.bias = jax.numpy.zeros((din,)) 32 | self.mean = jax.numpy.zeros((din,)) 33 | self.var = jax.numpy.ones((din,)) 34 | self.mu = mu 35 | 36 | def __call__(self, x, *, use_running_averages: bool) -> jax.Array: 37 | scale, bias = self.scale, self.bias 38 | if use_running_averages: 39 | mean, var = self.mean, self.var 40 | else: 41 | axis = tuple(range(0, x.ndim - 1)) 42 | mean = jax.numpy.mean(x, axis=axis) 43 | var = jax.numpy.var(x, axis=axis) 44 | # ema update 45 | self.mean = self.mu * self.mean + (1 - self.mu) * mean 46 | self.var = self.mu * self.var + (1 - self.mu) * var 47 | 48 | x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias 49 | 50 | return x 51 | 52 | 53 | @nnx.dataclasses 54 | class Dropout(nnx.Module): 55 | rate: float 56 | 57 | def __call__(self, inputs, *, deterministic: bool): 58 | if (self.rate == 0.0) or deterministic: 59 | return inputs 60 | rng = nnx.make_rng("dropout") 61 | keep_prob = 1.0 - self.rate 62 | mask = jax.random.bernoulli(rng, p=keep_prob, shape=inputs.shape) 63 | return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) 64 | 65 | 66 | class MLP(nnx.Module): 67 | 68 | def __init__(self, din: int, dmid: int, dout: int): 69 | self.linear1 = Linear(din, dmid) 70 | self.bn1 = BatchNorm(dmid) 71 | self.dropout = Dropout(0.5) 72 | self.linear2 = Linear(dmid, dout) 73 | 74 | def __call__(self, x: jax.Array, *, train: bool) -> jax.Array: 75 | x = self.linear1(x) 76 | x = self.bn1(x, use_running_averages=not train) 77 | x = self.dropout(x, deterministic=not train) 78 | x = jax.nn.relu(x) 79 | x = self.linear2(x) 80 | return x 81 | 82 | 83 | rngs = nnx.Context(jax.random.PRNGKey(0)) 84 | model = MLP.init(rngs)(10, 20, 30) 85 | 86 | 87 | @nnx.jit 88 | def train_step(model: MLP, key, batch): 89 | x, y = batch 90 | 91 | def loss(model: MLP): 92 | rngs = nnx.Context(dropout=key) 93 | y_pred = model.apply(rngs=rngs)(x, train=True) 94 | loss = jax.numpy.mean((y_pred - y) ** 2) 95 | return loss 96 | 97 | grads = nnx.grad(loss, wrt=nnx.Param)(model) 98 | model[:] = jax.tree_map(lambda w, g: w - 0.1 * g, model["params"], grads) 99 | 100 | 101 | # ---------------------------------------- 102 | # scan over layers + shared batchnorm 103 | # ---------------------------------------- 104 | 105 | n_layers = 10 106 | params_keys = jax.random.PRNGKey(0) 107 | params_keys = jax.random.split(params_keys, n_layers) 108 | 109 | 110 | @partial(jax.vmap, in_axes=0, out_axes=(0, None, None)) 111 | def create_state(params_key: jax.random.KeyArray): 112 | rngs = nnx.Context(params=params_key) 113 | model = MLP.init(rngs)(10, 20, 10) 114 | (params, batch_stats), modeldef = model.partition(nnx.Param, "batch_stats") 115 | return params, batch_stats, modeldef 116 | 117 | 118 | params, batch_stats, modeldef = create_state(params_keys) 119 | x = jax.numpy.zeros((32, 10)) 120 | dropout_key = jax.random.PRNGKey(1) 121 | dropout_stream = nnx.RngStream(jax.random.split(dropout_key, n_layers)) 122 | 123 | 124 | def scan_fn( 125 | carry: Tuple[jax.Array, nnx.State], 126 | inputs: Tuple[nnx.State, nnx.RngStream], 127 | ): 128 | # extract args 129 | x, batch_stats = carry 130 | params, dropout_stream = inputs 131 | 132 | # create state and rngs 133 | model = modeldef.merge([params, batch_stats]) 134 | rngs = nnx.Context(dropout=dropout_stream) 135 | 136 | # forward pass 137 | x = model.apply(rngs=rngs)(x, train=True) 138 | 139 | # partition state 140 | params, batch_stats = model.partition(nnx.Param, "batch_stats")[0] 141 | 142 | return (x, batch_stats), params 143 | 144 | 145 | (y, batch_stats), params = jax.lax.scan( 146 | scan_fn, (x, batch_stats), (params, dropout_stream) 147 | ) 148 | model = modeldef.merge([params, batch_stats]) 149 | -------------------------------------------------------------------------------- /ideas/pure/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import Initializer, Module 2 | from .partitioning import ( 3 | NOTHING, 4 | Partition, 5 | get_partition, 6 | merge_partitions, 7 | tree_partition, 8 | ) 9 | from .rngs import Rngs, RngStream 10 | from .state import State, Variable, merge 11 | -------------------------------------------------------------------------------- /ideas/pure/full/partitioning_full.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import jax.tree_util as jtu 5 | 6 | A = tp.TypeVar("A") 7 | CollectionPredicate = tp.Callable[[str], bool] 8 | Leaf = tp.Any 9 | Leaves = tp.List[Leaf] 10 | KeyPath = tp.Tuple[tp.Hashable, ...] 11 | LeafPredicate = tp.Callable[[tp.Any], bool] 12 | 13 | 14 | class Variable: 15 | __slots__ = ("_value", "_collection") 16 | 17 | def __init__(self, value: tp.Any, collection: str = "params"): 18 | self._value = value 19 | self._collection = collection 20 | 21 | @property 22 | def value(self) -> tp.Any: 23 | return self._value 24 | 25 | @property 26 | def collection(self) -> str: 27 | return self._collection 28 | 29 | @classmethod 30 | def from_value(cls, value: tp.Any) -> "Variable": 31 | return value if isinstance(value, Variable) else Variable(value) 32 | 33 | def copy(self) -> "Variable": 34 | return Variable(self.value, self.collection) 35 | 36 | def update(self, value: tp.Any) -> "Variable": 37 | if isinstance(value, Variable): 38 | if value.collection != self.collection: 39 | raise ValueError( 40 | f"Cannot update variable with value from a different collection. " 41 | f"Expected collection {self.collection}, got {value.collection}" 42 | ) 43 | value = value.value 44 | return Variable(value, self.collection) 45 | 46 | def __repr__(self) -> str: 47 | return f"Variable({self.value}, collection={self.collection})" 48 | 49 | 50 | def _flatten_variable_with_keys(variable: Variable): 51 | node = (jtu.GetAttrKey("value"), variable.value) 52 | return (node,), variable.collection 53 | 54 | 55 | def _flatten_variable(variable: Variable): 56 | return (variable.value,), variable.collection 57 | 58 | 59 | def _unflatten_variable(collection: str, nodes: tp.Tuple[tp.Any]): 60 | return Variable(nodes[0], collection) 61 | 62 | 63 | jax.tree_util.register_pytree_with_keys( 64 | Variable, 65 | _flatten_variable_with_keys, 66 | _unflatten_variable, 67 | flatten_func=_flatten_variable, 68 | ) 69 | 70 | 71 | class Nothing: 72 | 73 | def __repr__(self) -> str: 74 | return "Nothing" # pragma: no cover 75 | 76 | 77 | def _nothing_flatten(x): 78 | return (), None 79 | 80 | 81 | def _nothing_unflatten(aux_data, children): 82 | return NOTHING 83 | 84 | 85 | NOTHING = Nothing() 86 | 87 | jtu.register_pytree_node(Nothing, _nothing_flatten, _nothing_unflatten) 88 | 89 | 90 | class StrPath(tp.Tuple[str, ...]): 91 | pass 92 | 93 | 94 | class Partition(tp.Dict[tp.Tuple[str, ...], Leaf]): 95 | 96 | def __setitem__(self, key, value): 97 | raise TypeError("Partition is immutable") 98 | 99 | 100 | def _partition_flatten_with_keys( 101 | x: Partition, 102 | ) -> tp.Tuple[ 103 | tp.Tuple[tp.Tuple[StrPath, Leaf], ...], tp.Tuple[tp.Tuple[str, ...], ...] 104 | ]: 105 | children = tuple((StrPath(key), value) for key, value in x.items()) 106 | return children, tuple(x.keys()) 107 | 108 | 109 | def _partition_unflatten(keys: tp.Tuple[StrPath, ...], leaves: tp.Tuple[Leaf, ...]): 110 | return Partition(zip(keys, leaves)) 111 | 112 | 113 | jax.tree_util.register_pytree_with_keys( 114 | Partition, _partition_flatten_with_keys, _partition_unflatten 115 | ) 116 | 117 | 118 | def _key_path_to_str_gen(key_path: KeyPath) -> tp.Generator[str, None, None]: 119 | for key_entry in key_path: 120 | if isinstance(key_entry, StrPath): 121 | yield from key_entry 122 | elif isinstance(key_entry, jtu.SequenceKey): 123 | yield str(key_entry.idx) 124 | elif isinstance(key_entry, jtu.DictKey): # "['a']" 125 | yield str(key_entry.key) 126 | elif isinstance(key_entry, jtu.GetAttrKey): 127 | yield str(key_entry.name) 128 | elif isinstance(key_entry, jtu.FlattenedIndexKey): 129 | yield str(key_entry.key) 130 | elif hasattr(key_entry, "__dict__") and len(key_entry.__dict__) == 1: 131 | yield str(next(iter(key_entry.__dict__.values()))) 132 | else: 133 | yield str(key_entry) 134 | 135 | 136 | def _key_path_to_str_path(key_path: KeyPath) -> StrPath: 137 | return StrPath(_key_path_to_str_gen(key_path)) 138 | 139 | 140 | class StateDef(tp.Generic[A]): 141 | __slots__ = ("treedef",) 142 | 143 | def __init__(self, treedef: jtu.PyTreeDef): 144 | self.treedef = treedef 145 | 146 | def merge(self, *partitions: Partition) -> A: 147 | raise NotImplementedError 148 | 149 | 150 | def statedef_flatten(x: StateDef): 151 | return (), x.treedef 152 | 153 | 154 | def statedef_unflatten(treedef, children): 155 | return StateDef(treedef) 156 | 157 | 158 | jtu.register_pytree_node(StateDef, statedef_flatten, statedef_unflatten) 159 | 160 | 161 | def tree_partition( 162 | pytree: A, 163 | *predicates: CollectionPredicate, 164 | is_leaf: tp.Optional[LeafPredicate] = None, 165 | ) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef[A]]: 166 | paths_leaves: tp.List[tp.Tuple[KeyPath, Leaf]] 167 | paths_leaves, treedef = jax.tree_util.tree_flatten_with_path( 168 | pytree, 169 | is_leaf=lambda x: (isinstance(x, Variable) or x is NOTHING) 170 | or (False if is_leaf is None else is_leaf(x)), 171 | ) 172 | 173 | leaves: tp.Tuple[Leaf, ...] 174 | paths, leaves = zip(*paths_leaves) 175 | paths = tuple(map(_key_path_to_str_path, paths)) 176 | 177 | # we have n + 1 partitions, where n is the number of predicates 178 | # the last partition is for values that don't match any predicate 179 | partition_leaves: tp.Tuple[Leaves, ...] = tuple( 180 | [NOTHING] * len(leaves) for _ in range(len(predicates) + 1) 181 | ) 182 | for j, leaf in enumerate(leaves): 183 | for i, predicate in enumerate(predicates): 184 | if isinstance(leaf, Variable) and predicate(leaf.collection): 185 | partition_leaves[i][j] = leaf 186 | break 187 | else: 188 | # if we didn't break, set leaf to last partition 189 | partition_leaves[-1][j] = leaf 190 | 191 | partitions = tuple(Partition(zip(paths, partition)) for partition in partition_leaves) 192 | return partitions, StateDef(treedef) 193 | 194 | 195 | def get_partition( 196 | pytree, 197 | predicate: CollectionPredicate, 198 | is_leaf: tp.Optional[LeafPredicate] = None, 199 | ) -> Partition: 200 | (partition, _rest), _treedef = tree_partition(pytree, predicate, is_leaf=is_leaf) 201 | return partition 202 | 203 | 204 | def _get_non_nothing( 205 | paths: tp.Tuple[StrPath, ...], 206 | leaves: tp.Tuple[tp.Union[Leaf, Nothing], ...], 207 | position: int, 208 | ): 209 | # check that all paths are the same 210 | paths_set = set(paths) 211 | if len(paths_set) != 1: 212 | raise ValueError( 213 | "All partitions must have the same paths, " 214 | f" at position [{position}] got " 215 | "".join(f"\n- {path}" for path in paths_set) 216 | ) 217 | non_null = [option for option in leaves if option is not NOTHING] 218 | if len(non_null) == 0: 219 | raise ValueError(f"Expected at least one non-null value for position [{position}]") 220 | elif len(non_null) > 1: 221 | raise ValueError(f"Expected at most one non-null value for position [{position}]") 222 | return non_null[0] 223 | 224 | 225 | def merge_partitions( 226 | partitions: tp.Sequence[Partition], treedef: jax.tree_util.PyTreeDef 227 | ): 228 | lenghts = [len(partition) for partition in partitions] 229 | if not all(length == lenghts[0] for length in lenghts): 230 | raise ValueError( 231 | "All partitions must have the same length, got " 232 | f"{', '.join(str(length) for length in lenghts)}" 233 | ) 234 | 235 | partition_paths = (list(partition.keys()) for partition in partitions) 236 | partition_leaves = (list(partition.values()) for partition in partitions) 237 | 238 | merged_leaves = [ 239 | _get_non_nothing(paths, leaves, i) 240 | for i, (paths, leaves) in enumerate( 241 | zip(zip(*partition_paths), zip(*partition_leaves)) 242 | ) 243 | ] 244 | 245 | return jax.tree_util.tree_unflatten(treedef, merged_leaves) 246 | -------------------------------------------------------------------------------- /ideas/pure/full/state_full.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from types import MappingProxyType 3 | 4 | import jax 5 | import jax.tree_util as jtu 6 | from pure.partitioning import Partition, StateDef, Variable 7 | 8 | Node = tp.Union[Variable, "State"] 9 | S = tp.TypeVar("S", bound="State") 10 | 11 | 12 | class State(tp.Mapping[str, Node]): 13 | __slots__ = ("_variables",) 14 | 15 | def __init__(self, *args, **kwargs: tp.Union[Node, jax.Array]): 16 | self._variables = { 17 | k: self._create_node_field(v) for k, v in dict(*args, **kwargs).items() 18 | } 19 | 20 | @staticmethod 21 | def _create_node_field(value: tp.Any) -> Node: 22 | if isinstance(value, State): 23 | return value 24 | else: 25 | return Variable.from_value(value) 26 | 27 | @staticmethod 28 | def _update_node_field(node: Node, value: tp.Any) -> Node: 29 | if isinstance(node, State) and isinstance(value, State): 30 | return value 31 | elif isinstance(node, Variable) and isinstance(value, Variable): 32 | return node.update(value) 33 | else: 34 | raise ValueError( 35 | f"Cannot update node of type {type(node).__name__} with " 36 | f"value of type {type(value).__name__}" 37 | ) 38 | 39 | def __getitem__(self, name: str) -> tp.Any: 40 | return self._variables[name].value 41 | 42 | __getattr__ = __getitem__ 43 | 44 | def __iter__(self) -> tp.Iterator[str]: 45 | return iter(self._variables) 46 | 47 | def __len__(self) -> int: 48 | return len(self._variables) 49 | 50 | def keys(self) -> tp.KeysView[str]: 51 | return self._variables.keys() 52 | 53 | def values(self) -> tp.ValuesView[Node]: 54 | return self._variables.values() 55 | 56 | def __repr__(self) -> str: 57 | return f"State({self._variables})" 58 | 59 | def update(self, *args, **kwargs: tp.Union[Node, tp.Any]) -> "State": 60 | raise NotImplementedError 61 | 62 | @tp.overload 63 | def partition(self) -> tp.Tuple[tp.Dict[str, Partition], StateDef["State"]]: 64 | ... 65 | 66 | @tp.overload 67 | def partition(self, collection: str) -> tp.Tuple[Partition, StateDef["State"]]: 68 | ... 69 | 70 | @tp.overload 71 | def partition( 72 | self, collection: str, *collections: str 73 | ) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef["State"]]: 74 | ... 75 | 76 | def partition( 77 | self, *collections: str 78 | ) -> tp.Tuple[ 79 | tp.Union[tp.Dict[str, Partition], tp.Tuple[Partition, ...], Partition], 80 | StateDef["State"], 81 | ]: 82 | raise NotImplementedError 83 | 84 | @tp.overload 85 | def get_partition(self, collection: str) -> Partition: 86 | ... 87 | 88 | @tp.overload 89 | def get_partition( 90 | self, collection: str, *collections: str 91 | ) -> tp.Tuple[Partition, ...]: 92 | ... 93 | 94 | def get_partition( 95 | self, *collections: str 96 | ) -> tp.Union[Partition, tp.Tuple[Partition, ...]]: 97 | raise NotImplementedError 98 | 99 | def update_partition(self, partition: Partition, *partitions: Partition) -> "State": 100 | raise NotImplementedError 101 | 102 | @tp.overload 103 | def pop(self, name: str) -> Node: 104 | ... 105 | 106 | @tp.overload 107 | def pop(self, name: str, *names: str) -> tp.Tuple[Node, ...]: 108 | ... 109 | 110 | def pop(self, *names: str) -> tp.Union[Node, tp.Tuple[Node, ...]]: 111 | if len(names) == 0: 112 | raise ValueError("pop expected at least 1 argument, got 0") 113 | elif len(names) == 1: 114 | name = names[0] 115 | return self._variables.pop(name) 116 | else: 117 | return tuple(self._variables.pop(name) for name in names) 118 | 119 | 120 | def _state_flatten_with_keys(state: State): 121 | nodes = tuple((jtu.GetAttrKey(name), variable) for name, variable in state.items()) 122 | names = tuple(state) 123 | return nodes, names 124 | 125 | 126 | def _state_unflatten(names: tp.Tuple[str, ...], nodes: tp.Tuple[Variable, ...]): 127 | return State(zip(names, nodes)) 128 | 129 | 130 | def _state_flatten(state: State): 131 | return tuple(state.values()), tuple(state) 132 | 133 | 134 | jtu.register_pytree_with_keys( 135 | State, _state_flatten_with_keys, _state_unflatten, flatten_func=_state_flatten 136 | ) 137 | -------------------------------------------------------------------------------- /ideas/pure/module.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from dataclasses import dataclass 3 | from typing import Any 4 | 5 | import jax 6 | import jax.tree_util as jtu 7 | from pure.partitioning import Partition 8 | from pure.rngs import KeyArray, Rngs 9 | from pure.state import State, Variable 10 | 11 | A = tp.TypeVar("A", contravariant=True) 12 | 13 | 14 | class InitFn(tp.Protocol, tp.Generic[A]): 15 | 16 | @tp.overload 17 | def __call__(self, __key_or_stream: A) -> tp.Any: 18 | ... 19 | 20 | def __call__(self, __key_or_stream: A, *args: tp.Any) -> tp.Any: 21 | ... 22 | 23 | 24 | class Initializer: 25 | 26 | @tp.overload 27 | def __init__( 28 | self, 29 | initializer: InitFn[KeyArray], 30 | *args, 31 | collection: str = "params", 32 | ): 33 | ... 34 | 35 | @tp.overload 36 | def __init__( 37 | self, 38 | initializer: InitFn[Rngs], 39 | *args, 40 | stream: None, 41 | collection: str = "params", 42 | ): 43 | ... 44 | 45 | def __init__( 46 | self, 47 | initializer: tp.Union[InitFn[KeyArray], InitFn[Rngs]], 48 | *args, 49 | stream: tp.Optional[str] = "params", 50 | collection: str = "params", 51 | ): 52 | ... 53 | 54 | def create_variable(self, rngs: Rngs) -> Variable: 55 | ... 56 | 57 | 58 | class Module: 59 | 60 | def create_state(self, rngs: Rngs) -> State: 61 | return State( 62 | ( 63 | name, 64 | v.create_state(rngs) if isinstance(v, Module) else v.create_variable(rngs), 65 | ) 66 | for name, v in vars(self).items() 67 | if isinstance(v, (Initializer, Module)) 68 | ) 69 | -------------------------------------------------------------------------------- /ideas/pure/partitioning.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import jax.tree_util as jtu 5 | 6 | A = tp.TypeVar("A") 7 | CollectionPredicate = tp.Callable[[str], bool] 8 | Leaf = tp.Any 9 | Leaves = tp.List[Leaf] 10 | KeyPath = tp.Tuple[tp.Hashable, ...] 11 | LeafPredicate = tp.Callable[[tp.Any], bool] 12 | 13 | 14 | class Variable: 15 | __slots__ = ("_value", "_collection") 16 | 17 | def __init__(self, value: tp.Any, collection: str = "params"): 18 | ... 19 | 20 | @property 21 | def value(self) -> tp.Any: 22 | ... 23 | 24 | @property 25 | def collection(self) -> str: 26 | ... 27 | 28 | @classmethod 29 | def from_value(cls, value: tp.Any) -> "Variable": 30 | ... 31 | 32 | def copy(self) -> "Variable": 33 | ... 34 | 35 | def update(self, value: tp.Any) -> "Variable": 36 | ... 37 | 38 | def __repr__(self) -> str: 39 | return f"Variable({self.value}, collection={self.collection})" 40 | 41 | 42 | def _flatten_variable_with_keys(variable: Variable): 43 | ... 44 | 45 | 46 | def _flatten_variable(variable: Variable): 47 | ... 48 | 49 | 50 | def _unflatten_variable(collection: str, nodes: tp.Tuple[tp.Any]): 51 | ... 52 | 53 | 54 | jax.tree_util.register_pytree_with_keys( 55 | Variable, 56 | _flatten_variable_with_keys, 57 | _unflatten_variable, 58 | flatten_func=_flatten_variable, 59 | ) 60 | 61 | 62 | class Nothing: 63 | 64 | def __repr__(self) -> str: 65 | ... 66 | 67 | 68 | def _nothing_flatten(x): 69 | ... 70 | 71 | 72 | def _nothing_unflatten(aux_data, children): 73 | ... 74 | 75 | 76 | NOTHING = Nothing() 77 | 78 | jtu.register_pytree_node(Nothing, _nothing_flatten, _nothing_unflatten) 79 | 80 | 81 | class StrPath(tp.Tuple[str, ...]): 82 | pass 83 | 84 | 85 | class Partition(tp.Dict[tp.Tuple[str, ...], Leaf]): 86 | 87 | def __setitem__(self, key, value): 88 | raise TypeError("Partition is immutable") 89 | 90 | 91 | def _partition_flatten_with_keys( 92 | x: Partition, 93 | ) -> tp.Tuple[ 94 | tp.Tuple[tp.Tuple[StrPath, Leaf], ...], tp.Tuple[tp.Tuple[str, ...], ...] 95 | ]: 96 | ... 97 | 98 | 99 | def _partition_unflatten(keys: tp.Tuple[StrPath, ...], leaves: tp.Tuple[Leaf, ...]): 100 | ... 101 | 102 | 103 | jax.tree_util.register_pytree_with_keys( 104 | Partition, _partition_flatten_with_keys, _partition_unflatten 105 | ) 106 | 107 | 108 | class StateDef(tp.Generic[A]): 109 | __slots__ = ("treedef",) 110 | 111 | def __init__(self, treedef: jtu.PyTreeDef): 112 | ... 113 | 114 | @property 115 | def treedef(self) -> jtu.PyTreeDef: 116 | ... 117 | 118 | def merge(self, *partitions: Partition) -> A: 119 | ... 120 | 121 | 122 | def statedef_flatten(x: StateDef): 123 | ... 124 | 125 | 126 | def statedef_unflatten(treedef, children): 127 | ... 128 | 129 | 130 | jtu.register_pytree_node(StateDef, statedef_flatten, statedef_unflatten) 131 | 132 | 133 | def tree_partition( 134 | pytree: A, 135 | *predicates: CollectionPredicate, 136 | is_leaf: tp.Optional[LeafPredicate] = None, 137 | ) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef[A]]: 138 | ... 139 | 140 | 141 | def get_partition( 142 | pytree, 143 | predicate: CollectionPredicate, 144 | is_leaf: tp.Optional[LeafPredicate] = None, 145 | ) -> Partition: 146 | ... 147 | 148 | 149 | def merge_partitions( 150 | partitions: tp.Sequence[Partition], treedef: jax.tree_util.PyTreeDef 151 | ): 152 | ... 153 | -------------------------------------------------------------------------------- /ideas/pure/rngs.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | 5 | KeyArray = tp.Union[jax.Array, jax.random.KeyArray] 6 | 7 | 8 | class RngStream: 9 | 10 | def __init__( 11 | self, key: KeyArray, count: int = 0, count_path: tp.Tuple[int, ...] = () 12 | ): 13 | ... 14 | 15 | @property 16 | def key(self) -> jax.random.KeyArray: 17 | ... 18 | 19 | @property 20 | def count(self) -> int: 21 | ... 22 | 23 | @property 24 | def count_path(self) -> tp.Tuple[int, ...]: 25 | ... 26 | 27 | def next(self) -> jax.random.KeyArray: 28 | ... 29 | 30 | def fork(self) -> "RngStream": 31 | ... 32 | 33 | 34 | class Rngs: 35 | 36 | def __init__(self, **streams: tp.Union[KeyArray, RngStream]): 37 | ... 38 | 39 | def make_rng(self, stream: str) -> jax.Array: 40 | ... 41 | 42 | @tp.overload 43 | def fork(self, stream: str) -> RngStream: 44 | ... 45 | 46 | @tp.overload 47 | def fork(self, stream: str, *streams: str) -> tp.Tuple[RngStream, ...]: 48 | ... 49 | 50 | def fork(self, *streams: str) -> tp.Union[RngStream, tp.Tuple[RngStream, ...]]: 51 | ... 52 | -------------------------------------------------------------------------------- /ideas/pure/state.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from types import MappingProxyType 3 | 4 | import jax 5 | import jax.tree_util as jtu 6 | from pure.partitioning import Partition, StateDef, Variable 7 | 8 | Node = tp.Union[Variable, "State"] 9 | S = tp.TypeVar("S", bound="State") 10 | 11 | 12 | class State(tp.Mapping[str, Node]): 13 | __slots__ = ("_variables",) 14 | 15 | def __init__(self, *args, **kwargs: tp.Union[Node, jax.Array]): 16 | ... 17 | 18 | def __getitem__(self, name: str) -> tp.Any: 19 | ... 20 | 21 | __getattr__ = __getitem__ 22 | 23 | def __iter__(self) -> tp.Iterator[str]: 24 | ... 25 | 26 | def __len__(self) -> int: 27 | ... 28 | 29 | def keys(self) -> tp.KeysView[str]: 30 | ... 31 | 32 | def values(self) -> tp.ValuesView[Node]: 33 | ... 34 | 35 | def __repr__(self) -> str: 36 | ... 37 | 38 | def update(self, *args, **kwargs: tp.Union[Node, tp.Any]) -> "State": 39 | ... 40 | 41 | @tp.overload 42 | def partition(self) -> tp.Dict[str, Partition]: 43 | ... 44 | 45 | @tp.overload 46 | def partition(self, collection: str) -> Partition: 47 | ... 48 | 49 | @tp.overload 50 | def partition(self, collection: str, *collections: str) -> tp.Tuple[Partition, ...]: 51 | ... 52 | 53 | def partition( 54 | self, *collections: str 55 | ) -> tp.Union[tp.Dict[str, Partition], tp.Tuple[Partition, ...], Partition]: 56 | ... 57 | 58 | def merge(self, partition: Partition, *partitions: Partition) -> "State": 59 | ... 60 | 61 | @tp.overload 62 | def pop(self, name: str) -> Node: 63 | ... 64 | 65 | @tp.overload 66 | def pop(self, name: str, *names: str) -> tp.Tuple[Node, ...]: 67 | ... 68 | 69 | def pop(self, *names: str) -> tp.Union[Node, tp.Tuple[Node, ...]]: 70 | ... 71 | 72 | 73 | def _state_flatten_with_keys(state: State): 74 | ... 75 | 76 | 77 | def _state_unflatten(names: tp.Tuple[str, ...], nodes: tp.Tuple[Variable, ...]): 78 | ... 79 | 80 | 81 | def _state_flatten(state: State): 82 | ... 83 | 84 | 85 | jtu.register_pytree_with_keys( 86 | State, _state_flatten_with_keys, _state_unflatten, flatten_func=_state_flatten 87 | ) 88 | 89 | 90 | def merge(partition: Partition, other: Partition, *rest: Partition) -> State: 91 | ... 92 | -------------------------------------------------------------------------------- /ideas/pure_example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Tuple 4 | 5 | import jax 6 | import pure 7 | from pure.rngs import Rngs 8 | from pure.state import State 9 | 10 | 11 | @dataclass 12 | class Linear: 13 | din: int 14 | dout: int 15 | 16 | def create_state(self, rngs: Rngs) -> State: 17 | key = rngs.make_rng("params") 18 | return State( 19 | kernel=jax.random.uniform(key, (self.din, self.dout)), 20 | bias=jax.numpy.zeros((self.dout,)), 21 | ) 22 | 23 | def __call__(self, state: pure.State, x): 24 | return x @ state.kernel + state.bias 25 | 26 | 27 | class BatchNorm(pure.Module): 28 | 29 | def __init__(self, din: int, mu: float = 0.95): 30 | self.scale = pure.Initializer(jax.random.uniform, (din,)) 31 | self.bias = pure.Initializer(lambda _: jax.numpy.zeros((din,))) 32 | self.mean = pure.Initializer( 33 | lambda _: jax.numpy.zeros((din,)), collection="batch_stats" 34 | ) 35 | self.var = pure.Initializer( 36 | lambda _: jax.numpy.ones((din,)), collection="batch_stats" 37 | ) 38 | self.mu = mu 39 | 40 | def __call__( 41 | self, state: pure.State, x, use_running_averages: bool 42 | ) -> Tuple[jax.Array, pure.State]: 43 | scale, bias = state.scale, state.bias 44 | if use_running_averages: 45 | mean, var = state.mean, state.var 46 | else: 47 | axis = tuple(range(0, x.ndim - 1)) 48 | mean = jax.numpy.mean(x, axis=axis) 49 | var = jax.numpy.var(x, axis=axis) 50 | # ema update 51 | state = state.update( 52 | mean=self.mu * state.mean + (1 - self.mu) * mean, 53 | var=self.mu * state.var + (1 - self.mu) * var, 54 | ) 55 | 56 | x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias 57 | 58 | return x, state 59 | 60 | 61 | class Dropout(pure.Module): 62 | 63 | def __init__(self, rate: float): 64 | raise NotImplementedError 65 | 66 | def __call__(self, state, rngs: Rngs, x, *, deterministic: bool) -> jax.Array: 67 | key = rngs.make_rng("dropout") 68 | raise NotImplementedError 69 | 70 | 71 | class MLP(pure.Module): 72 | 73 | def __init__(self, din: int, dmid: int, dout: int): 74 | self.linear1 = Linear(din, dmid) 75 | self.bn1 = BatchNorm(dmid) 76 | self.dropout = Dropout(0.5) 77 | self.linear2 = Linear(dmid, dout) 78 | 79 | def __call__( 80 | self, state: pure.State, rngs: pure.Rngs, x: jax.Array, *, train: bool 81 | ) -> Tuple[jax.Array, pure.State]: 82 | x = self.linear1(state.linear1, x) 83 | x, bn1 = self.bn1(state.bn1, x, use_running_averages=not train) 84 | x = self.dropout(state.dropout, rngs, x, deterministic=not train) 85 | x = jax.nn.relu(x) 86 | x = self.linear2(state.linear2, x) 87 | return x, state.update(bn1=bn1) 88 | 89 | 90 | model = MLP(10, 20, 30) 91 | rngs = pure.Rngs(params=jax.random.PRNGKey(0)) 92 | state = model.create_state(rngs) 93 | 94 | 95 | @jax.jit 96 | def train_step(state: pure.State, key, batch): 97 | x, y = batch 98 | params = state.partition(nnx.Param) 99 | rngs = pure.Rngs(dropout=key) 100 | 101 | def loss(params): 102 | _state = state.merge(params) 103 | y_pred, _state = model(_state, rngs, x, train=True) 104 | loss = jax.numpy.mean((y_pred - y) ** 2) 105 | return loss, _state 106 | 107 | grads, state = jax.grad(loss, has_aux=True)(params) 108 | params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) 109 | state = state.merge(params) 110 | 111 | return state 112 | 113 | 114 | # ---------------------------------------- 115 | # scan over layers + shared batch_stats 116 | # ---------------------------------------- 117 | 118 | model = MLP(10, 20, 10) 119 | n_layers = 10 120 | params_keys = jax.random.PRNGKey(0) 121 | params_keys = jax.random.split(params_keys, n_layers) 122 | 123 | 124 | @partial(jax.vmap, in_axes=0, out_axes=(0, None)) 125 | def create_state(params_key: jax.random.KeyArray): 126 | state = model.create_state(pure.Rngs(params=params_key)) 127 | params, batch_stats = state.partition(nnx.Param, "batch_stats") 128 | return params, batch_stats 129 | 130 | 131 | params, batch_stats = create_state(params_keys) 132 | x = jax.numpy.zeros((32, 10)) 133 | dropout_key = jax.random.PRNGKey(1) 134 | dropout_stream = pure.RngStream(jax.random.split(dropout_key, n_layers)) 135 | 136 | 137 | def scan_fn( 138 | carry: Tuple[jax.Array, pure.Partition], 139 | inputs: Tuple[pure.Partition, pure.RngStream], 140 | ): 141 | # extract args 142 | x, batch_stats = carry 143 | params, dropout_stream = inputs 144 | 145 | # create state and rngs 146 | state = pure.merge(params, batch_stats) 147 | rngs = pure.Rngs(dropout=dropout_stream) 148 | 149 | # forward pass 150 | x, state = model(state, rngs, x, train=True) 151 | 152 | # partition state 153 | params, batch_stats = state.partition(nnx.Param, "batch_stats") 154 | 155 | return (x, batch_stats), params 156 | 157 | 158 | (y, batch_stats), params = jax.lax.scan( 159 | scan_fn, (x, batch_stats), (params, dropout_stream) 160 | ) 161 | state = pure.merge(params, batch_stats) 162 | -------------------------------------------------------------------------------- /ideas/pure_nnx_example.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import partial 3 | from typing import Tuple 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | import nnx 9 | 10 | 11 | class Linear(nnx.Module): 12 | 13 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 14 | self.kernel = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din, dout))) 15 | self.bias = nnx.Param(jax.numpy.zeros((dout,))) 16 | 17 | def __call__(self, x): 18 | return x @ self.kernel + self.bias 19 | 20 | 21 | class BatchNorm(nnx.Module): 22 | 23 | def __init__(self, din: int, mu: float = 0.95, *, ctx: nnx.Context): 24 | self.scale = nnx.Param(jax.random.uniform(ctx.make_rng("params"), (din,))) 25 | self.bias = nnx.Param(jax.numpy.zeros((din,))) 26 | self.mean = nnx.BatchStat(jax.numpy.zeros((din,))) 27 | self.var = nnx.BatchStat(jax.numpy.ones((din,))) 28 | self.mu = mu 29 | 30 | def __call__(self, x, *, use_running_averages: bool) -> jax.Array: 31 | scale, bias = self.scale, self.bias 32 | if use_running_averages: 33 | mean, var = self.mean, self.var 34 | else: 35 | axis = tuple(range(0, x.ndim - 1)) 36 | mean = jax.numpy.mean(x, axis=axis) 37 | var = jax.numpy.var(x, axis=axis) 38 | # ema update 39 | self.mean = self.mu * self.mean + (1 - self.mu) * mean 40 | self.var = self.mu * self.var + (1 - self.mu) * var 41 | 42 | x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias 43 | 44 | return x 45 | 46 | 47 | @dataclasses.dataclass 48 | class Dropout(nnx.Module): 49 | rate: float 50 | 51 | def __call__(self, inputs, *, deterministic: bool, ctx: nnx.Context): 52 | if (self.rate == 0.0) or deterministic: 53 | return inputs 54 | rng = ctx.make_rng("dropout") 55 | keep_prob = 1.0 - self.rate 56 | mask = jax.random.bernoulli(rng, p=keep_prob, shape=inputs.shape) 57 | return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) 58 | 59 | 60 | class MLP(nnx.Module): 61 | 62 | def __init__(self, din: int, dmid: int, dout: int, *, ctx: nnx.Context): 63 | self.linear1 = Linear(din, dmid, ctx=ctx) 64 | self.bn1 = BatchNorm(dmid, ctx=ctx) 65 | self.dropout = Dropout(0.5) 66 | self.linear2 = Linear(dmid, dout, ctx=ctx) 67 | 68 | def __call__(self, x: jax.Array, *, train: bool, ctx: nnx.Context) -> jax.Array: 69 | x = self.linear1(x) 70 | x = self.bn1(x, use_running_averages=not train) 71 | x = self.dropout(x, deterministic=not train, ctx=ctx) 72 | x = jax.nn.relu(x) 73 | x = self.linear2(x) 74 | return x 75 | 76 | 77 | ctx = nnx.Context(jax.random.PRNGKey(0)) 78 | model = MLP(10, 20, 30, ctx=ctx) 79 | 80 | 81 | @nnx.jit 82 | def train_step(model: MLP, key, batch): 83 | x, y = batch 84 | 85 | def loss(model: MLP): 86 | ctx = nnx.Context(rngs=dict(dropout=key)) 87 | y_pred = model(x, train=True, ctx=ctx) 88 | loss = jax.numpy.mean((y_pred - y) ** 2) 89 | return loss 90 | 91 | grads = nnx.grad(loss, wrt=nnx.Param)(model) 92 | model.update_state( 93 | jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) 94 | ) 95 | 96 | 97 | # ---------------------------------------- 98 | # scan over layers + shared batchnorm 99 | # ---------------------------------------- 100 | 101 | n_layers = 10 102 | params_keys = jax.random.PRNGKey(0) 103 | params_keys = jax.random.split(params_keys, n_layers) 104 | 105 | 106 | @partial(jax.vmap, in_axes=0, out_axes=(0, None, None)) 107 | def create_state(params_key: jax.random.KeyArray): 108 | ctx = nnx.Context(rngs=dict(params=params_key)) 109 | model = MLP(10, 20, 10, ctx=ctx) 110 | (params, batch_stats), modeldef = model.partition(nnx.Param, "batch_stats") 111 | return params, batch_stats, modeldef 112 | 113 | 114 | params, batch_stats, modeldef = create_state(params_keys) 115 | x = jax.numpy.zeros((32, 10)) 116 | dropout_key = jax.random.split(jax.random.PRNGKey(1), n_layers) 117 | 118 | 119 | def scan_fn( 120 | carry: Tuple[jax.Array, nnx.State], 121 | inputs: Tuple[nnx.State, jax.random.KeyArray], 122 | ): 123 | # extract args 124 | x, batch_stats = carry 125 | params, dropout_key = inputs 126 | 127 | # create state and ctx 128 | model = modeldef.merge(params, batch_stats) 129 | ctx = nnx.Context(dropout=dropout_key) 130 | 131 | # forward pass 132 | x = model(x, train=True, ctx=ctx) 133 | 134 | # partition state 135 | (params, batch_stats), _ = model.partition(nnx.Param, "batch_stats") 136 | 137 | return (x, batch_stats), params 138 | 139 | 140 | (y, batch_stats), params = jax.lax.scan( 141 | scan_fn, (x, batch_stats), (params, dropout_key) 142 | ) 143 | model = modeldef.merge(params, batch_stats) 144 | -------------------------------------------------------------------------------- /ideas/pure_pytree/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataclass import VariableField, dataclass, field, param, static_field, variable 2 | from .module import Initializer, Module 3 | from .partitioning import NOTHING, Partition, get_partition 4 | from .partitioning import merge_partitions as merge 5 | from .partitioning import tree_partition as partition 6 | from .rngs import Rngs, RngStream 7 | -------------------------------------------------------------------------------- /ideas/pure_pytree/dataclass.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing as tp 3 | from dataclasses import field 4 | 5 | import typing_extensions as tpe 6 | from simple_pytree import static_field 7 | 8 | A = tp.TypeVar("A") 9 | K = tp.TypeVar("K", bound=tp.Hashable) 10 | 11 | 12 | class VariableField(dataclasses.Field, tp.Generic[A]): 13 | 14 | def __init__( 15 | self, 16 | *, 17 | collection: tp.Hashable = None, 18 | default: tp.Any = dataclasses.MISSING, 19 | default_factory: tp.Any = dataclasses.MISSING, 20 | init: bool = True, 21 | repr: bool = True, 22 | hash: tp.Optional[bool] = None, 23 | compare: bool = True, 24 | metadata: tp.Optional[tp.Mapping[tp.Any, tp.Any]] = None, 25 | ): 26 | ... 27 | 28 | def __set_name__(self, cls, name): 29 | ... 30 | 31 | def __get__(self, obj, objtype=None) -> A: 32 | ... 33 | 34 | def __set__(self, obj, value: A): 35 | ... 36 | 37 | 38 | # ---------------------------------------- 39 | # fields 40 | # ---------------------------------------- 41 | 42 | 43 | def variable( 44 | collection: str, 45 | default: tp.Any = dataclasses.MISSING, 46 | *, 47 | default_factory: tp.Any = dataclasses.MISSING, 48 | init: bool = True, 49 | repr: bool = True, 50 | hash: tp.Optional[bool] = None, 51 | compare: bool = True, 52 | metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, 53 | ) -> tp.Any: 54 | return VariableField( 55 | collection=collection, 56 | default=default, 57 | default_factory=default_factory, 58 | init=init, 59 | repr=repr, 60 | hash=hash, 61 | compare=compare, 62 | metadata=metadata, 63 | ) 64 | 65 | 66 | def param( 67 | default: tp.Any = dataclasses.MISSING, 68 | *, 69 | default_factory: tp.Any = dataclasses.MISSING, 70 | init: bool = True, 71 | repr: bool = True, 72 | hash: tp.Optional[bool] = None, 73 | compare: bool = True, 74 | metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, 75 | ) -> tp.Any: 76 | return variable( 77 | "params", 78 | default=default, 79 | default_factory=default_factory, 80 | init=init, 81 | repr=repr, 82 | hash=hash, 83 | compare=compare, 84 | metadata=metadata, 85 | ) 86 | 87 | 88 | @tp.overload 89 | def dataclass(cls: tp.Type[A]) -> tp.Type[A]: 90 | ... 91 | 92 | 93 | @tp.overload 94 | def dataclass( 95 | *, 96 | init: bool = True, 97 | repr: bool = True, 98 | eq: bool = True, 99 | order: bool = False, 100 | unsafe_hash: bool = False, 101 | frozen: bool = False, 102 | ) -> tp.Callable[[tp.Type[A]], tp.Type[A]]: 103 | ... 104 | 105 | 106 | @tpe.dataclass_transform(field_specifiers=(variable, param, field, static_field)) 107 | def dataclass( 108 | cls: tp.Optional[tp.Type[A]] = None, 109 | init: bool = True, 110 | repr: bool = True, 111 | eq: bool = True, 112 | order: bool = False, 113 | unsafe_hash: bool = False, 114 | frozen: bool = False, 115 | ) -> tp.Union[tp.Type[A], tp.Callable[[tp.Type[A]], tp.Type[A]]]: 116 | ... 117 | -------------------------------------------------------------------------------- /ideas/pure_pytree/full/partitioning_full.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import jax.tree_util as jtu 5 | 6 | A = tp.TypeVar("A") 7 | CollectionPredicate = tp.Callable[[str], bool] 8 | Leaf = tp.Any 9 | Leaves = tp.List[Leaf] 10 | KeyPath = tp.Tuple[tp.Hashable, ...] 11 | LeafPredicate = tp.Callable[[tp.Any], bool] 12 | 13 | 14 | class Variable: 15 | __slots__ = ("_value", "_collection") 16 | 17 | def __init__(self, value: tp.Any, collection: str = "params"): 18 | self._value = value 19 | self._collection = collection 20 | 21 | @property 22 | def value(self) -> tp.Any: 23 | return self._value 24 | 25 | @property 26 | def collection(self) -> str: 27 | return self._collection 28 | 29 | @classmethod 30 | def from_value(cls, value: tp.Any) -> "Variable": 31 | return value if isinstance(value, Variable) else Variable(value) 32 | 33 | def copy(self) -> "Variable": 34 | return Variable(self.value, self.collection) 35 | 36 | def update(self, value: tp.Any) -> "Variable": 37 | if isinstance(value, Variable): 38 | if value.collection != self.collection: 39 | raise ValueError( 40 | f"Cannot update variable with value from a different collection. " 41 | f"Expected collection {self.collection}, got {value.collection}" 42 | ) 43 | value = value.value 44 | return Variable(value, self.collection) 45 | 46 | def __repr__(self) -> str: 47 | return f"Variable({self.value}, collection={self.collection})" 48 | 49 | 50 | def _flatten_variable_with_keys(variable: Variable): 51 | node = (jtu.GetAttrKey("value"), variable.value) 52 | return (node,), variable.collection 53 | 54 | 55 | def _flatten_variable(variable: Variable): 56 | return (variable.value,), variable.collection 57 | 58 | 59 | def _unflatten_variable(collection: str, nodes: tp.Tuple[tp.Any]): 60 | return Variable(nodes[0], collection) 61 | 62 | 63 | jax.tree_util.register_pytree_with_keys( 64 | Variable, 65 | _flatten_variable_with_keys, 66 | _unflatten_variable, 67 | flatten_func=_flatten_variable, 68 | ) 69 | 70 | 71 | class Nothing: 72 | 73 | def __repr__(self) -> str: 74 | return "Nothing" # pragma: no cover 75 | 76 | 77 | def _nothing_flatten(x): 78 | return (), None 79 | 80 | 81 | def _nothing_unflatten(aux_data, children): 82 | return NOTHING 83 | 84 | 85 | NOTHING = Nothing() 86 | 87 | jtu.register_pytree_node(Nothing, _nothing_flatten, _nothing_unflatten) 88 | 89 | 90 | class StrPath(tp.Tuple[str, ...]): 91 | pass 92 | 93 | 94 | class Partition(tp.Dict[tp.Tuple[str, ...], Leaf]): 95 | 96 | def __setitem__(self, key, value): 97 | raise TypeError("Partition is immutable") 98 | 99 | 100 | def _partition_flatten_with_keys( 101 | x: Partition, 102 | ) -> tp.Tuple[ 103 | tp.Tuple[tp.Tuple[StrPath, Leaf], ...], tp.Tuple[tp.Tuple[str, ...], ...] 104 | ]: 105 | children = tuple((StrPath(key), value) for key, value in x.items()) 106 | return children, tuple(x.keys()) 107 | 108 | 109 | def _partition_unflatten(keys: tp.Tuple[StrPath, ...], leaves: tp.Tuple[Leaf, ...]): 110 | return Partition(zip(keys, leaves)) 111 | 112 | 113 | jax.tree_util.register_pytree_with_keys( 114 | Partition, _partition_flatten_with_keys, _partition_unflatten 115 | ) 116 | 117 | 118 | def _key_path_to_str_gen(key_path: KeyPath) -> tp.Generator[str, None, None]: 119 | for key_entry in key_path: 120 | if isinstance(key_entry, StrPath): 121 | yield from key_entry 122 | elif isinstance(key_entry, jtu.SequenceKey): 123 | yield str(key_entry.idx) 124 | elif isinstance(key_entry, jtu.DictKey): # "['a']" 125 | yield str(key_entry.key) 126 | elif isinstance(key_entry, jtu.GetAttrKey): 127 | yield str(key_entry.name) 128 | elif isinstance(key_entry, jtu.FlattenedIndexKey): 129 | yield str(key_entry.key) 130 | elif hasattr(key_entry, "__dict__") and len(key_entry.__dict__) == 1: 131 | yield str(next(iter(key_entry.__dict__.values()))) 132 | else: 133 | yield str(key_entry) 134 | 135 | 136 | def _key_path_to_str_path(key_path: KeyPath) -> StrPath: 137 | return StrPath(_key_path_to_str_gen(key_path)) 138 | 139 | 140 | class StateDef(tp.Generic[A]): 141 | __slots__ = ("treedef",) 142 | 143 | def __init__(self, treedef: jtu.PyTreeDef): 144 | self.treedef = treedef 145 | 146 | def merge(self, *partitions: Partition) -> A: 147 | raise NotImplementedError 148 | 149 | 150 | def statedef_flatten(x: StateDef): 151 | return (), x.treedef 152 | 153 | 154 | def statedef_unflatten(treedef, children): 155 | return StateDef(treedef) 156 | 157 | 158 | jtu.register_pytree_node(StateDef, statedef_flatten, statedef_unflatten) 159 | 160 | 161 | def tree_partition( 162 | pytree: A, 163 | *predicates: CollectionPredicate, 164 | is_leaf: tp.Optional[LeafPredicate] = None, 165 | ) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef[A]]: 166 | paths_leaves: tp.List[tp.Tuple[KeyPath, Leaf]] 167 | paths_leaves, treedef = jax.tree_util.tree_flatten_with_path( 168 | pytree, 169 | is_leaf=lambda x: (isinstance(x, Variable) or x is NOTHING) 170 | or (False if is_leaf is None else is_leaf(x)), 171 | ) 172 | 173 | leaves: tp.Tuple[Leaf, ...] 174 | paths, leaves = zip(*paths_leaves) 175 | paths = tuple(map(_key_path_to_str_path, paths)) 176 | 177 | # we have n + 1 partitions, where n is the number of predicates 178 | # the last partition is for values that don't match any predicate 179 | partition_leaves: tp.Tuple[Leaves, ...] = tuple( 180 | [NOTHING] * len(leaves) for _ in range(len(predicates) + 1) 181 | ) 182 | for j, leaf in enumerate(leaves): 183 | for i, predicate in enumerate(predicates): 184 | if isinstance(leaf, Variable) and predicate(leaf.collection): 185 | partition_leaves[i][j] = leaf 186 | break 187 | else: 188 | # if we didn't break, set leaf to last partition 189 | partition_leaves[-1][j] = leaf 190 | 191 | partitions = tuple(Partition(zip(paths, partition)) for partition in partition_leaves) 192 | return partitions, StateDef(treedef) 193 | 194 | 195 | def get_partition( 196 | pytree, 197 | predicate: CollectionPredicate, 198 | is_leaf: tp.Optional[LeafPredicate] = None, 199 | ) -> Partition: 200 | (partition, _rest), _treedef = tree_partition(pytree, predicate, is_leaf=is_leaf) 201 | return partition 202 | 203 | 204 | def _get_non_nothing( 205 | paths: tp.Tuple[StrPath, ...], 206 | leaves: tp.Tuple[tp.Union[Leaf, Nothing], ...], 207 | position: int, 208 | ): 209 | # check that all paths are the same 210 | paths_set = set(paths) 211 | if len(paths_set) != 1: 212 | raise ValueError( 213 | "All partitions must have the same paths, " 214 | f" at position [{position}] got " 215 | "".join(f"\n- {path}" for path in paths_set) 216 | ) 217 | non_null = [option for option in leaves if option is not NOTHING] 218 | if len(non_null) == 0: 219 | raise ValueError(f"Expected at least one non-null value for position [{position}]") 220 | elif len(non_null) > 1: 221 | raise ValueError(f"Expected at most one non-null value for position [{position}]") 222 | return non_null[0] 223 | 224 | 225 | def merge_partitions( 226 | partitions: tp.Sequence[Partition], treedef: jax.tree_util.PyTreeDef 227 | ): 228 | lenghts = [len(partition) for partition in partitions] 229 | if not all(length == lenghts[0] for length in lenghts): 230 | raise ValueError( 231 | "All partitions must have the same length, got " 232 | f"{', '.join(str(length) for length in lenghts)}" 233 | ) 234 | 235 | partition_paths = (list(partition.keys()) for partition in partitions) 236 | partition_leaves = (list(partition.values()) for partition in partitions) 237 | 238 | merged_leaves = [ 239 | _get_non_nothing(paths, leaves, i) 240 | for i, (paths, leaves) in enumerate( 241 | zip(zip(*partition_paths), zip(*partition_leaves)) 242 | ) 243 | ] 244 | 245 | return jax.tree_util.tree_unflatten(treedef, merged_leaves) 246 | -------------------------------------------------------------------------------- /ideas/pure_pytree/full/state_full.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from types import MappingProxyType 3 | 4 | import jax 5 | import jax.tree_util as jtu 6 | from pure.partitioning import Partition, StateDef, Variable 7 | 8 | Node = tp.Union[Variable, "State"] 9 | S = tp.TypeVar("S", bound="State") 10 | 11 | 12 | class State(tp.Mapping[str, Node]): 13 | __slots__ = ("_variables",) 14 | 15 | def __init__(self, *args, **kwargs: tp.Union[Node, jax.Array]): 16 | self._variables = { 17 | k: self._create_node_field(v) for k, v in dict(*args, **kwargs).items() 18 | } 19 | 20 | @staticmethod 21 | def _create_node_field(value: tp.Any) -> Node: 22 | if isinstance(value, State): 23 | return value 24 | else: 25 | return Variable.from_value(value) 26 | 27 | @staticmethod 28 | def _update_node_field(node: Node, value: tp.Any) -> Node: 29 | if isinstance(node, State) and isinstance(value, State): 30 | return value 31 | elif isinstance(node, Variable) and isinstance(value, Variable): 32 | return node.update(value) 33 | else: 34 | raise ValueError( 35 | f"Cannot update node of type {type(node).__name__} with " 36 | f"value of type {type(value).__name__}" 37 | ) 38 | 39 | def __getitem__(self, name: str) -> tp.Any: 40 | return self._variables[name].value 41 | 42 | __getattr__ = __getitem__ 43 | 44 | def __iter__(self) -> tp.Iterator[str]: 45 | return iter(self._variables) 46 | 47 | def __len__(self) -> int: 48 | return len(self._variables) 49 | 50 | def keys(self) -> tp.KeysView[str]: 51 | return self._variables.keys() 52 | 53 | def values(self) -> tp.ValuesView[Node]: 54 | return self._variables.values() 55 | 56 | def __repr__(self) -> str: 57 | return f"State({self._variables})" 58 | 59 | def update(self, *args, **kwargs: tp.Union[Node, tp.Any]) -> "State": 60 | raise NotImplementedError 61 | 62 | @tp.overload 63 | def partition(self) -> tp.Tuple[tp.Dict[str, Partition], StateDef["State"]]: 64 | ... 65 | 66 | @tp.overload 67 | def partition(self, collection: str) -> tp.Tuple[Partition, StateDef["State"]]: 68 | ... 69 | 70 | @tp.overload 71 | def partition( 72 | self, collection: str, *collections: str 73 | ) -> tp.Tuple[tp.Tuple[Partition, ...], StateDef["State"]]: 74 | ... 75 | 76 | def partition( 77 | self, *collections: str 78 | ) -> tp.Tuple[ 79 | tp.Union[tp.Dict[str, Partition], tp.Tuple[Partition, ...], Partition], 80 | StateDef["State"], 81 | ]: 82 | raise NotImplementedError 83 | 84 | @tp.overload 85 | def get_partition(self, collection: str) -> Partition: 86 | ... 87 | 88 | @tp.overload 89 | def get_partition( 90 | self, collection: str, *collections: str 91 | ) -> tp.Tuple[Partition, ...]: 92 | ... 93 | 94 | def get_partition( 95 | self, *collections: str 96 | ) -> tp.Union[Partition, tp.Tuple[Partition, ...]]: 97 | raise NotImplementedError 98 | 99 | def update_partition(self, partition: Partition, *partitions: Partition) -> "State": 100 | raise NotImplementedError 101 | 102 | @tp.overload 103 | def pop(self, name: str) -> Node: 104 | ... 105 | 106 | @tp.overload 107 | def pop(self, name: str, *names: str) -> tp.Tuple[Node, ...]: 108 | ... 109 | 110 | def pop(self, *names: str) -> tp.Union[Node, tp.Tuple[Node, ...]]: 111 | if len(names) == 0: 112 | raise ValueError("pop expected at least 1 argument, got 0") 113 | elif len(names) == 1: 114 | name = names[0] 115 | return self._variables.pop(name) 116 | else: 117 | return tuple(self._variables.pop(name) for name in names) 118 | 119 | 120 | def _state_flatten_with_keys(state: State): 121 | nodes = tuple((jtu.GetAttrKey(name), variable) for name, variable in state.items()) 122 | names = tuple(state) 123 | return nodes, names 124 | 125 | 126 | def _state_unflatten(names: tp.Tuple[str, ...], nodes: tp.Tuple[Variable, ...]): 127 | return State(zip(names, nodes)) 128 | 129 | 130 | def _state_flatten(state: State): 131 | return tuple(state.values()), tuple(state) 132 | 133 | 134 | jtu.register_pytree_with_keys( 135 | State, _state_flatten_with_keys, _state_unflatten, flatten_func=_state_flatten 136 | ) 137 | -------------------------------------------------------------------------------- /ideas/pure_pytree/module.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | from pure_pytree.partitioning import Partition, PartitionDef, Variable 4 | 5 | A = tp.TypeVar("A", contravariant=True) 6 | M = tp.TypeVar("M", bound="Module") 7 | 8 | 9 | class Pytree: 10 | ... 11 | 12 | 13 | class Module(Pytree): 14 | 15 | def replace(self: M, **kwargs: tp.Any) -> M: 16 | ... 17 | 18 | @tp.overload 19 | def partition(self: M) -> tp.Tuple[tp.Dict[str, Partition], PartitionDef[M]]: 20 | ... 21 | 22 | @tp.overload 23 | def partition(self: M, collection: str) -> tp.Tuple[Partition, PartitionDef[M]]: 24 | ... 25 | 26 | @tp.overload 27 | def partition( 28 | self: M, collection: str, *collections: str 29 | ) -> tp.Tuple[tp.Tuple[Partition, ...], PartitionDef[M]]: 30 | ... 31 | 32 | def partition( 33 | self: M, *collections: str 34 | ) -> tp.Tuple[ 35 | tp.Union[tp.Dict[str, Partition], tp.Tuple[Partition, ...], Partition], 36 | PartitionDef[M], 37 | ]: 38 | ... 39 | 40 | @tp.overload 41 | def get_partition(self, collection: str) -> Partition: 42 | ... 43 | 44 | @tp.overload 45 | def get_partition( 46 | self, collection: str, *collections: str 47 | ) -> tp.Tuple[Partition, ...]: 48 | ... 49 | 50 | def get_partition( 51 | self, *collections: str 52 | ) -> tp.Union[Partition, tp.Tuple[Partition, ...]]: 53 | ... 54 | 55 | def merge(self: M, partition: Partition, *partitions: Partition) -> M: 56 | ... 57 | -------------------------------------------------------------------------------- /ideas/pure_pytree/partitioning.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import jax.tree_util as jtu 5 | 6 | A = tp.TypeVar("A") 7 | CollectionPredicate = tp.Callable[[str], bool] 8 | Leaf = tp.Any 9 | Leaves = tp.List[Leaf] 10 | KeyPath = tp.Tuple[tp.Hashable, ...] 11 | LeafPredicate = tp.Callable[[tp.Any], bool] 12 | 13 | 14 | class Variable: 15 | __slots__ = ("_value", "_collection") 16 | 17 | def __init__(self, value: tp.Any, collection: str = "params"): 18 | ... 19 | 20 | @property 21 | def value(self) -> tp.Any: 22 | ... 23 | 24 | @property 25 | def collection(self) -> str: 26 | ... 27 | 28 | @classmethod 29 | def from_value(cls, value: tp.Any) -> "Variable": 30 | ... 31 | 32 | def copy(self) -> "Variable": 33 | ... 34 | 35 | def update(self, value: tp.Any) -> "Variable": 36 | ... 37 | 38 | def __repr__(self) -> str: 39 | return f"Variable({self.value}, collection={self.collection})" 40 | 41 | 42 | def _flatten_variable_with_keys(variable: Variable): 43 | ... 44 | 45 | 46 | def _flatten_variable(variable: Variable): 47 | ... 48 | 49 | 50 | def _unflatten_variable(collection: str, nodes: tp.Tuple[tp.Any]): 51 | ... 52 | 53 | 54 | jax.tree_util.register_pytree_with_keys( 55 | Variable, 56 | _flatten_variable_with_keys, 57 | _unflatten_variable, 58 | flatten_func=_flatten_variable, 59 | ) 60 | 61 | 62 | class Nothing: 63 | 64 | def __repr__(self) -> str: 65 | ... 66 | 67 | 68 | def _nothing_flatten(x): 69 | ... 70 | 71 | 72 | def _nothing_unflatten(aux_data, children): 73 | ... 74 | 75 | 76 | NOTHING = Nothing() 77 | 78 | jtu.register_pytree_node(Nothing, _nothing_flatten, _nothing_unflatten) 79 | 80 | 81 | class StrPath(tp.Tuple[str, ...]): 82 | pass 83 | 84 | 85 | class Partition(tp.Dict[tp.Tuple[str, ...], Leaf]): 86 | 87 | def __setitem__(self, key, value): 88 | raise TypeError("Partition is immutable") 89 | 90 | 91 | def _partition_flatten_with_keys( 92 | x: Partition, 93 | ) -> tp.Tuple[ 94 | tp.Tuple[tp.Tuple[StrPath, Leaf], ...], tp.Tuple[tp.Tuple[str, ...], ...] 95 | ]: 96 | ... 97 | 98 | 99 | def _partition_unflatten(keys: tp.Tuple[StrPath, ...], leaves: tp.Tuple[Leaf, ...]): 100 | ... 101 | 102 | 103 | jax.tree_util.register_pytree_with_keys( 104 | Partition, _partition_flatten_with_keys, _partition_unflatten 105 | ) 106 | 107 | 108 | class PartitionDef(tp.Generic[A]): 109 | __slots__ = ("treedef",) 110 | 111 | def __init__(self, treedef: jtu.PyTreeDef): 112 | ... 113 | 114 | @property 115 | def treedef(self) -> jtu.PyTreeDef: 116 | ... 117 | 118 | def merge(self, *partitions: Partition) -> A: 119 | ... 120 | 121 | 122 | def partitiondef_flatten(x: PartitionDef): 123 | ... 124 | 125 | 126 | def statedef_unflatten(treedef, children): 127 | ... 128 | 129 | 130 | jtu.register_pytree_node(PartitionDef, partitiondef_flatten, statedef_unflatten) 131 | 132 | 133 | def tree_partition( 134 | pytree: A, 135 | *predicates: CollectionPredicate, 136 | is_leaf: tp.Optional[LeafPredicate] = None, 137 | ) -> tp.Tuple[tp.Tuple[Partition, ...], PartitionDef[A]]: 138 | ... 139 | 140 | 141 | def get_partition( 142 | pytree, 143 | predicate: CollectionPredicate, 144 | is_leaf: tp.Optional[LeafPredicate] = None, 145 | ) -> Partition: 146 | ... 147 | 148 | 149 | def merge_partitions( 150 | partitions: tp.Sequence[Partition], partitiondef: PartitionDef[A] 151 | ) -> A: 152 | ... 153 | -------------------------------------------------------------------------------- /ideas/pure_pytree/rngs.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | 5 | KeyArray = tp.Union[jax.Array, jax.random.KeyArray] 6 | 7 | 8 | class RngStream: 9 | 10 | def __init__( 11 | self, key: KeyArray, count: int = 0, count_path: tp.Tuple[int, ...] = () 12 | ): 13 | ... 14 | 15 | @property 16 | def key(self) -> jax.random.KeyArray: 17 | ... 18 | 19 | @property 20 | def count(self) -> int: 21 | ... 22 | 23 | @property 24 | def count_path(self) -> tp.Tuple[int, ...]: 25 | ... 26 | 27 | def next(self) -> jax.random.KeyArray: 28 | ... 29 | 30 | def fork(self) -> "RngStream": 31 | ... 32 | 33 | 34 | class Rngs: 35 | 36 | def __init__(self, **streams: tp.Union[KeyArray, RngStream]): 37 | ... 38 | 39 | def make_rng(self, stream: str) -> jax.Array: 40 | ... 41 | 42 | @tp.overload 43 | def fork(self, stream: str) -> RngStream: 44 | ... 45 | 46 | @tp.overload 47 | def fork(self, stream: str, *streams: str) -> tp.Tuple[RngStream, ...]: 48 | ... 49 | 50 | def fork(self, *streams: str) -> tp.Union[RngStream, tp.Tuple[RngStream, ...]]: 51 | ... 52 | -------------------------------------------------------------------------------- /ideas/pure_pytree_example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Tuple 4 | 5 | import jax 6 | import pure_pytree as pure 7 | 8 | 9 | class Linear(pure.Module): 10 | kernel: jax.Array = pure.param() 11 | bias: jax.Array = pure.param() 12 | 13 | def __init__(self, din: int, dout: int, *, rngs: pure.Rngs): 14 | self.kernel = jax.random.uniform(rngs.make_rng("params"), (din, dout)) 15 | self.bias = jax.numpy.zeros((dout,)) 16 | 17 | def __call__(self, x): 18 | return x @ self.kernel + self.bias 19 | 20 | 21 | class BatchNorm(pure.Module): 22 | scale: jax.Array = pure.param() 23 | bias: jax.Array = pure.param() 24 | mean: jax.Array = pure.variable("batch_stats") 25 | var: jax.Array = pure.variable("batch_stats") 26 | mu: float = pure.static_field() 27 | 28 | def __init__(self, din: int, mu: float = 0.95, *, rngs: pure.Rngs): 29 | self.scale = jax.random.uniform(rngs.make_rng("params"), (din,)) 30 | self.bias = jax.numpy.zeros((din,)) 31 | self.mean = jax.numpy.zeros((din,)) 32 | self.var = jax.numpy.ones((din,)) 33 | self.mu = mu 34 | 35 | def __call__(self, x, use_running_averages: bool) -> Tuple[jax.Array, "BatchNorm"]: 36 | scale, bias = self.scale, self.bias 37 | if use_running_averages: 38 | mean, var = self.mean, self.var 39 | else: 40 | axis = tuple(range(0, x.ndim - 1)) 41 | mean = jax.numpy.mean(x, axis=axis) 42 | var = jax.numpy.var(x, axis=axis) 43 | # ema update 44 | self = self.replace( 45 | mean=self.mu * self.mean + (1 - self.mu) * mean, 46 | var=self.mu * self.var + (1 - self.mu) * var, 47 | ) 48 | 49 | x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias 50 | 51 | return x, self 52 | 53 | 54 | class Dropout(pure.Module): 55 | 56 | def __init__(self, rate: float): 57 | raise NotImplementedError 58 | 59 | def __call__(self, x, *, deterministic: bool, rngs: pure.Rngs) -> jax.Array: 60 | key = rngs.make_rng("dropout") 61 | ... 62 | raise NotImplementedError 63 | 64 | 65 | class MLP(pure.Module): 66 | 67 | def __init__(self, din: int, dmid: int, dout: int, *, rngs: pure.Rngs): 68 | self.linear1 = Linear(din, dmid, rngs=rngs) 69 | self.bn1 = BatchNorm(dmid, rngs=rngs) 70 | self.dropout = Dropout(0.5) 71 | self.linear2 = Linear(dmid, dout, rngs=rngs) 72 | 73 | def __call__( 74 | self, x: jax.Array, *, train: bool, rngs: pure.Rngs 75 | ) -> Tuple[jax.Array, "MLP"]: 76 | x = self.linear1(x) 77 | x, bn1 = self.bn1(x, use_running_averages=not train) 78 | x = self.dropout(x, deterministic=not train, rngs=rngs) 79 | x = jax.nn.relu(x) 80 | x = self.linear2(x) 81 | return x, self.replace(bn1=bn1) 82 | 83 | 84 | rngs = pure.Rngs(params=jax.random.PRNGKey(0)) 85 | model = MLP(10, 20, 30, rngs=rngs) 86 | 87 | 88 | @jax.jit 89 | def train_step(model: MLP, key, batch): 90 | x, y = batch 91 | params = model.get_partition("params") 92 | rngs = pure.Rngs(dropout=key) 93 | 94 | def loss(params): 95 | _model = model.merge(params) 96 | y_pred, _model = model(x, train=True, rngs=rngs) 97 | loss = jax.numpy.mean((y_pred - y) ** 2) 98 | return loss, _model 99 | 100 | grads, model = jax.grad(loss, has_aux=True)(params) 101 | params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) 102 | model = model.merge(params) 103 | 104 | return model 105 | 106 | 107 | # ---------------------------------------- 108 | # scan over layers + shared batchnorm 109 | # ---------------------------------------- 110 | 111 | n_layers = 10 112 | params_keys = jax.random.PRNGKey(0) 113 | params_keys = jax.random.split(params_keys, n_layers) 114 | 115 | 116 | @partial(jax.vmap, in_axes=0, out_axes=(0, None, None)) 117 | def create_state(params_key: jax.random.KeyArray): 118 | rngs = pure.Rngs(params=params_key) 119 | model = MLP(10, 20, 10, rngs=rngs) 120 | (params, batch_stats), modeldef = model.partition(nnx.Param, "batch_stats") 121 | return params, batch_stats, modeldef 122 | 123 | 124 | params, batch_stats, modeldef = create_state(params_keys) 125 | x = jax.numpy.zeros((32, 10)) 126 | dropout_key = jax.random.PRNGKey(1) 127 | dropout_stream = pure.RngStream(jax.random.split(dropout_key, n_layers)) 128 | 129 | 130 | def scan_fn( 131 | carry: Tuple[jax.Array, pure.Partition], 132 | inputs: Tuple[pure.Partition, pure.RngStream], 133 | ): 134 | # extract args 135 | x, batch_stats = carry 136 | params, dropout_stream = inputs 137 | 138 | # create state and rngs 139 | model = pure.merge([params, batch_stats], modeldef) 140 | rngs = pure.Rngs(dropout=dropout_stream) 141 | 142 | # forward pass 143 | x, model = model(x, train=True, rngs=rngs) 144 | 145 | # partition state 146 | params, batch_stats = model.get_partition("params", "batch_stats") 147 | 148 | return (x, batch_stats), params 149 | 150 | 151 | (y, batch_stats), params = jax.lax.scan( 152 | scan_fn, (x, batch_stats), (params, dropout_stream) 153 | ) 154 | model = pure.merge([params, batch_stats], modeldef) 155 | -------------------------------------------------------------------------------- /ideas/shape_inference.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import random 6 | 7 | import nnx 8 | 9 | 10 | class Linear(nnx.Module): 11 | 12 | @tp.overload 13 | def __init__(self, *, din: int, dout: int, ctx: nnx.Context): 14 | ... 15 | 16 | @tp.overload 17 | def __init__(self, *, dout: int): 18 | ... 19 | 20 | @tp.overload 21 | def __init__( 22 | self, 23 | *, 24 | din: tp.Optional[int] = None, 25 | dout: int, 26 | ctx: tp.Optional[nnx.Context] = None, 27 | ): 28 | ... 29 | 30 | def __init__( 31 | self, 32 | *, 33 | din: tp.Optional[int] = None, 34 | dout: int, 35 | ctx: tp.Optional[nnx.Context] = None, 36 | ): 37 | self.dout = dout 38 | if din is not None: 39 | if ctx is None: 40 | raise ValueError("ctx must be provided if din is provided") 41 | self.init_variables(din, ctx) 42 | 43 | def init_variables(self, din: int, ctx: nnx.Context): 44 | key = ctx.make_rng("params") 45 | self.w = nnx.Param(random.uniform(key, (din, self.dout))) 46 | self.b = nnx.Param(jnp.zeros((self.dout,))) 47 | 48 | def __call__( 49 | self, x: jax.Array, *, ctx: tp.Optional[nnx.Context] = None 50 | ) -> jax.Array: 51 | if self.is_initializing and not hasattr(self, "w"): 52 | if ctx is None: 53 | raise ValueError("ctx must be provided to initialize module") 54 | self.init_variables(x.shape[-1], ctx) 55 | 56 | return x @ self.w + self.b 57 | 58 | 59 | class BatchNorm(nnx.Module): 60 | 61 | @tp.overload 62 | def __init__(self, *, mu: float = 0.95): 63 | ... 64 | 65 | @tp.overload 66 | def __init__(self, *, din: int, mu: float = 0.95, ctx: nnx.Context): 67 | ... 68 | 69 | @tp.overload 70 | def __init__( 71 | self, 72 | *, 73 | din: tp.Optional[int] = None, 74 | mu: float = 0.95, 75 | ctx: tp.Optional[nnx.Context] = None, 76 | ): 77 | ... 78 | 79 | def __init__( 80 | self, 81 | *, 82 | din: tp.Optional[int] = None, 83 | mu: float = 0.95, 84 | ctx: tp.Optional[nnx.Context] = None, 85 | ): 86 | self.mu = mu 87 | 88 | if din is not None: 89 | if ctx is None: 90 | raise ValueError("ctx must be provided if din is provided") 91 | self.init_variables(din, ctx) 92 | 93 | def init_variables(self, din: int, ctx: nnx.Context): 94 | self.scale = nnx.Param(jax.numpy.ones((din,))) 95 | self.bias = nnx.Param(jax.numpy.zeros((din,))) 96 | self.mean = nnx.BatchStat(jax.numpy.zeros((din,))) 97 | self.var = nnx.BatchStat(jax.numpy.ones((din,))) 98 | 99 | def __call__( 100 | self, x, *, train: bool, ctx: tp.Optional[nnx.Context] = None 101 | ) -> jax.Array: 102 | if self.is_initializing and not hasattr(self, "scale"): 103 | if ctx is None: 104 | raise ValueError("ctx must be provided to initialize module") 105 | self.init_variables(x.shape[-1], ctx) 106 | 107 | if train: 108 | axis = tuple(range(x.ndim - 1)) 109 | mean = jax.numpy.mean(x, axis=axis) 110 | var = jax.numpy.var(x, axis=axis) 111 | # ema update 112 | self.mean = self.mu * self.mean + (1 - self.mu) * mean 113 | self.var = self.mu * self.var + (1 - self.mu) * var 114 | else: 115 | mean, var = self.mean, self.var 116 | 117 | scale, bias = self.scale, self.bias 118 | x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias 119 | return x 120 | 121 | 122 | class Dropout(nnx.Module): 123 | 124 | def __init__(self, rate: float): 125 | self.rate = rate 126 | 127 | def __call__(self, x: jax.Array, *, train: bool, ctx: nnx.Context) -> jax.Array: 128 | if train: 129 | mask = random.bernoulli(ctx.make_rng("dropout"), (1 - self.rate), x.shape) 130 | x = x * mask / (1 - self.rate) 131 | return x 132 | 133 | 134 | # ---------------------------- 135 | # test Linear 136 | # ---------------------------- 137 | print("test Linear") 138 | 139 | # eager 140 | m1 = Linear(din=32, dout=10, ctx=nnx.context(params=0)) 141 | y = m1(x=jnp.ones((1, 32))) 142 | print(jax.tree_map(jnp.shape, m1.get_state())) 143 | 144 | # lazy 145 | m2 = Linear(dout=10) 146 | y = m2.init(x=jnp.ones((1, 32)), ctx=nnx.context(params=0)) 147 | print(jax.tree_map(jnp.shape, m2.get_state())) 148 | 149 | # usage 150 | y1 = m1(x=jnp.ones((1, 32))) 151 | y2 = m2(x=jnp.ones((1, 32))) 152 | 153 | # ---------------------------- 154 | # Test scan 155 | # ---------------------------- 156 | print("\ntest scan") 157 | 158 | 159 | class Block(nnx.Module): 160 | 161 | def __init__( 162 | self, 163 | din: tp.Optional[int] = None, 164 | dout: int = 10, 165 | ctx: tp.Optional[nnx.Context] = None, 166 | ): 167 | self.linear = Linear(din=din, dout=dout, ctx=ctx) 168 | self.bn = BatchNorm(din=dout if din is not None else None, ctx=ctx) 169 | self.dropout = Dropout(0.5) 170 | 171 | def __call__(self, x: jax.Array, _, *, train: bool, ctx: nnx.Context): 172 | x = self.linear(x, ctx=ctx) 173 | x = self.bn(x, train=train, ctx=ctx) 174 | x = self.dropout(x, train=train, ctx=ctx) 175 | x = jax.nn.gelu(x) 176 | return x, None 177 | 178 | 179 | MLP = nnx.Scan( 180 | Block, 181 | variable_axes={nnx.Param: 0}, 182 | variable_carry=nnx.BatchStat, 183 | split_rngs={"params": True, "dropout": True}, 184 | length=5, 185 | ) 186 | 187 | 188 | # eager 189 | mlp = MLP(din=10, dout=10, ctx=nnx.context(params=0)) 190 | y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, ctx=nnx.context(dropout=1)) 191 | print(f"{y.shape=}") 192 | print("state =", jax.tree_map(jnp.shape, mlp.get_state())) 193 | print() 194 | 195 | # lazy 196 | mlp = MLP(dout=10) 197 | mlp.init(jnp.ones((1, 10)), None, train=False, ctx=nnx.context(params=0)) 198 | y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, ctx=nnx.context(dropout=1)) 199 | print(f"{y.shape=}") 200 | print("state =", jax.tree_map(jnp.shape, mlp.get_state())) 201 | -------------------------------------------------------------------------------- /nnx/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.0" 2 | 3 | 4 | from .containers import ( 5 | BatchStat, 6 | Cache, 7 | Container, 8 | ContainerMetadata, 9 | Intermediate, 10 | Node, 11 | Param, 12 | Static, 13 | Variable, 14 | with_metadata, 15 | ) 16 | from .contextlib import Context, context 17 | from .dataclasses import ( 18 | dataclass, 19 | field, 20 | node_field, 21 | param_field, 22 | static_field, 23 | var_field, 24 | ) 25 | from .errors import TraceContextError 26 | from .helpers import Dict, Sequence, TrainState 27 | from .module import Module, ModuleDef, Pure, PureModule 28 | from .nn import initializers 29 | from .nn.activations import ( 30 | celu, 31 | elu, 32 | gelu, 33 | glu, 34 | hard_sigmoid, 35 | hard_silu, 36 | hard_swish, 37 | hard_tanh, 38 | leaky_relu, 39 | log_sigmoid, 40 | log_softmax, 41 | logsumexp, 42 | normalize, 43 | one_hot, 44 | relu, 45 | relu6, 46 | selu, 47 | sigmoid, 48 | silu, 49 | soft_sign, 50 | softmax, 51 | softplus, 52 | standardize, 53 | swish, 54 | tanh, 55 | ) 56 | from .nn.linear import Conv, Embed, Linear 57 | from .nn.normalization import BatchNorm, LayerNorm 58 | from .nn.stochastic import Dropout 59 | from .nodes import is_node, register_node_type 60 | from .partitioning import All, Not, buffers 61 | from .pytreelib import Pytree, TreeNode 62 | from .spmd import ( 63 | PARTITION_NAME, 64 | get_partition_spec, 65 | logical_axis_rules, 66 | logical_to_mesh, 67 | with_logical_constraint, 68 | with_logical_partitioning, 69 | ) 70 | from .state import State 71 | from .transforms import Remat, Scan, grad, jit, remat, scan 72 | -------------------------------------------------------------------------------- /nnx/containers.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import functools 3 | import typing as tp 4 | from abc import ABCMeta 5 | from functools import partial 6 | from typing import Any 7 | 8 | import jax.tree_util as jtu 9 | 10 | from nnx import nodes, reprlib 11 | 12 | A = tp.TypeVar("A") 13 | B = tp.TypeVar("B") 14 | F = tp.TypeVar("F", bound=tp.Callable[..., tp.Any]) 15 | Sharding = tp.Tuple[tp.Optional[str], ...] 16 | 17 | 18 | @dataclasses.dataclass 19 | class ContainerMetadata(tp.Generic[A]): 20 | value: A 21 | metadata: tp.Mapping[str, tp.Any] 22 | 23 | 24 | class ContainerMetaclass(ABCMeta): 25 | 26 | def __call__(self, value: A, **metadata: tp.Any) -> A: 27 | if isinstance(value, Container): 28 | container = value 29 | value = container.value 30 | else: 31 | container = None 32 | 33 | obj = super().__call__(value, **metadata) 34 | 35 | if container is not None and not container.is_equivalent(obj): 36 | raise ValueError( 37 | f"input value of type '{type(container).__name__}' is not compatible " 38 | f"with return type '{type(obj).__name__}'" 39 | ) 40 | 41 | return obj 42 | 43 | 44 | class Container(tp.Generic[A], reprlib.Representable, metaclass=ContainerMetaclass): 45 | value: A 46 | 47 | def __init__(self, value: tp.Union[A, ContainerMetadata[A]], **metadata: tp.Any): 48 | if isinstance(value, ContainerMetadata): 49 | metadata.update(value.metadata) 50 | value = tp.cast(A, value.value) 51 | 52 | vars(self).update(metadata, value=value) 53 | 54 | if tp.TYPE_CHECKING: 55 | 56 | def __getattr__(self, name: str) -> tp.Any: 57 | ... 58 | 59 | def __eq__(self, other: object) -> bool: 60 | if not isinstance(other, Container): 61 | return False 62 | return type(self) is type(other) and vars(other) == vars(self) 63 | 64 | @tp.overload 65 | def replace(self, *, value: B, **kwargs) -> "Container[B]": 66 | ... 67 | 68 | @tp.overload 69 | def replace(self, **kwargs) -> "Container[A]": 70 | ... 71 | 72 | def replace(self, **kwargs) -> "Container[tp.Any]": 73 | if "value" in kwargs: 74 | value = kwargs["value"] 75 | if isinstance(value, Container): 76 | if not self.is_equivalent(value): 77 | raise ValueError( 78 | "Cannot replace value from incompatible container, " 79 | f"expected {self}, got {value}" 80 | ) 81 | kwargs["value"] = value.value 82 | 83 | attributes = vars(self).copy() 84 | # validate keys 85 | for key in kwargs: 86 | if key not in attributes: 87 | raise ValueError(f"Unknown metadata key {key!r}") 88 | attributes.update(**kwargs) 89 | node_type = type(self) 90 | return node_type(**attributes) 91 | 92 | def is_equivalent(self, other: tp.Any) -> bool: 93 | def metadata_fields(container: Container[tp.Any]) -> tp.Dict[str, tp.Any]: 94 | return {k: v for k, v in vars(container).items() if k != "value"} 95 | 96 | return type(self) is type(other) and metadata_fields(self) == metadata_fields(other) 97 | 98 | def copy(self: "Container[A]") -> "Container[A]": 99 | return type(self)(**vars(self)) 100 | 101 | def __nnx_repr__(self): 102 | yield reprlib.Object(type=type(self)) 103 | for name, value in vars(self).items(): 104 | yield reprlib.Attr(name, repr(value)) 105 | 106 | 107 | class NodeBase(Container[A]): 108 | 109 | def __init_subclass__(cls): 110 | super().__init_subclass__() 111 | 112 | def _node_flatten( 113 | x: NodeBase[tp.Any], 114 | *, 115 | with_keys: bool, 116 | ): 117 | attributes = vars(x).copy() 118 | value = attributes.pop("value") 119 | if with_keys: 120 | node = (jtu.GetAttrKey("value"), value) 121 | else: 122 | node = value 123 | 124 | return (node,), attributes 125 | 126 | def _node_unflatten( 127 | metadata: tp.Mapping[str, tp.Any], children: tp.Tuple[A] 128 | ) -> NodeBase[A]: 129 | return cls(children[0], **metadata) 130 | 131 | jtu.register_pytree_with_keys( 132 | cls, 133 | partial(_node_flatten, with_keys=True), 134 | _node_unflatten, 135 | flatten_func=partial(_node_flatten, with_keys=False), 136 | ) 137 | 138 | 139 | class Node(NodeBase[A]): 140 | pass 141 | 142 | 143 | class Variable(Node[A]): 144 | sharding: tp.Optional[Sharding] 145 | 146 | def __init__( 147 | self, 148 | value: tp.Union[A, ContainerMetadata[A]], 149 | sharding: tp.Optional[Sharding] = None, 150 | **metadata: Any, 151 | ): 152 | super().__init__(value, sharding=sharding, **metadata) 153 | 154 | 155 | class Param(Variable[A]): 156 | pass 157 | 158 | 159 | class BatchStat(Variable[A]): 160 | pass 161 | 162 | 163 | class Cache(Variable[A]): 164 | pass 165 | 166 | 167 | class Intermediate(Variable[A]): 168 | pass 169 | 170 | 171 | class Static(Container[A], reprlib.Representable): 172 | 173 | def __init__(self, value: A): 174 | super().__init__(value) 175 | 176 | def __hash__(self) -> int: 177 | return hash(self.value) 178 | 179 | 180 | def _static_flatten(x: Static[tp.Any]): 181 | return (), x.value 182 | 183 | 184 | def _static_unflatten(metadata: A, _) -> Static[A]: 185 | return Static(metadata) 186 | 187 | 188 | jtu.register_pytree_node(Static, _static_flatten, _static_unflatten) 189 | 190 | 191 | def with_metadata( 192 | initializer: F, 193 | **metadata: tp.Any, 194 | ) -> F: 195 | @functools.wraps(initializer) 196 | def wrapper(*args): 197 | return ContainerMetadata(initializer(*args), metadata=metadata) 198 | 199 | return wrapper # type: ignore 200 | 201 | 202 | # register nodes 203 | nodes.register_node_type(Node) 204 | -------------------------------------------------------------------------------- /nnx/contextlib.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import dataclasses 3 | import hashlib 4 | import typing as tp 5 | 6 | import jax 7 | import jax.tree_util as jtu 8 | 9 | from nnx import errors, tracers 10 | 11 | KeyArray = jax.Array 12 | Counts = tp.Tuple[int, ...] 13 | 14 | 15 | def _stable_hash(data: Counts) -> int: 16 | hash_str = " ".join(str(x) for x in data) 17 | _hash = hashlib.blake2s(hash_str.encode()) 18 | hash_bytes = _hash.digest() 19 | # uint32 is represented as 4 bytes in big endian 20 | return int.from_bytes(hash_bytes[:4], byteorder="big") 21 | 22 | 23 | class ContextDef: 24 | __slots__ = ("_rng_counts", "_flags") 25 | 26 | def __init__( 27 | self, 28 | rng_counts: tp.Tuple[tp.Tuple[str, Counts], ...], 29 | flags: tp.Tuple[tp.Tuple[str, bool], ...], 30 | ): 31 | self._rng_counts = rng_counts 32 | self._flags = flags 33 | 34 | def merge(self, keys: tp.Mapping[str, KeyArray]) -> "Context": 35 | rngs = { 36 | name: RngStream(keys[name], count=0, count_path=count_path) 37 | for name, count_path in self._rng_counts 38 | } 39 | return Context(rngs=rngs, flags=dict(self._flags)) 40 | 41 | 42 | class PureContext(tp.Tuple[tp.Dict[str, KeyArray], ContextDef]): 43 | 44 | @classmethod 45 | def new(cls, keys: tp.Dict[str, KeyArray], contextdef: ContextDef): 46 | return cls((keys, contextdef)) 47 | 48 | @property 49 | def keys(self) -> tp.Dict[str, KeyArray]: 50 | return self[0] 51 | 52 | @property 53 | def contextdef(self) -> ContextDef: 54 | return self[1] 55 | 56 | def merge(self): 57 | return self.contextdef.merge(self.keys) 58 | 59 | 60 | def _pure_context_flatten(pure_context: PureContext): 61 | return tuple(pure_context), None 62 | 63 | 64 | def _pure_context_unflatten( 65 | aux_data: None, 66 | children: tp.Tuple[tp.Dict[str, KeyArray], ContextDef], 67 | ) -> PureContext: 68 | return PureContext(children) 69 | 70 | 71 | jtu.register_pytree_node(PureContext, _pure_context_flatten, _pure_context_unflatten) 72 | 73 | 74 | @dataclasses.dataclass 75 | class RngStream: 76 | key: KeyArray 77 | count: int = 0 78 | count_path: Counts = () 79 | 80 | 81 | class Context: 82 | __slots__ = ("_rngs", "_flags", "_trace_state") 83 | 84 | def __init__( 85 | self, 86 | rngs: tp.Mapping[str, RngStream], 87 | flags: tp.Mapping[str, bool], 88 | ): 89 | self._rngs = rngs 90 | self._flags = flags 91 | self._trace_state = tracers.TraceState() 92 | 93 | def has_rng(self, name: str) -> bool: 94 | return name in self._rngs 95 | 96 | def make_rng(self, name: str) -> KeyArray: 97 | if name not in self._rngs: 98 | raise ValueError(f"Unknown Rng Stream: {name}") 99 | elif not self._trace_state.is_valid(): 100 | raise errors.TraceContextError("Cannot use Context from a different trace level") 101 | 102 | stream = self._rngs[name] 103 | fold_data = _stable_hash(stream.count_path + (stream.count,)) 104 | stream.count += 1 105 | return jax.random.fold_in(stream.key, fold_data) 106 | 107 | def copy(self) -> "Context": 108 | return Context(rngs=self._rngs, flags=self._flags) 109 | 110 | def has_flag(self, name: str) -> bool: 111 | return name in self._flags 112 | 113 | def get_flag(self, name: str) -> tp.Optional[bool]: 114 | return self._flags.get(name, None) 115 | 116 | def partition(self) -> PureContext: 117 | if not self._trace_state.is_valid(): 118 | raise errors.TraceContextError("Cannot use Context from a different trace level") 119 | 120 | def fork(stream) -> "RngStream": 121 | count_path = stream.count_path + (stream.count,) 122 | stream.count += 1 123 | return RngStream(stream.key, count_path=count_path) 124 | 125 | rngs = {name: fork(stream) for name, stream in self._rngs.items()} 126 | keys = {name: stream.key for name, stream in rngs.items()} 127 | rng_counts = tuple((name, stream.count_path) for name, stream in rngs.items()) 128 | return PureContext.new(keys, ContextDef(rng_counts, tuple(self._flags.items()))) 129 | 130 | 131 | def context( 132 | params: tp.Union[int, KeyArray, RngStream, None] = None, 133 | *, 134 | flags: tp.Optional[tp.Mapping[str, bool]] = None, 135 | **rngs: tp.Union[int, KeyArray, RngStream], 136 | ) -> Context: 137 | _flags = flags or {} 138 | 139 | if params is not None: 140 | rngs["params"] = params 141 | 142 | _rngs = { 143 | name: RngStream(jax.random.PRNGKey(value)) 144 | if isinstance(value, int) 145 | else RngStream(value) 146 | if isinstance(value, jax.Array) 147 | else value 148 | for name, value in rngs.items() 149 | } 150 | 151 | return Context(rngs=_rngs, flags=_flags) 152 | 153 | 154 | if tp.TYPE_CHECKING: 155 | ellipsis = builtins.ellipsis 156 | else: 157 | ellipsis = tp.Any 158 | 159 | RngPredicate = tp.Callable[[str], bool] 160 | RngFilterLiteral = tp.Union[str, RngPredicate, ellipsis, None] 161 | RngFilter = tp.Union[ 162 | RngFilterLiteral, tp.Sequence[RngFilterLiteral], tp.Mapping[RngFilterLiteral, bool] 163 | ] 164 | 165 | 166 | def to_rng_predicate(filter: RngFilter) -> RngPredicate: 167 | if filter is None: 168 | return lambda _: False 169 | elif filter is ...: 170 | return lambda _: True 171 | elif callable(filter): 172 | return filter 173 | elif isinstance(filter, str): 174 | return lambda name: name == filter 175 | elif isinstance(filter, tp.Mapping): 176 | predicates = tuple( 177 | to_rng_predicate(filter) for filter, include in filter.items() if include 178 | ) 179 | return lambda name: any(predicate(name) for predicate in predicates) 180 | elif isinstance(filter, tp.Sequence): 181 | predicates = tuple(map(to_rng_predicate, filter)) 182 | return lambda name: any(predicate(name) for predicate in predicates) 183 | else: 184 | raise TypeError(f"Invalid rng filter: {filter}") 185 | -------------------------------------------------------------------------------- /nnx/dataclasses.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing as tp 3 | 4 | import typing_extensions as tpe 5 | 6 | from nnx import containers 7 | 8 | A = tp.TypeVar("A") 9 | 10 | 11 | def field( 12 | *, 13 | default: tp.Any = dataclasses.MISSING, 14 | default_factory: tp.Any = dataclasses.MISSING, 15 | init: bool = True, 16 | repr: bool = True, 17 | hash: tp.Optional[bool] = None, 18 | compare: bool = True, 19 | metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, 20 | ): 21 | return dataclasses.field( # type: ignore 22 | default=default, 23 | default_factory=default_factory, 24 | init=init, 25 | repr=repr, 26 | hash=hash, 27 | compare=compare, 28 | metadata=metadata, 29 | ) 30 | 31 | 32 | def node_field( 33 | *, 34 | default: tp.Any = dataclasses.MISSING, 35 | default_factory: tp.Any = dataclasses.MISSING, 36 | init: bool = True, 37 | repr: bool = True, 38 | hash: tp.Optional[bool] = None, 39 | compare: bool = True, 40 | metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, 41 | ): 42 | if metadata is None: 43 | metadata = {} 44 | else: 45 | metadata = dict(metadata) 46 | 47 | if "nnx_container_fn" in metadata: 48 | raise ValueError("'nnx_container_fn' found in metadata") 49 | 50 | metadata["nnx_container_fn"] = lambda value: containers.Node(value) 51 | 52 | return field( 53 | default=default, 54 | default_factory=default_factory, 55 | init=init, 56 | repr=repr, 57 | hash=hash, 58 | compare=compare, 59 | metadata=metadata, 60 | ) 61 | 62 | 63 | def static_field( 64 | *, 65 | default: tp.Any = dataclasses.MISSING, 66 | default_factory: tp.Any = dataclasses.MISSING, 67 | init: bool = True, 68 | repr: bool = True, 69 | hash: tp.Optional[bool] = None, 70 | compare: bool = True, 71 | metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, 72 | ): 73 | if metadata is None: 74 | metadata = {} 75 | else: 76 | metadata = dict(metadata) 77 | 78 | if "nnx_container_fn" in metadata: 79 | raise ValueError("'nnx_container_fn' found in metadata") 80 | 81 | metadata["nnx_container_fn"] = lambda value: containers.Static(value) 82 | 83 | return field( 84 | default=default, 85 | default_factory=default_factory, 86 | init=init, 87 | repr=repr, 88 | hash=hash, 89 | compare=compare, 90 | metadata=metadata, 91 | ) 92 | 93 | 94 | def var_field( 95 | variable_type: tp.Type[containers.Variable[tp.Any]], 96 | *, 97 | default: tp.Any = dataclasses.MISSING, 98 | default_factory: tp.Any = dataclasses.MISSING, 99 | init: bool = True, 100 | repr: bool = True, 101 | hash: tp.Optional[bool] = None, 102 | compare: bool = True, 103 | metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, 104 | sharding: tp.Optional[containers.Sharding] = None, 105 | ) -> tp.Any: 106 | if metadata is None: 107 | metadata = {} 108 | else: 109 | metadata = dict(metadata) 110 | 111 | if "nnx_container_fn" in metadata: 112 | raise ValueError("'nnx_container_fn' found in metadata") 113 | 114 | metadata["nnx_container_fn"] = lambda value: variable_type(value, sharding=sharding) 115 | 116 | return field( 117 | default=default, 118 | default_factory=default_factory, 119 | init=init, 120 | repr=repr, 121 | hash=hash, 122 | compare=compare, 123 | metadata=metadata, 124 | ) 125 | 126 | 127 | def param_field( 128 | default: tp.Any = dataclasses.MISSING, 129 | *, 130 | default_factory: tp.Any = dataclasses.MISSING, 131 | init: bool = True, 132 | repr: bool = True, 133 | hash: tp.Optional[bool] = None, 134 | compare: bool = True, 135 | metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, 136 | ) -> tp.Any: 137 | return var_field( 138 | containers.Param, 139 | default=default, 140 | default_factory=default_factory, 141 | init=init, 142 | repr=repr, 143 | hash=hash, 144 | compare=compare, 145 | metadata=metadata, 146 | ) 147 | 148 | 149 | @tp.overload 150 | def dataclass(cls: tp.Type[A]) -> tp.Type[A]: 151 | ... 152 | 153 | 154 | @tp.overload 155 | def dataclass( 156 | *, 157 | init: bool = True, 158 | repr: bool = True, 159 | eq: bool = True, 160 | order: bool = False, 161 | unsafe_hash: bool = False, 162 | frozen: bool = False, 163 | ) -> tp.Callable[[tp.Type[A]], tp.Type[A]]: 164 | ... 165 | 166 | 167 | @tpe.dataclass_transform( 168 | field_specifiers=(field, node_field, static_field, var_field, param_field) 169 | ) 170 | def dataclass( 171 | cls: tp.Optional[tp.Type[A]] = None, 172 | init: bool = True, 173 | repr: bool = True, 174 | eq: bool = True, 175 | order: bool = False, 176 | unsafe_hash: bool = False, 177 | frozen: bool = False, 178 | ) -> tp.Union[tp.Type[A], tp.Callable[[tp.Type[A]], tp.Type[A]]]: 179 | decorator = dataclasses.dataclass( 180 | init=init, 181 | repr=repr, 182 | eq=eq, 183 | order=order, 184 | unsafe_hash=unsafe_hash, 185 | frozen=frozen, 186 | ) 187 | 188 | if cls is None: 189 | return decorator 190 | 191 | return decorator(cls) 192 | -------------------------------------------------------------------------------- /nnx/errors.py: -------------------------------------------------------------------------------- 1 | class TraceContextError(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /nnx/helpers.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing as tp 3 | 4 | import jax.numpy as jnp 5 | import optax 6 | 7 | from nnx import pytreelib 8 | from nnx.contextlib import Context 9 | from nnx.module import ApplyCaller, Module, ModuleDef, Pure 10 | from nnx.state import State 11 | 12 | A = tp.TypeVar("A") 13 | M = tp.TypeVar("M", bound=Module) 14 | 15 | 16 | class Dict(Module, tp.Mapping[str, A]): 17 | 18 | @tp.overload 19 | def __init__(self, __iterable: tp.Iterable[tp.Tuple[str, A]]): 20 | ... 21 | 22 | @tp.overload 23 | def __init__(self, __mapping: tp.Optional[tp.Mapping[str, A]] = None, **kwargs: A): 24 | ... 25 | 26 | def __init__(self, *args, **kwargs): 27 | for name, value in dict(*args, **kwargs).items(): 28 | setattr(self, name, value) 29 | 30 | def __getitem__(self, key) -> A: 31 | return getattr(self, key) 32 | 33 | def __setitem__(self, key, value): 34 | setattr(self, key, value) 35 | 36 | def __getattr__(self, key) -> A: 37 | return super().__getattribute__(key) 38 | 39 | def __setattr__(self, key, value): 40 | super().__setattr__(key, value) 41 | 42 | def __iter__(self) -> tp.Iterator[str]: 43 | return iter(vars(self)) 44 | 45 | def __len__(self) -> int: 46 | return len(vars(self)) 47 | 48 | 49 | class Sequence(Module, tp.Generic[A]): 50 | 51 | def __init__(self, iterable: tp.Iterable[A]): 52 | i = 0 53 | for i, value in enumerate(iterable): 54 | setattr(self, str(i), value) 55 | self._length = i + 1 56 | 57 | def __getitem__(self, key: int) -> A: 58 | if key >= len(self): 59 | raise IndexError(f"index {key} out of range for {self}") 60 | return getattr(self, str(key)) 61 | 62 | def __iter__(self) -> tp.Iterator[A]: 63 | for i in range(len(self)): 64 | yield getattr(self, str(i)) 65 | 66 | def __len__(self) -> int: 67 | return self._length 68 | 69 | def __call__(self, *args, ctx: tp.Optional[Context] = None, **kwargs) -> tp.Any: 70 | output: tp.Any = None 71 | 72 | for i, f in enumerate(self): 73 | if not callable(f): 74 | raise TypeError(f"Sequence[{i}] is not callable: {f}") 75 | if i > 0: 76 | if isinstance(output, tp.Tuple): 77 | args = output 78 | kwargs = {} 79 | elif isinstance(output, tp.Dict): 80 | args = () 81 | kwargs = output 82 | else: 83 | args = (output,) 84 | kwargs = {} 85 | if ctx is not None and has_keyword_arg(f, "ctx"): 86 | kwargs["ctx"] = ctx 87 | 88 | output = f(*args, **kwargs) 89 | 90 | return output 91 | 92 | 93 | class ModuleDefApply(tp.Protocol, tp.Generic[M]): 94 | 95 | def __call__(self, state: State, *states: State) -> ApplyCaller["Pure[M]"]: 96 | ... 97 | 98 | 99 | class TrainState(pytreelib.Pytree, tp.Generic[M]): 100 | 101 | def __init__( 102 | self, 103 | moduledef: ModuleDef[M], 104 | *, 105 | params: State, 106 | tx: optax.GradientTransformation, 107 | step: int = 0, 108 | **kwargs, 109 | ): 110 | self.moduledef = moduledef 111 | self.params: State = pytreelib.TreeNode(params) 112 | self.tx = tx 113 | self.opt_state = pytreelib.TreeNode(tx.init(self.params)) 114 | self.step = pytreelib.TreeNode(jnp.asarray(step)) 115 | for name, value in kwargs.items(): 116 | setattr(self, name, value) 117 | 118 | if tp.TYPE_CHECKING: 119 | 120 | def __getattr__(self, key: str) -> tp.Any: 121 | ... 122 | 123 | def apply( 124 | self, state: tp.Union[State, str], *states: tp.Union[State, str] 125 | ) -> ApplyCaller[Pure[State, M]]: 126 | states = (state, *states) 127 | 128 | _states = ( 129 | getattr(self, state) if isinstance(state, str) else state for state in states 130 | ) 131 | 132 | return self.moduledef.apply(*_states) 133 | 134 | def apply_gradients(self, grads: State, **kwargs) -> "TrainState[M]": 135 | updates, opt_state = self.tx.update(grads, self.opt_state, self.params) 136 | params = optax.apply_updates(self.params, updates) # type: ignore 137 | step = self.step + 1 138 | return self.replace( 139 | params=params, 140 | opt_state=opt_state, 141 | step=step, 142 | **kwargs, 143 | ) 144 | 145 | 146 | def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool: 147 | """Return True if func has keyword-only arguments with the given name.""" 148 | return any( 149 | param.name == name 150 | and param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD) 151 | for param in inspect.signature(func).parameters.values() 152 | ) 153 | -------------------------------------------------------------------------------- /nnx/ids.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """UUIDs for Flax internals.""" 16 | 17 | import threading 18 | 19 | 20 | class UUIDManager: 21 | """Globally unique counter-based id manager. 22 | 23 | We need globally unique key ids for Module and Variable object instances 24 | to preserve and recreate sharing-by-reference relationship when lifting 25 | transforms and adopting outside Modules. 26 | - Use of id() is unacceptable because these identifiers are literally 27 | pointers which can be recycled, so we rely on a globally unique counter id 28 | instead. 29 | - We need to handle copy/deepcopy uniqueness via a wrapped type. 30 | """ 31 | 32 | def __init__(self): 33 | self._lock = threading.Lock() 34 | self._id = 0 35 | 36 | def __call__(self): 37 | with self._lock: 38 | self._id += 1 39 | return UUID(self._id) 40 | 41 | 42 | uuid = UUIDManager() 43 | 44 | 45 | class UUID: 46 | """Hashable wrapper for ids that handles uniqueness of copies.""" 47 | 48 | def __init__(self, rawid): 49 | self.id = rawid 50 | 51 | def __eq__(self, other): 52 | return isinstance(other, UUID) and other.id == self.id 53 | 54 | def __hash__(self): 55 | return hash(self.id) 56 | 57 | def __repr__(self): 58 | return f"UUID({self.id})" 59 | 60 | def __deepcopy__(self, memo): 61 | del memo 62 | return uuid() 63 | 64 | def __copy__(self): 65 | return uuid() 66 | -------------------------------------------------------------------------------- /nnx/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/3e4c750791ea39d26f5667ed26066fa6c13ff46b/nnx/nn/__init__.py -------------------------------------------------------------------------------- /nnx/nn/activations.py: -------------------------------------------------------------------------------- 1 | from jax.nn import ( 2 | celu, 3 | elu, 4 | gelu, 5 | glu, 6 | hard_sigmoid, 7 | hard_silu, 8 | hard_swish, 9 | hard_tanh, 10 | leaky_relu, 11 | log_sigmoid, 12 | log_softmax, 13 | logsumexp, 14 | normalize, 15 | one_hot, 16 | relu, 17 | relu6, 18 | selu, 19 | sigmoid, 20 | silu, 21 | soft_sign, 22 | softmax, 23 | softplus, 24 | standardize, 25 | swish, 26 | ) 27 | from jax.numpy import tanh 28 | 29 | __all__ = [ 30 | "celu", 31 | "elu", 32 | "gelu", 33 | "glu", 34 | "hard_sigmoid", 35 | "hard_silu", 36 | "hard_swish", 37 | "hard_tanh", 38 | "leaky_relu", 39 | "log_sigmoid", 40 | "log_softmax", 41 | "logsumexp", 42 | "normalize", 43 | "one_hot", 44 | "relu", 45 | "relu6", 46 | "selu", 47 | "sigmoid", 48 | "silu", 49 | "soft_sign", 50 | "softmax", 51 | "softplus", 52 | "standardize", 53 | "swish", 54 | "tanh", 55 | ] 56 | -------------------------------------------------------------------------------- /nnx/nn/dtypes.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | import jax 4 | from jax import numpy as jnp 5 | 6 | Dtype = Any 7 | Array = Any 8 | 9 | 10 | def canonicalize_dtype( 11 | *args, dtype: Optional[Dtype] = None, inexact: bool = True 12 | ) -> Dtype: 13 | """Canonicalize an optional dtype to the definitive dtype. 14 | 15 | If the ``dtype`` is None this function will infer the dtype. If it is not 16 | None it will be returned unmodified or an exceptions is raised if the dtype 17 | is invalid. 18 | from the input arguments using ``jnp.result_type``. 19 | 20 | Args: 21 | *args: JAX array compatible values. None values 22 | are ignored. 23 | dtype: Optional dtype override. If specified the arguments are cast to 24 | the specified dtype instead and dtype inference is disabled. 25 | inexact: When True, the output dtype must be a subdtype 26 | of `jnp.inexact`. Inexact dtypes are real or complex floating points. This 27 | is useful when you want to apply operations that don't work directly on 28 | integers like taking a mean for example. 29 | Returns: 30 | The dtype that *args should be cast to. 31 | """ 32 | if dtype is None: 33 | args_filtered = [jnp.asarray(x) for x in args if x is not None] 34 | dtype = jnp.result_type(*args_filtered) 35 | if inexact and not jnp.issubdtype(dtype, jnp.inexact): 36 | dtype = jnp.promote_types(jnp.float32, dtype) 37 | if inexact and not jnp.issubdtype(dtype, jnp.inexact): 38 | raise ValueError(f"Dtype must be inexact: {dtype}") 39 | return dtype 40 | 41 | 42 | def promote_dtype(*args, dtype=None, inexact=True) -> List[Array]: 43 | """ "Promotes input arguments to a specified or inferred dtype. 44 | 45 | All args are cast to the same dtype. See ``canonicalize_dtype`` for how 46 | this dtype is determined. 47 | 48 | The behavior of promote_dtype is mostly a convinience wrapper around 49 | ``jax.numpy.promote_types``. The differences being that it automatically casts 50 | all input to the inferred dtypes, allows inference to be overridden by a 51 | forced dtype, and has an optional check to garantuee the resulting dtype is 52 | inexact. 53 | 54 | Args: 55 | *args: JAX array compatible values. None values 56 | are returned as is. 57 | dtype: Optional dtype override. If specified the arguments are cast to 58 | the specified dtype instead and dtype inference is disabled. 59 | inexact: When True, the output dtype must be a subdtype 60 | of `jnp.inexact`. Inexact dtypes are real or complex floating points. This 61 | is useful when you want to apply operations that don't work directly on 62 | integers like taking a mean for example. 63 | Returns: 64 | The arguments cast to arrays of the same dtype. 65 | """ 66 | dtype = canonicalize_dtype(*args, dtype=dtype, inexact=inexact) 67 | return [jnp.asarray(x, dtype) if x is not None else None for x in args] 68 | -------------------------------------------------------------------------------- /nnx/nn/initializers.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax.nn.initializers import constant as constant 6 | from jax.nn.initializers import delta_orthogonal as delta_orthogonal 7 | from jax.nn.initializers import glorot_normal as glorot_normal 8 | from jax.nn.initializers import glorot_uniform as glorot_uniform 9 | from jax.nn.initializers import he_normal as he_normal 10 | from jax.nn.initializers import he_uniform as he_uniform 11 | from jax.nn.initializers import kaiming_normal as kaiming_normal 12 | from jax.nn.initializers import kaiming_uniform as kaiming_uniform 13 | from jax.nn.initializers import lecun_normal as lecun_normal 14 | from jax.nn.initializers import lecun_uniform as lecun_uniform 15 | from jax.nn.initializers import normal as normal 16 | from jax.nn.initializers import orthogonal as orthogonal 17 | from jax.nn.initializers import uniform as uniform 18 | from jax.nn.initializers import variance_scaling as variance_scaling 19 | from jax.nn.initializers import xavier_normal as xavier_normal 20 | from jax.nn.initializers import xavier_uniform as xavier_uniform 21 | 22 | Shape = tp.Sequence[int] 23 | DTypeLikeInexact = tp.Any 24 | KeyArray = jax.random.KeyArray 25 | Array = jax.Array 26 | 27 | 28 | class Initializer(tp.Protocol): 29 | 30 | @staticmethod 31 | def __call__( 32 | key: KeyArray, shape: Shape, dtype: DTypeLikeInexact = jnp.float_ 33 | ) -> Array: 34 | ... 35 | 36 | 37 | def zeros() -> Initializer: 38 | """Builds an initializer that returns a constant array full of zeros. 39 | 40 | >>> import jax, jax.numpy as jnp 41 | >>> from flax.linen.initializers import zeros_init 42 | >>> zeros_initializer = zeros_init() 43 | >>> zeros_initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) 44 | Array([[0., 0., 0.], 45 | [0., 0., 0.]], dtype=float32) 46 | """ 47 | return jax.nn.initializers.zeros 48 | 49 | 50 | def ones() -> Initializer: 51 | """Builds an initializer that returns a constant array full of ones. 52 | 53 | >>> import jax, jax.numpy as jnp 54 | >>> from flax.linen.initializers import ones_init 55 | >>> ones_initializer = ones_init() 56 | >>> ones_initializer(jax.random.PRNGKey(42), (3, 2), jnp.float32) 57 | Array([[1., 1.], 58 | [1., 1.], 59 | [1., 1.]], dtype=float32) 60 | """ 61 | return jax.nn.initializers.ones 62 | -------------------------------------------------------------------------------- /nnx/nn/stochastic.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional, Sequence 3 | 4 | import jax.numpy as jnp 5 | from jax import lax, random 6 | 7 | from nnx import contextlib 8 | from nnx.module import Module, first_from 9 | 10 | 11 | @dataclasses.dataclass 12 | class Dropout(Module): 13 | """Create a dropout layer. 14 | 15 | Attributes: 16 | rate: the dropout probability. (_not_ the keep rate!) 17 | broadcast_dims: dimensions that will share the same dropout mask 18 | deterministic: if false the inputs are scaled by `1 / (1 - rate)` and 19 | masked, whereas if true, no mask is applied and the inputs are returned 20 | as is. 21 | rng_collection: the rng collection name to use when requesting an rng key. 22 | """ 23 | 24 | rate: float 25 | broadcast_dims: Sequence[int] = () 26 | deterministic: Optional[bool] = None 27 | rng_collection: str = "dropout" 28 | 29 | def __call__( 30 | self, 31 | inputs, 32 | *, 33 | deterministic: Optional[bool] = None, 34 | ctx: Optional[contextlib.Context] = None, 35 | ): 36 | """Applies a random dropout mask to the input. 37 | 38 | Args: 39 | inputs: the inputs that should be randomly masked. 40 | deterministic: if false the inputs are scaled by `1 / (1 - rate)` and 41 | masked, whereas if true, no mask is applied and the inputs are returned 42 | as is. 43 | 44 | Returns: 45 | The masked inputs reweighted to preserve mean. 46 | """ 47 | deterministic = first_from( 48 | deterministic, 49 | self.deterministic, 50 | ctx and ctx.get_flag("deterministic"), 51 | ) 52 | 53 | if (self.rate == 0.0) or deterministic: 54 | return inputs 55 | 56 | # Prevent gradient NaNs in 1.0 edge-case. 57 | if self.rate == 1.0: 58 | return jnp.zeros_like(inputs) 59 | 60 | if ctx is None: 61 | raise ValueError( 62 | "Dropout needs to generate a random mask but no 'ctx' were provided." 63 | ) 64 | 65 | keep_prob = 1.0 - self.rate 66 | rng = ctx.make_rng(self.rng_collection) 67 | broadcast_shape = list(inputs.shape) 68 | for dim in self.broadcast_dims: 69 | broadcast_shape[dim] = 1 70 | mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) 71 | mask = jnp.broadcast_to(mask, inputs.shape) 72 | return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) 73 | -------------------------------------------------------------------------------- /nnx/nodes.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import numpy as np 5 | 6 | node_types: tp.Tuple[type, ...] = () 7 | 8 | 9 | def register_node_type(node_type: type) -> None: 10 | global node_types 11 | node_types += (node_type,) 12 | 13 | 14 | def is_node(obj: object) -> bool: 15 | return isinstance(obj, node_types) 16 | 17 | 18 | # register nodes 19 | register_node_type(jax.Array) 20 | register_node_type(np.ndarray) 21 | -------------------------------------------------------------------------------- /nnx/partitioning.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import dataclasses 3 | import typing as tp 4 | 5 | import jax 6 | import numpy as np 7 | 8 | import nnx 9 | 10 | if tp.TYPE_CHECKING: 11 | ellipsis = builtins.ellipsis 12 | else: 13 | ellipsis = tp.Any 14 | 15 | Path = str 16 | Predicate = tp.Callable[[Path, tp.Any], bool] 17 | FilterLiteral = tp.Union[type, Predicate, ellipsis, None] 18 | Filter = tp.Union[FilterLiteral, tp.Tuple[FilterLiteral, ...]] 19 | 20 | 21 | def to_predicate(filter: Filter) -> Predicate: 22 | if isinstance(filter, str): 23 | raise TypeError(f"Invalid filter of type '{type(filter).__name__}'") 24 | elif isinstance(filter, type): 25 | return OfType(filter) 26 | elif filter is Ellipsis: 27 | return Everything() 28 | elif filter is None: 29 | return Nothing() 30 | elif callable(filter): 31 | return filter 32 | elif isinstance(filter, tp.Tuple): 33 | return Any(*filter) 34 | else: 35 | raise TypeError(f"Invalid collection filter: {filter:!r}. ") 36 | 37 | 38 | @dataclasses.dataclass 39 | class OfType: 40 | type: type 41 | 42 | def __call__(self, path: Path, x: tp.Any): 43 | return isinstance(x, self.type) 44 | 45 | 46 | class Any: 47 | 48 | def __init__(self, *filters: Filter): 49 | self.predicates = tuple( 50 | to_predicate(collection_filter) for collection_filter in filters 51 | ) 52 | 53 | def __call__(self, path: Path, x: tp.Any): 54 | return any(predicate(path, x) for predicate in self.predicates) 55 | 56 | 57 | class All: 58 | 59 | def __init__(self, *filters: Filter): 60 | self.predicates = tuple( 61 | to_predicate(collection_filter) for collection_filter in filters 62 | ) 63 | 64 | def __call__(self, path: Path, x: tp.Any): 65 | return all(predicate(path, x) for predicate in self.predicates) 66 | 67 | 68 | class Not: 69 | 70 | def __init__(self, collection_filter: Filter): 71 | self.predicate = to_predicate(collection_filter) 72 | 73 | def __call__(self, path: Path, x: tp.Any): 74 | return not self.predicate(path, x) 75 | 76 | 77 | class Everything: 78 | 79 | def __call__(self, path: Path, x: tp.Any): 80 | return True 81 | 82 | 83 | class Nothing: 84 | 85 | def __call__(self, path: Path, x: tp.Any): 86 | return False 87 | 88 | 89 | buffers = (jax.Array, np.ndarray) 90 | -------------------------------------------------------------------------------- /nnx/pytreelib.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import dataclasses 3 | import importlib.util 4 | import inspect 5 | import typing as tp 6 | from abc import ABCMeta 7 | from copy import copy 8 | from functools import partial 9 | from types import MappingProxyType 10 | 11 | import jax 12 | 13 | from nnx import containers, nodes, reprlib 14 | 15 | A = tp.TypeVar("A") 16 | P = tp.TypeVar("P", bound="Pytree") 17 | 18 | 19 | class TreeNode(containers.NodeBase[A]): 20 | pass 21 | 22 | 23 | @contextlib.contextmanager 24 | def _mutable(obj: P) -> tp.Iterator[None]: 25 | vars(obj)["_pytree__is_mutable"] = True 26 | try: 27 | yield 28 | finally: 29 | del vars(obj)["_pytree__is_mutable"] 30 | 31 | 32 | @contextlib.contextmanager 33 | def _initializing(obj: P) -> tp.Iterator[None]: 34 | vars(obj)["_pytree__initializing"] = True 35 | try: 36 | yield 37 | finally: 38 | del vars(obj)["_pytree__initializing"] 39 | 40 | 41 | class PytreeMeta(ABCMeta): 42 | if not tp.TYPE_CHECKING: 43 | 44 | def __call__(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: 45 | return cls.call(*args, **kwargs) 46 | 47 | def call(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: 48 | obj: P = cls.__new__(cls, *args, **kwargs) 49 | vars(obj)["_pytree__sorted_fields"] = ["_pytree__sorted_fields"] 50 | 51 | with _mutable(obj), _initializing(obj): 52 | obj.__init__(*args, **kwargs) 53 | 54 | if dataclasses.is_dataclass(obj): 55 | assert isinstance(obj, Pytree) 56 | for field in dataclasses.fields(obj): 57 | if "nnx_container_fn" not in field.metadata: 58 | continue 59 | 60 | container_fn = field.metadata["nnx_container_fn"] 61 | value = vars(obj)[field.name] 62 | value = container_fn(value) 63 | vars(obj)[field.name] = value 64 | 65 | vars(obj)["_pytree__sorted_fields"] = sorted(vars(obj)) 66 | 67 | return obj 68 | 69 | 70 | class Pytree(reprlib.Representable, metaclass=PytreeMeta): 71 | _pytree__is_mutable: bool 72 | _pytree__class_is_mutable: bool 73 | _pytree__sorted_fields: tp.Tuple[str, ...] 74 | 75 | if not tp.TYPE_CHECKING: 76 | 77 | def __getattribute__(self, name: str) -> tp.Any: 78 | value = object.__getattribute__(self, name) 79 | if isinstance(value, containers.Container): 80 | return value.value 81 | return value 82 | 83 | def __setattr__(self, name: str, value: tp.Any) -> None: 84 | self._setattr(name, value) 85 | 86 | def _setattr(self: P, name: str, value: tp.Any): 87 | vars_dict = vars(self) 88 | if "_pytree__initializing" in vars_dict: 89 | pass 90 | elif name not in vars_dict: 91 | raise AttributeError(r"Cannot add new fields to an initialized Pytree") 92 | elif "_pytree__is_mutable" not in vars_dict and not self._pytree__class_is_mutable: 93 | raise AttributeError(f"{type(self)} is immutable, trying to update field {name}") 94 | 95 | if name in vars_dict and isinstance(vars_dict[name], containers.Container): 96 | vars_dict[name] = vars_dict[name].replace(value=value) 97 | else: 98 | if isinstance(value, containers.Container): 99 | value = value.copy() 100 | vars_dict[name] = value 101 | 102 | def __init_subclass__(cls, mutable: bool = False): 103 | super().__init_subclass__() 104 | # init class variables 105 | cls._pytree__is_mutable = False 106 | cls._pytree__class_is_mutable = mutable 107 | 108 | # TODO: clean up this in the future once minimal supported version is 0.4.7 109 | if hasattr(jax.tree_util, "register_pytree_with_keys"): 110 | if ( 111 | "flatten_func" 112 | in inspect.signature(jax.tree_util.register_pytree_with_keys).parameters 113 | ): 114 | jax.tree_util.register_pytree_with_keys( 115 | cls, 116 | partial( 117 | cls._pytree__flatten, 118 | with_key_paths=True, 119 | ), 120 | cls._pytree__unflatten, 121 | flatten_func=partial( 122 | cls._pytree__flatten, 123 | with_key_paths=False, 124 | ), 125 | ) 126 | else: 127 | jax.tree_util.register_pytree_with_keys( 128 | cls, 129 | partial( 130 | cls._pytree__flatten, 131 | with_key_paths=True, 132 | ), 133 | cls._pytree__unflatten, 134 | ) 135 | else: 136 | jax.tree_util.register_pytree_node( 137 | cls, 138 | partial( 139 | cls._pytree__flatten, 140 | with_key_paths=False, 141 | ), 142 | cls._pytree__unflatten, 143 | ) 144 | 145 | # flax serialization support 146 | if importlib.util.find_spec("flax") is not None: 147 | from flax import serialization 148 | 149 | serialization.register_serialization_state( 150 | cls, cls._to_flax_state_dict, cls._from_flax_state_dict 151 | ) 152 | 153 | @classmethod 154 | def _pytree__flatten( 155 | cls, 156 | pytree: "Pytree", 157 | *, 158 | with_key_paths: bool, 159 | ): 160 | all_vars = vars(pytree) 161 | static = {} 162 | node_values = [] 163 | node_names = [] 164 | 165 | for field in pytree._pytree__sorted_fields: 166 | value = all_vars[field] 167 | 168 | if nodes.is_node(value): 169 | node_names.append(field) 170 | if with_key_paths: 171 | node_values.append((jax.tree_util.GetAttrKey(field), value)) 172 | else: 173 | node_values.append(value) 174 | else: 175 | static[field] = value 176 | 177 | return node_values, (tuple(node_names), MappingProxyType(static)) 178 | 179 | @classmethod 180 | def _pytree__unflatten( 181 | cls: tp.Type[P], 182 | metadata: tp.Tuple[tp.Tuple[str, ...], tp.Mapping[str, tp.Any]], 183 | node_values: tp.Tuple[tp.Any, ...], 184 | ) -> P: 185 | node_names, static_fields = metadata 186 | pytree = object.__new__(cls) 187 | pytree.__dict__.update(zip(node_names, node_values)) 188 | pytree.__dict__.update(static_fields) 189 | return pytree 190 | 191 | @classmethod 192 | def _to_flax_state_dict(cls, pytree: "Pytree") -> tp.Dict[str, tp.Any]: 193 | from flax import serialization 194 | 195 | state_dict = { 196 | name: serialization.to_state_dict(getattr(pytree, name)) 197 | for name, value in vars(pytree).items() 198 | if nodes.is_node(value) 199 | } 200 | return state_dict 201 | 202 | @classmethod 203 | def _from_flax_state_dict( 204 | cls, 205 | pytree: P, 206 | state: tp.Dict[str, tp.Any], 207 | ) -> P: 208 | """Restore the state of a data class.""" 209 | from flax import serialization 210 | 211 | state = state.copy() # copy the state so we can pop the restored fields. 212 | updates = {} 213 | for name, value in vars(pytree).items(): 214 | if not nodes.is_node(value): 215 | continue 216 | if name not in state: 217 | raise ValueError( 218 | f"Missing field {name} in state dict while restoring" 219 | f" an instance of {type(pytree).__name__}," 220 | f" at path {serialization.current_path()}" 221 | ) 222 | value_state = state.pop(name) 223 | updates[name] = serialization.from_state_dict(value, value_state, name=name) 224 | if state: 225 | names = ",".join(state.keys()) 226 | raise ValueError( 227 | f'Unknown field(s) "{names}" in state dict while' 228 | f" restoring an instance of {type(pytree).__name__}" 229 | f" at path {serialization.current_path()}" 230 | ) 231 | return pytree.replace(**updates) 232 | 233 | def replace(self: P, **kwargs: tp.Any) -> P: 234 | """ 235 | Replace the values of the fields of the object with the values of the 236 | keyword arguments. If the object is a dataclass, `dataclasses.replace` 237 | will be used. Otherwise, a new object will be created with the same 238 | type as the original object. 239 | """ 240 | if dataclasses.is_dataclass(self): 241 | return dataclasses.replace(self, **kwargs) 242 | 243 | unknown_keys = set(kwargs) - set(vars(self)) 244 | if unknown_keys and not self._pytree__class_is_mutable: 245 | raise ValueError( 246 | f"Trying to replace unknown fields {unknown_keys} " 247 | f"for '{type(self).__name__}'" 248 | ) 249 | 250 | pytree = copy(self) 251 | with _mutable(pytree): 252 | for key, value in kwargs.items(): 253 | setattr(pytree, key, value) 254 | 255 | return pytree 256 | 257 | def __nnx_repr__(self): 258 | yield reprlib.Object(type(self)) 259 | for name, value in vars(self).items(): 260 | yield reprlib.Attr(name, repr(value)) 261 | 262 | 263 | # register node types 264 | nodes.register_node_type(Pytree) 265 | nodes.register_node_type(TreeNode) 266 | -------------------------------------------------------------------------------- /nnx/reprlib.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import dataclasses 3 | import threading 4 | import typing as tp 5 | from abc import abstractmethod 6 | 7 | 8 | @dataclasses.dataclass 9 | class ReprContext(threading.local): 10 | indent_stack: tp.List[str] = dataclasses.field(default_factory=lambda: [""]) 11 | 12 | 13 | REPR_CONTEXT = ReprContext() 14 | 15 | 16 | @dataclasses.dataclass 17 | class Object: 18 | type: tp.Union[str, type] 19 | start: str = "(" 20 | end: str = ")" 21 | value_sep: str = "=" 22 | elem_indent: str = " " 23 | empty_repr: str = "" 24 | 25 | 26 | @dataclasses.dataclass 27 | class Attr: 28 | key: str 29 | value: tp.Union[str, tp.Any] 30 | start: str = "" 31 | end: str = "" 32 | 33 | 34 | class Representable: 35 | __slots__ = () 36 | 37 | @abstractmethod 38 | def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]: 39 | raise NotImplementedError 40 | 41 | def __repr__(self) -> str: 42 | return get_repr(self) 43 | 44 | 45 | @contextlib.contextmanager 46 | def add_indent(indent: str) -> tp.Iterator[None]: 47 | REPR_CONTEXT.indent_stack.append(REPR_CONTEXT.indent_stack[-1] + indent) 48 | 49 | try: 50 | yield 51 | finally: 52 | REPR_CONTEXT.indent_stack.pop() 53 | 54 | 55 | def get_indent() -> str: 56 | return REPR_CONTEXT.indent_stack[-1] 57 | 58 | 59 | def get_repr(obj: Representable) -> str: 60 | if not isinstance(obj, Representable): 61 | raise TypeError(f"Object {obj!r} is not representable") 62 | 63 | iterator = obj.__nnx_repr__() 64 | config = next(iterator) 65 | if not isinstance(config, Object): 66 | raise TypeError(f"First item must be Config, got {type(config).__name__}") 67 | 68 | def _repr_elem(elem: tp.Any) -> str: 69 | if not isinstance(elem, Attr): 70 | raise TypeError(f"Item must be Elem, got {type(elem).__name__}") 71 | 72 | value = elem.value if isinstance(elem.value, str) else repr(elem.value) 73 | 74 | if "\n" in value and not isinstance(elem.value, Representable): 75 | value = value.replace("\n", "\n" + get_indent()) 76 | 77 | return f"{get_indent()}{elem.start}{elem.key}{config.value_sep}{value}{elem.end}" 78 | 79 | with add_indent(config.elem_indent): 80 | elems = list(map(_repr_elem, iterator)) 81 | elems = ",\n".join(elems) 82 | 83 | if elems: 84 | elems = "\n" + elems + "\n" + get_indent() 85 | else: 86 | elems = config.empty_repr 87 | 88 | type_repr = config.type if isinstance(config.type, str) else config.type.__name__ 89 | 90 | return f"{type_repr}{config.start}{elems}{config.end}" 91 | -------------------------------------------------------------------------------- /nnx/state.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import jax.tree_util as jtu 5 | 6 | from nnx import nodes, partitioning, reprlib 7 | from nnx.containers import Node 8 | 9 | A = tp.TypeVar("A") 10 | 11 | Leaf = tp.Any 12 | Path = str 13 | StateDict = tp.Dict[Path, tp.Any] 14 | StateMapping = tp.Mapping[Path, tp.Any] 15 | 16 | 17 | class State(tp.Mapping[Path, Leaf], reprlib.Representable): 18 | __slots__ = ("_mapping",) 19 | 20 | def __init__( 21 | self, 22 | __input: tp.Union[ 23 | tp.Mapping[Path, Leaf], 24 | tp.Iterator[tp.Tuple[Path, Leaf]], 25 | ], 26 | /, 27 | ): 28 | if isinstance(__input, tp.Mapping): 29 | self._mapping = dict(sorted(__input.items(), key=lambda x: x[0])) 30 | else: 31 | self._mapping = dict(sorted(__input, key=lambda x: x[0])) 32 | 33 | def __getitem__(self, __key: Path) -> Leaf: 34 | return self._mapping[__key] 35 | 36 | def __iter__(self) -> tp.Iterator[Path]: 37 | return iter(self._mapping) 38 | 39 | def __len__(self) -> int: 40 | return len(self._mapping) 41 | 42 | def __nnx_repr__(self): 43 | yield reprlib.Object(type(self), value_sep=": ", start="({", end="})") 44 | 45 | for k, v in self._mapping.items(): 46 | yield reprlib.Attr(repr(k), v) 47 | 48 | @tp.overload 49 | def partition(self, first: partitioning.Filter, /) -> "State": 50 | ... 51 | 52 | @tp.overload 53 | def partition( 54 | self, 55 | first: partitioning.Filter, 56 | second: partitioning.Filter, 57 | /, 58 | *filters: partitioning.Filter, 59 | ) -> tp.Tuple["State", ...]: 60 | ... 61 | 62 | def partition( 63 | self, first: partitioning.Filter, /, *filters: partitioning.Filter 64 | ) -> tp.Union["State", tp.Tuple["State", ...]]: 65 | filters = (first, *filters) 66 | *states, rest = _split_state(self, *filters) 67 | 68 | if rest: 69 | raise ValueError( 70 | f"Non-exhaustive filters, got a non-empty remainder: " 71 | f"{list(rest.keys())}.\nUse `...` to match all remaining elements." 72 | ) 73 | 74 | if len(states) == 1: 75 | states = State(states[0]) 76 | else: 77 | states = tuple(State(state) for state in states) 78 | return states 79 | 80 | @tp.overload 81 | def filter( 82 | self, 83 | first: partitioning.Filter, 84 | /, 85 | ) -> "State": 86 | ... 87 | 88 | @tp.overload 89 | def filter( 90 | self, 91 | first: partitioning.Filter, 92 | second: partitioning.Filter, 93 | /, 94 | *filters: partitioning.Filter, 95 | ) -> tp.Tuple["State", ...]: 96 | ... 97 | 98 | def filter( 99 | self, 100 | first: partitioning.Filter, 101 | /, 102 | *filters: partitioning.Filter, 103 | ) -> tp.Union["State", tp.Tuple["State", ...]]: 104 | *states, _rest = _split_state(self, first, *filters) 105 | 106 | assert len(states) == len(filters) + 1 107 | 108 | if len(states) == 1: 109 | states = State(states[0]) 110 | else: 111 | states = tuple(State(state) for state in states) 112 | 113 | return states 114 | 115 | @staticmethod 116 | def merge(state: "State", /, *states: "State") -> "State": 117 | states = (state, *states) 118 | 119 | if len(states) == 1: 120 | return states[0] 121 | 122 | new_state: StateDict = {} 123 | 124 | for state in states: 125 | new_state.update(state) 126 | 127 | return State(new_state) 128 | 129 | def __or__(self, other: "State") -> "State": 130 | if not other: 131 | return self 132 | return State.merge(self, other) 133 | 134 | def __sub__(self, other: "State") -> "State": 135 | if not other: 136 | return self 137 | 138 | # create new State via __new__ to avoid __init__ sorting 139 | _mapping = {k: v for k, v in self.items() if k not in other} 140 | state = object.__new__(State) 141 | state._mapping = _mapping 142 | return state 143 | 144 | 145 | def _state_flatten_with_keys( 146 | x: State, 147 | ): 148 | children = tuple((jtu.DictKey(key), value) for key, value in x.items()) 149 | return children, tuple(x.keys()) 150 | 151 | 152 | def _state_unflatten( 153 | keys: tp.Tuple[Path, ...], 154 | leaves: tp.Tuple[Leaf, ...], 155 | ): 156 | state = object.__new__(State) 157 | state._mapping = dict(zip(keys, leaves)) 158 | return state 159 | 160 | 161 | jax.tree_util.register_pytree_with_keys( 162 | State, _state_flatten_with_keys, _state_unflatten 163 | ) 164 | 165 | 166 | def _split_state( 167 | state: StateMapping, 168 | *filters: partitioning.Filter, 169 | ) -> tp.Tuple[StateDict, ...]: 170 | for i, filter_ in enumerate(filters): 171 | if filter_ is ... and i != len(filters) - 1: 172 | raise ValueError( 173 | f"Ellipsis `...` can only be used as the last filter, " 174 | f"got it at index {i}." 175 | ) 176 | predicates = tuple(map(partitioning.to_predicate, filters)) 177 | 178 | # we have n + 1 states, where n is the number of predicates 179 | # the last state is for values that don't match any predicate 180 | states: tp.Tuple[StateDict, ...] = tuple({} for _ in range(len(predicates) + 1)) 181 | 182 | for path, value in state.items(): 183 | for i, predicate in enumerate(predicates): 184 | if predicate(path, value): 185 | states[i][path] = value 186 | break 187 | else: 188 | # if we didn't break, set leaf to last state 189 | states[-1][path] = value 190 | 191 | return states 192 | 193 | 194 | # register nodes 195 | nodes.register_node_type(State) 196 | -------------------------------------------------------------------------------- /nnx/tracers.py: -------------------------------------------------------------------------------- 1 | # Taken from flax/core/tracer.py 🏴‍☠️ 2 | 3 | import contextlib 4 | import dataclasses 5 | import threading 6 | import typing as tp 7 | 8 | import jax 9 | import jax.core 10 | from jax.core import MainTrace 11 | 12 | from nnx import reprlib 13 | 14 | 15 | @tp.runtime_checkable 16 | class Tracer(tp.Protocol): 17 | _trace: jax.core.Trace 18 | 19 | 20 | def get_top_trace(pytree: tp.Union[tp.Any, Tracer]) -> MainTrace: 21 | """Returns the main top trace of a sequence of tracers.""" 22 | if isinstance(pytree, Tracer): 23 | return pytree._trace.main 24 | 25 | return jax.core.find_top_trace(jax.tree_util.tree_leaves(pytree)).main 26 | 27 | 28 | def current_jax_trace() -> MainTrace: 29 | """Returns the innermost Jax tracer.""" 30 | return get_top_trace(()) 31 | 32 | 33 | def get_all_traces(pytree: tp.Union[tp.Any, Tracer]) -> tp.Set[MainTrace]: 34 | """Returns True if all tracers have the same main trace.""" 35 | if isinstance(pytree, Tracer): 36 | return {pytree._trace.main} 37 | else: 38 | return { 39 | trace._trace.main 40 | for trace in jax.tree_util.tree_leaves(pytree) 41 | if isinstance(trace, Tracer) 42 | } 43 | 44 | 45 | def trace_level(main): 46 | """Returns the level of the trace of -infinity if it is None.""" 47 | if main: 48 | return main.level 49 | return float("-inf") 50 | 51 | 52 | @dataclasses.dataclass 53 | class TraceContext(threading.local): 54 | nnx_trace_stack: tp.List[MainTrace] = dataclasses.field( 55 | default_factory=lambda: [current_jax_trace()] 56 | ) 57 | 58 | 59 | TRACE_CONTEXT = TraceContext() 60 | 61 | 62 | @contextlib.contextmanager 63 | def nnx_trace(trace: MainTrace): 64 | TRACE_CONTEXT.nnx_trace_stack.append(trace) 65 | try: 66 | yield 67 | finally: 68 | TRACE_CONTEXT.nnx_trace_stack.pop() 69 | 70 | 71 | def current_nnx_trace() -> MainTrace: 72 | return TRACE_CONTEXT.nnx_trace_stack[-1] 73 | 74 | 75 | class TraceState(reprlib.Representable): 76 | __slots__ = ["_jax_trace", "_nnx_trace"] 77 | 78 | def __init__(self): 79 | self._jax_trace = current_jax_trace() 80 | self._nnx_trace = current_nnx_trace() 81 | 82 | @property 83 | def jax_trace(self): 84 | return self._jax_trace 85 | 86 | @property 87 | def nnx_trace(self): 88 | return self._nnx_trace 89 | 90 | def is_valid(self) -> bool: 91 | return ( 92 | self._jax_trace is current_jax_trace() 93 | and self._nnx_trace is current_nnx_trace() 94 | ) 95 | 96 | def __nnx_repr__(self): 97 | yield reprlib.Object(f"{type(self).__name__}") 98 | yield reprlib.Attr("jax_trace", self._jax_trace) 99 | yield reprlib.Attr("nnx_trace", self._nnx_trace) 100 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "nnx" 3 | version = "0.0.8" 4 | description = "" 5 | authors = ["Cristian Garcia "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.9,<3.12" 10 | jax = "*" 11 | jaxlib = "*" 12 | optax = "*" 13 | typing-extensions = "*" 14 | 15 | 16 | [tool.poetry.group.test.dependencies] 17 | pytest = ">=7.2.2" 18 | pytest-cov = ">=4.0.0" 19 | flax = ">=0.6.10" 20 | 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | black = { version = "23.3.0", extras = ["jupyter"] } 24 | isort = "5.12.0" 25 | ipykernel = "^6.22.0" 26 | pre-commit = ">=3.3.2" 27 | pyink = "23.3.0" 28 | 29 | [tool.poetry.group.examples.dependencies] 30 | matplotlib = "^3.7.1" 31 | datasets = "^2.12.0" 32 | 33 | [build-system] 34 | requires = ["poetry-core"] 35 | build-backend = "poetry.core.masonry.api" 36 | 37 | [tool.coverage.report] 38 | exclude_lines = [ 39 | "pragma: no cover", 40 | "@overload", 41 | "@tp.overload", 42 | "@abstractmethod", 43 | ] 44 | 45 | [tool.pyink] 46 | pyink-indentation = 2 47 | -------------------------------------------------------------------------------- /scripts/deploy-docs.sh: -------------------------------------------------------------------------------- 1 | cp README.md docs/index.md 2 | mkdocs gh-deploy --clean -------------------------------------------------------------------------------- /scripts/run-all-examples.bash: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | for f in $(find examples -name "*.py"); do 4 | echo -e "\n---------------------------------" 5 | echo "$f" 6 | echo "---------------------------------" 7 | poetry run time python "$f" 8 | done 9 | -------------------------------------------------------------------------------- /scripts/update_version.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | import typer 5 | 6 | 7 | # NOTE: this script could be written bash using sed, but I'm not sure if it's worth it 8 | def main(release_name: str): 9 | release_name = release_name.replace("-create-release", "") 10 | 11 | # Update pyproject.toml 12 | pyproject_path = Path("pyproject.toml") 13 | pyproject_text = pyproject_path.read_text() 14 | pyproject_text = re.sub( 15 | r'version = ".*"', 16 | f'version = "{release_name}"', 17 | pyproject_text, 18 | count=1, 19 | ) 20 | pyproject_path.write_text(pyproject_text) 21 | 22 | # Update __init__.py 23 | init_path = Path("nnx/__init__.py") 24 | init_text = init_path.read_text() 25 | init_text = re.sub( 26 | r'__version__ = "(.*?)"', 27 | f'__version__ = "{release_name}"', 28 | init_text, 29 | count=1, 30 | ) 31 | init_path.write_text(init_text) 32 | 33 | 34 | if __name__ == "__main__": 35 | typer.run(main) 36 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/3e4c750791ea39d26f5667ed26066fa6c13ff46b/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_containers.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import pytest 4 | 5 | import nnx 6 | 7 | 8 | class TestContainers: 9 | 10 | def test_node_idenpotence(self): 11 | x = nnx.Node(1) 12 | x = nnx.Node(x) 13 | 14 | assert isinstance(x, nnx.Node) 15 | 16 | def test_variable_idenpotence(self): 17 | x = nnx.Variable(1) 18 | x = nnx.Variable(x) 19 | 20 | assert isinstance(x, nnx.Variable) 21 | assert x.value == 1 22 | 23 | def test_variable_cannot_change_collection(self): 24 | x = nnx.Param(1) 25 | 26 | with pytest.raises(ValueError, match="is not compatible with return type"): 27 | x = nnx.BatchStat(x) 28 | 29 | def test_container_cannot_change_type(self): 30 | x = nnx.Variable(1) 31 | 32 | with pytest.raises(ValueError, match="is not compatible with return type"): 33 | x = nnx.Node(x) 34 | 35 | x = nnx.Node(2) 36 | 37 | with pytest.raises(ValueError, match="is not compatible with return type"): 38 | x = nnx.Variable(x) 39 | 40 | def test_static_is_empty(self): 41 | leaves = jax.tree_util.tree_leaves(nnx.Static(1)) 42 | 43 | assert len(leaves) == 0 44 | 45 | def test_static_empty_pytree(self): 46 | static = nnx.Static(2) 47 | 48 | static = jax.tree_map(lambda x: x + 1, static) 49 | 50 | assert static.value == 2 51 | 52 | def test_static_array_not_jitable(self): 53 | @jax.jit 54 | def f(x): 55 | return x 56 | 57 | # first time you don't get an error due to a bug in jax 58 | f(nnx.Static(np.random.uniform(size=(10, 10)))) 59 | 60 | with pytest.raises(ValueError): 61 | f(nnx.Static(np.random.uniform(size=(10, 10)))) 62 | -------------------------------------------------------------------------------- /tests/test_context.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import jax 4 | import numpy as np 5 | import pytest 6 | 7 | import nnx 8 | from nnx.contextlib import _stable_hash 9 | 10 | 11 | class TestContext: 12 | 13 | def test_hash(self): 14 | _hash = _stable_hash("hi") 15 | assert isinstance(_hash, int) 16 | 17 | def test_rng_stream(self): 18 | key0 = jax.random.PRNGKey(0) 19 | ctx = nnx.context(key0) 20 | assert ctx._rngs["params"].count == 0 21 | 22 | key1 = ctx.make_rng("params") 23 | assert ctx._rngs["params"].count == 1 24 | assert ctx._rngs["params"].key is key0 25 | assert not np.equal(key0, key1).all() 26 | 27 | key2 = ctx.make_rng("params") 28 | assert ctx._rngs["params"].count == 2 29 | assert ctx._rngs["params"].key is key0 30 | assert not np.equal(key1, key2).all() 31 | 32 | def test_rng_fork(self): 33 | key0 = jax.random.PRNGKey(0) 34 | ctx1 = nnx.context(key0) 35 | ctx2 = ctx1.partition().merge() 36 | 37 | assert ctx2._rngs["params"].count == 0 38 | assert ctx2._rngs["params"].count_path == (0,) 39 | 40 | key1 = ctx1.make_rng("params") 41 | key2 = ctx2.make_rng("params") 42 | 43 | assert not np.equal(key1, key2).all() 44 | 45 | def test_rng_trace_level_constraints(self): 46 | ctx = nnx.context(0) 47 | 48 | @jax.jit 49 | def f(): 50 | with pytest.raises( 51 | nnx.TraceContextError, 52 | match="Cannot use Context from a different trace level", 53 | ): 54 | ctx.make_rng("params") 55 | 56 | f() 57 | 58 | @jax.jit 59 | def f(): 60 | with pytest.raises( 61 | nnx.TraceContextError, 62 | match="Cannot use Context from a different trace level", 63 | ): 64 | ctx.partition() 65 | 66 | f() 67 | 68 | ctx1: Any = None 69 | 70 | @jax.jit 71 | def g(): 72 | nonlocal ctx1 73 | ctx1 = nnx.context(1) 74 | 75 | g() 76 | 77 | assert isinstance(ctx1, nnx.Context) 78 | with pytest.raises( 79 | nnx.TraceContextError, 80 | match="Cannot use Context from a different trace level", 81 | ): 82 | ctx1.make_rng("params") 83 | 84 | def test_partition_merge(self): 85 | ctx = nnx.context(dropout=0) 86 | 87 | keys, ctxdef = ctx.partition() 88 | 89 | assert "dropout" in keys 90 | assert ctxdef._rng_counts == (("dropout", (0,)),) 91 | 92 | ctx2 = ctxdef.merge(keys) 93 | 94 | key1 = ctx.make_rng("dropout") 95 | key2 = ctx2.make_rng("dropout") 96 | assert not np.equal(key1, key2).all() 97 | 98 | ctx3 = ctxdef.merge(keys) 99 | key3 = ctx3.make_rng("dropout") 100 | assert np.equal(key2, key3).all() 101 | -------------------------------------------------------------------------------- /tests/test_helpers.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import optax 4 | 5 | import nnx 6 | 7 | 8 | class TestHelpers: 9 | 10 | def test_train_state(self): 11 | m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2)) 12 | 13 | (params, batch_stats), moduledef = m.partition(nnx.Param, nnx.BatchStat) 14 | 15 | state = nnx.TrainState( 16 | moduledef, 17 | params=params, 18 | tx=optax.sgd(1.0), 19 | batch_stats=batch_stats, 20 | other=nnx.Node(100), 21 | int=200, 22 | static=nnx.Static(300), 23 | ) 24 | 25 | leaves = jax.tree_util.tree_leaves(state) 26 | 27 | assert 1 in leaves 28 | assert 2 in leaves 29 | assert 100 in leaves 30 | assert 200 not in leaves 31 | assert 300 not in leaves 32 | 33 | def test_train_state_methods(self): 34 | class Foo(nnx.Module): 35 | 36 | def __init__(self, *, ctx: nnx.Context): 37 | self.linear = nnx.Linear(2, 4, ctx=ctx) 38 | self.batch_norm = nnx.BatchNorm(4, ctx=ctx) 39 | 40 | def __call__(self, x: jax.Array, train: bool) -> jax.Array: 41 | x = self.linear(x) 42 | x = self.batch_norm(x, use_running_average=not train) 43 | return x 44 | 45 | module = Foo(ctx=nnx.context(0)) 46 | (params, batch_stats), moduledef = module.partition(nnx.Param, nnx.BatchStat) 47 | 48 | state = nnx.TrainState( 49 | moduledef, 50 | params=params, 51 | tx=optax.sgd(1.0), 52 | batch_stats=batch_stats, 53 | ) 54 | 55 | x = jax.numpy.ones((1, 2)) 56 | y, _updates = state.apply("params", "batch_stats")(x, train=True) 57 | 58 | assert y.shape == (1, 4) 59 | 60 | # fake gradient 61 | grads = jax.tree_map(jnp.ones_like, state.params) 62 | # test apply_gradients 63 | state = state.apply_gradients(grads) 64 | -------------------------------------------------------------------------------- /tests/test_ids.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from nnx import ids 4 | 5 | 6 | class TestIds: 7 | 8 | def test_hashable(self): 9 | id1 = ids.uuid() 10 | id2 = ids.uuid() 11 | assert id1 == id1 12 | assert id1 != id2 13 | assert hash(id1) != hash(id2) 14 | id1c = copy.copy(id1) 15 | id1dc = copy.deepcopy(id1) 16 | assert hash(id1) != hash(id1c) 17 | assert hash(id1) != hash(id1dc) 18 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | import nnx 8 | 9 | A = tp.TypeVar("A") 10 | 11 | 12 | class TestIntegration: 13 | 14 | def test_shared_modules(self): 15 | class Block(nnx.Module): 16 | 17 | def __init__(self, linear: nnx.Linear, *, ctx): 18 | self.linear = linear 19 | self.bn = nnx.BatchNorm(2, ctx=ctx) 20 | 21 | def __call__(self, x, *, ctx): 22 | x = self.linear(x) 23 | x = self.bn(x, ctx=ctx) 24 | return nnx.relu(x) 25 | 26 | class Model(nnx.Module): 27 | 28 | def __init__(self, *, ctx): 29 | shared = nnx.Linear(2, 2, ctx=ctx) 30 | self.block1 = Block(shared, ctx=ctx) 31 | self.block2 = Block(shared, ctx=ctx) 32 | 33 | def __call__(self, x, *, ctx): 34 | x = self.block1(x, ctx=ctx) 35 | x = self.block2(x, ctx=ctx) 36 | return x 37 | 38 | @nnx.jit 39 | def train_step(model: Model, x, y): 40 | @nnx.grad 41 | def loss_fn(model: Model): 42 | ctx = nnx.context(flags=dict(use_running_average=False)) 43 | y_pred = model(x, ctx=ctx) 44 | return jnp.mean((y - y_pred) ** 2) 45 | 46 | grads = loss_fn(model) 47 | model.update_state( 48 | jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) 49 | ) 50 | 51 | model = Model(ctx=nnx.context(0)) 52 | 53 | x = np.random.uniform(size=(4, 2)) 54 | y = np.random.uniform(size=(4, 2)) 55 | 56 | for _i in range(3): 57 | train_step(model, x, y) 58 | 59 | assert model.block1.linear is model.block2.linear 60 | assert model.block1.linear.bias is not None 61 | assert model.block1.bn is not model.block2.bn 62 | 63 | def test_shared_modules_pure(self): 64 | class Block(nnx.Module): 65 | 66 | def __init__(self, linear: nnx.Linear, *, ctx: nnx.Context): 67 | self.linear = linear 68 | self.bn = nnx.BatchNorm(2, ctx=ctx) 69 | 70 | def __call__(self, x, *, ctx: nnx.Context): 71 | x = self.linear(x) 72 | x = self.bn(x, ctx=ctx) 73 | return nnx.relu(x) 74 | 75 | class Model(nnx.Module): 76 | 77 | def __init__(self, *, ctx: nnx.Context): 78 | shared = nnx.Linear(2, 2, ctx=ctx) 79 | self.block1 = Block(shared, ctx=ctx) 80 | self.block2 = Block(shared, ctx=ctx) 81 | 82 | def __call__(self, x, *, ctx: nnx.Context): 83 | x = self.block1(x, ctx=ctx) 84 | x = self.block2(x, ctx=ctx) 85 | return x 86 | 87 | @jax.jit 88 | def train_step(pure_module: nnx.PureModule[Model], x, y): 89 | model = pure_module.merge() 90 | 91 | @nnx.grad 92 | def loss_fn(model: Model): 93 | ctx = nnx.context(flags=dict(use_running_average=False)) 94 | y_pred = model(x, ctx=ctx) 95 | return jnp.mean((y - y_pred) ** 2) 96 | 97 | grads = loss_fn(model) 98 | model.update_state( 99 | jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) 100 | ) 101 | 102 | return model.partition() 103 | 104 | pure_module = Model(ctx=nnx.context(0)).partition() 105 | 106 | x = np.random.uniform(size=(4, 2)) 107 | y = np.random.uniform(size=(4, 2)) 108 | 109 | for _i in range(3): 110 | pure_module = train_step(pure_module, x, y) 111 | 112 | model = pure_module.merge() 113 | 114 | assert model.block1.linear.bias is not None 115 | assert model.block2.linear.bias is not None 116 | assert model.block1.linear.kernel is model.block2.linear.kernel 117 | assert model.block1.linear.bias is model.block2.linear.bias 118 | assert model.block1.bn is not model.block2.bn 119 | 120 | def test_stateful_example(self): 121 | class State(nnx.Variable[A]): 122 | pass 123 | 124 | class Linear(nnx.Module): 125 | 126 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 127 | key = ctx.make_rng("params") 128 | self.w = nnx.Param(jax.random.uniform(key, (din, dout))) 129 | self.b = nnx.Param(jnp.zeros((dout,))) 130 | self.count = State(0) 131 | 132 | def __call__(self, x): 133 | self.count += 1 134 | return x @ self.w + self.b 135 | 136 | model = Linear(din=12, dout=2, ctx=nnx.context(0)) 137 | # forward pass 138 | x = jnp.ones((8, 12)) 139 | y = model(x) 140 | assert model.count == 1 141 | 142 | @nnx.jit 143 | def train_step(model, x, y): 144 | def loss_fn(model): 145 | y_pred = model(x) 146 | return jax.numpy.mean((y_pred - y) ** 2) 147 | 148 | # compute gradient 149 | grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) 150 | # SGD update 151 | model.update_state( 152 | jax.tree_map(lambda w, g: w - 0.1 * g, model.filter(nnx.Param), grads) 153 | ) 154 | 155 | # execute the training step 156 | train_step(model, x, y) 157 | assert model.count == 2 158 | 159 | def test_functional_example(self): 160 | class Count(nnx.Variable[A]): 161 | pass 162 | 163 | class Linear(nnx.Module): 164 | 165 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 166 | key = ctx.make_rng("params") 167 | self.w = nnx.Param(jax.random.uniform(key, (din, dout))) 168 | self.b = nnx.Param(jnp.zeros((dout,))) 169 | self.count = Count(0) 170 | 171 | def __call__(self, x): 172 | self.count += 1 173 | return x @ self.w + self.b 174 | 175 | model = Linear(din=12, dout=2, ctx=nnx.context(0)) 176 | # forward pass 177 | x = jnp.ones((8, 12)) 178 | y = model(x) 179 | assert model.count == 1 180 | 181 | (params, counts), moduledef = model.partition(nnx.Param, Count) 182 | 183 | @jax.jit 184 | def train_step(params, counts, x, y): 185 | def loss_fn(params): 186 | y_pred, (updates, _) = moduledef.apply(params, counts)(x) 187 | loss = jax.numpy.mean((y_pred - y) ** 2) 188 | return loss, updates.filter(Count) 189 | 190 | # compute gradient 191 | grads, counts = jax.grad(loss_fn, has_aux=True)(params) 192 | # SGD update 193 | params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) 194 | 195 | return params, counts 196 | 197 | # execute the training step 198 | params, counts = train_step(params, counts, x, y) 199 | model = moduledef.merge(params, counts) 200 | assert model.count == 2 201 | 202 | def test_intermediates_example(self): 203 | class Linear(nnx.Module): 204 | 205 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 206 | key = ctx.make_rng("params") 207 | self.w = nnx.Param(jax.random.uniform(key, (din, dout))) 208 | self.b = nnx.Param(jnp.zeros((dout,))) 209 | 210 | def __call__(self, x): 211 | y = x @ self.w + self.b 212 | self.y = nnx.Intermediate(y) 213 | return y 214 | 215 | model = Linear(12, 2, ctx=nnx.context(0)) 216 | 217 | y = model(jnp.ones((8, 12))) 218 | 219 | intermediates = model.pop_state(nnx.Intermediate) 220 | 221 | assert "y" in intermediates 222 | 223 | def test_intermediates_example_functional(self): 224 | class Linear(nnx.Module): 225 | 226 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 227 | key = ctx.make_rng("params") 228 | self.w = nnx.Param(jax.random.uniform(key, (din, dout))) 229 | self.b = nnx.Param(jnp.zeros((dout,))) 230 | 231 | def __call__(self, x): 232 | y = x @ self.w + self.b 233 | self.y = nnx.Intermediate(y) 234 | return y 235 | 236 | model = Linear(12, 2, ctx=nnx.context(0)) 237 | 238 | state, moduledef = model.partition() 239 | 240 | y, (state, _) = moduledef.apply(state)(jnp.ones((8, 12))) 241 | 242 | intermediates, state = state.partition(nnx.Intermediate, ...) 243 | 244 | assert "y" in intermediates 245 | -------------------------------------------------------------------------------- /tests/test_partitioning.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import numpy as np 5 | import pytest 6 | 7 | import nnx 8 | 9 | 10 | class TestPartitioning: 11 | 12 | def test_partition(self): 13 | m = nnx.Dict( 14 | a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(2)]), 15 | b=nnx.Param(2), 16 | c=100, 17 | ) 18 | 19 | (params, rest), moduledef = m.partition(nnx.Param, ...) 20 | 21 | assert len(params) == 2 22 | assert len(rest) == 1 23 | 24 | # check params 25 | assert params["a/0"].value == m.a[0] 26 | assert params["b"].value == m.b 27 | 28 | # check rest 29 | assert rest["a/1"].value == m.a[1] 30 | 31 | m2 = moduledef.merge(params, rest) 32 | 33 | assert m2.a[0] == m.a[0] 34 | assert m2.a[1] == m.a[1] 35 | assert m2.b == m.b 36 | assert m2.c == 100 37 | 38 | def test_complete_partitioning(self): 39 | m = nnx.Dict( 40 | a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), 41 | b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), 42 | ) 43 | 44 | # no error 45 | m.partition(nnx.Param, nnx.BatchStat, nnx.Node) 46 | 47 | def test_complete_partitioning_plus_ellipsis(self): 48 | m = nnx.Dict( 49 | a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), 50 | b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), 51 | ) 52 | 53 | # no error if additional ... is passed at the end 54 | m.partition(nnx.Param, nnx.BatchStat, nnx.Node, ...) 55 | 56 | def test_inclomplete_partition_error(self): 57 | m = nnx.Dict( 58 | a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), 59 | b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), 60 | ) 61 | 62 | with pytest.raises( 63 | ValueError, match="Non-exhaustive filters, got a non-empty remainder" 64 | ): 65 | m.partition(nnx.Param) 66 | 67 | def test_ellipsis_not_last_error(self): 68 | m = nnx.Dict( 69 | a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Node(3)]), 70 | b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), 71 | ) 72 | 73 | with pytest.raises( 74 | ValueError, match="Ellipsis `...` can only be used as the last filter," 75 | ): 76 | m.partition(..., nnx.Param) 77 | 78 | def test_update_from(self): 79 | m = nnx.Dict( 80 | a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(3)]), 81 | b=nnx.Param(2), 82 | c=100, 83 | ) 84 | 85 | state = m.partition()[0] 86 | state = jax.tree_map(lambda x: x * 2, state) 87 | 88 | m.update_state(state) 89 | 90 | assert m.a[0] == 2 91 | assert m.a[1] == 6 92 | assert m.b == 4 93 | assert m.c == 100 94 | 95 | def test_update_from_with_array_leaf(self): 96 | m = nnx.Dict( 97 | a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(3)]), 98 | b=nnx.Param(2), 99 | c=jax.numpy.array(100), 100 | ) 101 | 102 | pure_module: nnx.Pure = m.partition() 103 | pure_module = jax.tree_map(lambda x: x * 2, pure_module) 104 | 105 | m.update_state(pure_module.states) 106 | 107 | assert m.a[0] == 2 108 | assert m.a[1] == 6 109 | assert m.b == 4 110 | assert m.c == 200 111 | 112 | def test_grad_example(self): 113 | m = nnx.Dict( 114 | a=nnx.Sequence([nnx.Param(1.0), nnx.BatchStat(-10)]), 115 | b=nnx.Param(2.0), 116 | c=100, 117 | ) 118 | 119 | params = m.filter(nnx.Param) 120 | 121 | def loss(params): 122 | return sum(2 * p for p in jax.tree_util.tree_leaves(params)) 123 | 124 | grads = jax.grad(loss)(params) 125 | m.update_state(grads) 126 | 127 | assert m.a[0] == 2.0 128 | assert m.a[1] == -10 129 | assert m.b == 2.0 130 | assert m.c == 100 131 | 132 | def test_get_paritition(self): 133 | m = nnx.Dict( 134 | a=nnx.Sequence([nnx.Param(10.0), nnx.Param(20.0)]), 135 | b=nnx.Param(10.0), 136 | c=7, 137 | d=5.0, 138 | ) 139 | 140 | # test Variables not shared 141 | assert vars(m.a)["0"] is not vars(m)["b"] 142 | 143 | state = m.filter(nnx.Node) 144 | assert state["a/0"].value == m.a[0] 145 | assert state["a/1"].value == m.a[1] 146 | assert state["b"].value == m.b 147 | assert state["b"] is not state["a/0"] 148 | assert len(state) == 3 149 | -------------------------------------------------------------------------------- /tests/test_pytree.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar 2 | 3 | import jax 4 | import pytest 5 | from flax import serialization 6 | 7 | import nnx 8 | 9 | 10 | class TestPytree: 11 | 12 | def test_immutable_pytree(self): 13 | class Foo(nnx.Pytree): 14 | 15 | def __init__(self, y) -> None: 16 | self.x = 2 17 | self.y = nnx.Node(y) 18 | 19 | pytree = Foo(y=3) 20 | 21 | leaves = jax.tree_util.tree_leaves(pytree) 22 | assert leaves == [3] 23 | 24 | pytree = jax.tree_map(lambda x: x * 2, pytree) 25 | assert pytree.x == 2 26 | assert pytree.y == 6 27 | 28 | pytree = pytree.replace(x=3) 29 | assert pytree.x == 3 30 | assert pytree.y == 6 31 | 32 | with pytest.raises(AttributeError, match="is immutable, trying to update field"): 33 | pytree.x = 4 34 | 35 | def test_immutable_pytree_dataclass(self): 36 | @nnx.dataclass(frozen=True) 37 | class Foo(nnx.Pytree): 38 | y: int = nnx.node_field() 39 | x: int = nnx.field(default=2) 40 | 41 | pytree = Foo(y=3) 42 | 43 | leaves = jax.tree_util.tree_leaves(pytree) 44 | assert leaves == [3] 45 | 46 | pytree = jax.tree_map(lambda x: x * 2, pytree) 47 | assert pytree.x == 2 48 | assert pytree.y == 6 49 | 50 | pytree = pytree.replace(x=3) 51 | assert pytree.x == 3 52 | assert pytree.y == 6 53 | 54 | with pytest.raises(AttributeError, match="cannot assign to field"): 55 | pytree.x = 4 56 | 57 | def test_jit(self): 58 | @nnx.dataclass 59 | class Foo(nnx.Pytree): 60 | a: int = nnx.node_field() 61 | b: int = nnx.field() 62 | 63 | module = Foo(a=1, b=2) 64 | 65 | @jax.jit 66 | def f(m: Foo): 67 | return m.a + m.b 68 | 69 | assert f(module) == 3 70 | 71 | def test_flax_serialization(self): 72 | class Bar(nnx.Pytree): 73 | 74 | def __init__(self, a, b): 75 | self.a = a 76 | self.b = nnx.Node(b) 77 | 78 | @nnx.dataclass 79 | class Foo(nnx.Pytree): 80 | bar: Bar 81 | c: int = nnx.node_field() 82 | d: int = nnx.field() 83 | 84 | foo: Foo = Foo(bar=Bar(a=1, b=2), c=3, d=4) 85 | 86 | state_dict = serialization.to_state_dict(foo) 87 | 88 | assert state_dict == { 89 | "bar": { 90 | "b": 2, 91 | }, 92 | "c": 3, 93 | } 94 | 95 | state_dict["bar"]["b"] = 5 96 | 97 | foo = serialization.from_state_dict(foo, state_dict) 98 | 99 | assert foo.bar.b == 5 100 | 101 | del state_dict["bar"]["b"] 102 | 103 | with pytest.raises(ValueError, match="Missing field"): 104 | serialization.from_state_dict(foo, state_dict) 105 | 106 | state_dict["bar"]["b"] = 5 107 | 108 | # add unknown field 109 | state_dict["x"] = 6 110 | 111 | with pytest.raises(ValueError, match="Unknown field"): 112 | serialization.from_state_dict(foo, state_dict) 113 | 114 | def test_generics(self): 115 | T = TypeVar("T") 116 | 117 | class MyClass(nnx.Pytree, Generic[T]): 118 | 119 | def __init__(self, x: T): 120 | self.x = x 121 | 122 | MyClass[int] 123 | 124 | def test_key_paths(self): 125 | @nnx.dataclass 126 | class Bar(nnx.Pytree): 127 | a: int = nnx.node_field(default=1) 128 | b: int = nnx.field(default=2) 129 | 130 | @nnx.dataclass 131 | class Foo(nnx.Pytree): 132 | x: int = nnx.node_field(default=3) 133 | y: int = nnx.field(default=4) 134 | z: Bar = nnx.node_field(default_factory=Bar) 135 | 136 | foo = Foo() 137 | 138 | path_values, treedef = jax.tree_util.tree_flatten_with_path(foo) 139 | path_values = [(list(map(str, path)), value) for path, value in path_values] 140 | 141 | assert path_values[0] == ([".x", ".value"], 3) 142 | assert path_values[1] == ([".z", ".value", ".a", ".value"], 1) 143 | 144 | def test_replace_unknown_fields_error(self): 145 | class Foo(nnx.Pytree): 146 | pass 147 | 148 | with pytest.raises(ValueError, match="Trying to replace unknown fields"): 149 | Foo().replace(y=1) 150 | 151 | def test_dataclass_inheritance(self): 152 | @nnx.dataclass 153 | class A(nnx.Pytree): 154 | a: int = nnx.node_field(default=1) 155 | b: int = nnx.field(default=2) 156 | 157 | @nnx.dataclass 158 | class B(A): 159 | c: int = nnx.node_field(default=3) 160 | 161 | pytree = B() 162 | leaves = jax.tree_util.tree_leaves(pytree) 163 | assert leaves == [1, 3] 164 | 165 | def test_pytree_with_new(self): 166 | class A(nnx.Pytree): 167 | 168 | def __init__(self, a): 169 | self.a = a 170 | 171 | def __new__(cls, a): 172 | return super().__new__(cls) 173 | 174 | pytree = A(a=1) 175 | 176 | pytree = jax.tree_map(lambda x: x * 2, pytree) 177 | 178 | def test_deterministic_order(self): 179 | class A(nnx.Pytree): 180 | 181 | def __init__(self, order: bool): 182 | if order: 183 | self.a = 1 184 | self.b = 2 185 | else: 186 | self.b = 2 187 | self.a = 1 188 | 189 | p1 = A(order=True) 190 | p2 = A(order=False) 191 | 192 | leaves1 = jax.tree_util.tree_leaves(p1) 193 | leaves2 = jax.tree_util.tree_leaves(p2) 194 | 195 | assert leaves1 == leaves2 196 | 197 | 198 | class TestMutablePytree: 199 | 200 | def test_pytree(self): 201 | class Foo(nnx.Pytree, mutable=True): 202 | 203 | def __init__(self, y) -> None: 204 | self.x = 2 205 | self.y = nnx.Node(y) 206 | 207 | pytree = Foo(y=3) 208 | 209 | leaves = jax.tree_util.tree_leaves(pytree) 210 | assert leaves == [3] 211 | 212 | pytree = jax.tree_map(lambda x: x * 2, pytree) 213 | assert pytree.x == 2 214 | assert pytree.y == 6 215 | 216 | pytree = pytree.replace(x=3) 217 | assert pytree.x == 3 218 | assert pytree.y == 6 219 | 220 | # test mutation 221 | pytree.x = 4 222 | assert pytree.x == 4 223 | 224 | def test_no_new_fields_after_init(self): 225 | class Foo(nnx.Pytree, mutable=True): 226 | 227 | def __init__(self, x): 228 | self.x = nnx.Node(x) 229 | 230 | foo = Foo(x=1) 231 | foo.x = 2 232 | 233 | with pytest.raises(AttributeError, match=r"Cannot add new fields to"): 234 | foo.y = 2 235 | 236 | def test_pytree_dataclass(self): 237 | @nnx.dataclass 238 | class Foo(nnx.Pytree, mutable=True): 239 | y: int = nnx.node_field() 240 | x: int = nnx.field(default=2) 241 | 242 | pytree: Foo = Foo(y=3) 243 | 244 | leaves = jax.tree_util.tree_leaves(pytree) 245 | assert leaves == [3] 246 | 247 | pytree = jax.tree_map(lambda x: x * 2, pytree) 248 | assert pytree.x == 2 249 | assert pytree.y == 6 250 | 251 | pytree = pytree.replace(x=3) 252 | assert pytree.x == 3 253 | assert pytree.y == 6 254 | 255 | # test mutation 256 | pytree.x = 4 257 | assert pytree.x == 4 258 | -------------------------------------------------------------------------------- /tests/test_spmd.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import optax 4 | from jax._src import test_util as jtu 5 | from jax.experimental import mesh_utils 6 | from jax.sharding import Mesh, PartitionSpec 7 | 8 | import nnx 9 | 10 | 11 | class TestSPMD: 12 | 13 | @jtu.skip_on_devices("cpu", "gpu") 14 | def test_init(self): 15 | class Foo(nnx.Module): 16 | 17 | def __init__(self): 18 | self.w = nnx.Param( 19 | nnx.with_logical_partitioning( 20 | lambda: jnp.ones((8, 2)), 21 | sharding=("row", "col"), 22 | )() 23 | ) 24 | 25 | def __call__(self, x): 26 | return x @ self.w 27 | 28 | @jax.jit 29 | def create_module(): 30 | return Foo().partition() 31 | 32 | mesh = Mesh(mesh_utils.create_device_mesh((2, 2)), ("model", "data")) 33 | 34 | with mesh, nnx.logical_axis_rules([("row", "model"), ("col", "data")]): 35 | m: Foo = create_module().merge() 36 | 37 | assert m.w.shape == (8, 2) 38 | assert m.w.sharding.shard_shape(m.w.shape) == (4, 1) 39 | 40 | def test_get_partition_spec(self): 41 | class Foo(nnx.Module): 42 | 43 | def __init__(self): 44 | self.w = nnx.Param( 45 | nnx.with_logical_partitioning( 46 | lambda: jnp.ones((8, 2)), 47 | sharding=("row", "col"), 48 | )() 49 | ) 50 | 51 | def __call__(self, x): 52 | return x @ self.w 53 | 54 | params, moduledef = Foo().partition() 55 | state = nnx.TrainState( 56 | moduledef, 57 | params=params, 58 | tx=optax.adam(1e-3), 59 | ) 60 | logical_state_spec = nnx.get_partition_spec(state) 61 | 62 | assert logical_state_spec.params["w"] == PartitionSpec("row", "col") 63 | assert logical_state_spec.opt_state[0].mu["w"] == PartitionSpec("row", "col") 64 | assert logical_state_spec.opt_state[0].nu["w"] == PartitionSpec("row", "col") 65 | 66 | with nnx.logical_axis_rules([("row", "model"), ("col", "data")]): 67 | state_spec = nnx.logical_to_mesh(logical_state_spec) 68 | 69 | assert state_spec.params["w"] == PartitionSpec("model", "data") 70 | assert state_spec.opt_state[0].mu["w"] == PartitionSpec("model", "data") 71 | assert state_spec.opt_state[0].nu["w"] == PartitionSpec("model", "data") 72 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import pytest 7 | 8 | import nnx 9 | 10 | 11 | class TestJIT: 12 | 13 | def test_jit(self): 14 | m = nnx.Dict(a=nnx.Param(1)) 15 | 16 | @nnx.jit 17 | def g(m: nnx.Dict): 18 | m.a = 2 19 | return 1.0 20 | 21 | out = g(m) 22 | 23 | assert m.a == 2 24 | assert out == 1.0 25 | 26 | def test_jit_stateless(self): 27 | m = nnx.Dict(a=nnx.Param(1)) 28 | 29 | @partial(nnx.jit, stateful=False) 30 | def g(m: nnx.Dict): 31 | m.a = 2 32 | return 1.0 33 | 34 | out = g(m) 35 | 36 | assert m.a == 1 37 | assert out == 1.0 38 | 39 | 40 | class TestGrad: 41 | 42 | def test_grad(self): 43 | p1 = nnx.Param(10.0) 44 | p2 = nnx.Param(20.0) 45 | 46 | m = nnx.Dict( 47 | a=nnx.Sequence([p1, p2]), 48 | b=p1, 49 | c=7, 50 | d=5.0, 51 | ) 52 | 53 | @nnx.grad 54 | def f(m: nnx.Dict): 55 | # sum all params 56 | return m["a"][0] + m["a"][1] + m["b"] 57 | 58 | grads = f(m) 59 | 60 | assert isinstance(grads, nnx.State) 61 | assert grads["a/0"].value == 1.0 62 | assert isinstance(grads["a/0"], nnx.Node) 63 | assert grads["a/1"].value == 1.0 64 | assert isinstance(grads["a/1"], nnx.Node) 65 | assert grads["b"].value == 1.0 66 | assert isinstance(grads["b"], nnx.Node) 67 | assert len(grads) == 3 68 | 69 | m.update_state(grads) 70 | 71 | assert m["a"][0] == 1.0 72 | assert m["a"][1] == 1.0 73 | assert m["b"] == 1.0 74 | assert m["c"] == 7 75 | assert m["d"] == 5.0 76 | 77 | def test_grad_with_multiple_ref_types(self): 78 | m = nnx.Dict( 79 | a=nnx.Sequence([nnx.Param(10.0), nnx.BatchStat(20.0)]), 80 | b=nnx.Param(10.0), 81 | c=7, 82 | d=5.0, 83 | ) 84 | 85 | @nnx.grad 86 | def f(m: nnx.Dict): 87 | # sum all params 88 | return m.a[0] + m.a[1] + m.b 89 | 90 | grads = f(m) 91 | 92 | assert isinstance(grads, nnx.State) 93 | assert grads["a/0"].value == 1.0 94 | assert isinstance(grads["a/0"], nnx.Param) 95 | assert len(grads) == 2 96 | 97 | m.update_state(grads) 98 | 99 | assert m.a[0] == 1.0 100 | assert m.a[1] == 20.0 101 | assert m.b == 1.0 102 | assert m.c == 7 103 | assert m.d == 5.0 104 | 105 | def test_grad_with_type_predicate(self): 106 | m = nnx.Dict( 107 | a=nnx.Sequence([nnx.Param(10.0), nnx.BatchStat(20.0)]), 108 | b=nnx.Param(10.0), 109 | c=7, 110 | d=5.0, 111 | ) 112 | 113 | @partial(nnx.grad, wrt=nnx.BatchStat) 114 | def f(m: nnx.Dict): 115 | # sum all params 116 | return m.a[0] + m.a[1] + m.b 117 | 118 | grads = f(m) 119 | 120 | assert isinstance(grads, nnx.State) 121 | assert grads["a/1"].value == 1.0 122 | assert isinstance(grads["a/1"], nnx.BatchStat) 123 | assert len(grads) == 1 124 | 125 | m.update_state(grads) 126 | 127 | assert m.a[0] == 10.0 128 | assert m.a[1] == 1.0 129 | assert m.b == 10.0 130 | assert m.c == 7 131 | assert m.d == 5.0 132 | 133 | 134 | class TestScan: 135 | 136 | def test_basic(self): 137 | class Block(nnx.Module): 138 | 139 | def __init__(self, *, ctx: nnx.Context): 140 | self.linear = nnx.Linear(3, 3, ctx=ctx) 141 | self.node = jnp.ones((2,)) 142 | 143 | def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: 144 | jax.debug.print("x={x}", x=x) 145 | x = self.linear(x) 146 | x = nnx.gelu(x) 147 | return x, None 148 | 149 | MLP = nnx.Scan(Block, variable_axes={nnx.Param: 0}, split_rngs="params", length=5) 150 | 151 | module = MLP(ctx=nnx.context(0)) 152 | 153 | assert module.scan_module.linear.kernel.shape == (5, 3, 3) 154 | assert module.scan_module.linear.bias.shape == (5, 3) 155 | assert module.scan_module.node.shape == (2,) 156 | 157 | x = jnp.ones((1, 3)) 158 | y, out = module.call(x, None) 159 | 160 | assert y.shape == (1, 3) 161 | assert out is None 162 | 163 | def test_complex(self): 164 | class Block(nnx.Module): 165 | 166 | def __init__(self, *, ctx: nnx.Context): 167 | self.linear = nnx.Linear(3, 3, ctx=ctx) 168 | self.bn = nnx.BatchNorm(3, ctx=ctx) 169 | self.dropout = nnx.Dropout(0.5) 170 | self.node = jnp.ones((2,)) 171 | 172 | def __call__( 173 | self, x: jax.Array, _, *, ctx: nnx.Context 174 | ) -> tp.Tuple[jax.Array, None]: 175 | jax.debug.print("x={x}", x=x) 176 | x = self.linear(x) 177 | x = self.bn(x, ctx=ctx) 178 | x = self.dropout(x, ctx=ctx) 179 | x = nnx.gelu(x) 180 | return x, None 181 | 182 | MLP = nnx.Scan( 183 | Block, 184 | variable_axes={nnx.Param: 0}, 185 | # variable_carry="batch_stats", 186 | split_rngs=["params", "dropout"], 187 | length=5, 188 | ) 189 | 190 | module = MLP(ctx=nnx.context(0)) 191 | 192 | assert module.scan_module.linear.kernel.shape == (5, 3, 3) 193 | assert module.scan_module.linear.bias.shape == (5, 3) 194 | assert module.scan_module.node.shape == (2,) 195 | 196 | x = jnp.ones((1, 3)) 197 | ctx = nnx.context( 198 | dropout=1, flags=dict(deterministic=False, use_running_average=False) 199 | ) 200 | y, out = module.call(x, None, ctx=ctx) 201 | 202 | assert y.shape == (1, 3) 203 | assert out is None 204 | 205 | def test_complex_decorator(self): 206 | scan_over_layers = partial( 207 | nnx.scan, 208 | variable_axes={nnx.Param: 0}, 209 | split_rngs=["params", "dropout"], 210 | length=5, 211 | ) 212 | 213 | class Block(nnx.Module): 214 | 215 | @scan_over_layers 216 | def __init__(self, *, ctx: nnx.Context): 217 | self.linear = nnx.Linear(3, 3, ctx=ctx) 218 | self.bn = nnx.BatchNorm(3, ctx=ctx) 219 | self.dropout = nnx.Dropout(0.5) 220 | self.node = jnp.ones((2,)) 221 | 222 | @scan_over_layers 223 | def __call__( 224 | self, x: jax.Array, _, *, ctx: nnx.Context 225 | ) -> tp.Tuple[jax.Array, None]: 226 | jax.debug.print("x={x}", x=x) 227 | x = self.linear(x) 228 | x = self.bn(x, ctx=ctx) 229 | x = self.dropout(x, ctx=ctx) 230 | x = nnx.gelu(x) 231 | return x, None 232 | 233 | module = Block(ctx=nnx.context(0)) 234 | 235 | assert module.linear.kernel.shape == (5, 3, 3) 236 | assert module.linear.bias.shape == (5, 3) 237 | assert module.node.shape == (2,) 238 | 239 | x = jnp.ones((1, 3)) 240 | ctx = nnx.context( 241 | dropout=1, flags=dict(deterministic=False, use_running_average=False) 242 | ) 243 | y, out = module(x, None, ctx=ctx) 244 | 245 | assert y.shape == (1, 3) 246 | assert out is None 247 | 248 | def test_scan_with_sharding(self): 249 | class Block(nnx.Module): 250 | 251 | def __init__(self, *, ctx: nnx.Context): 252 | self.linear = nnx.Linear( 253 | 3, 254 | 3, 255 | kernel_init=nnx.with_metadata( 256 | nnx.initializers.lecun_normal(), 257 | sharding=("din", "dout"), 258 | ), 259 | bias_init=nnx.with_metadata( 260 | nnx.initializers.zeros(), 261 | sharding=("dout",), 262 | ), 263 | ctx=ctx, 264 | ) 265 | 266 | def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: 267 | x = self.linear(x) 268 | 269 | # test sharding layer axes is not present inside scan 270 | state = self.linear.get_state() 271 | assert state["kernel"].value.shape == (3, 3) 272 | assert state["kernel"].sharding == ("din", "dout") 273 | assert state["bias"].value.shape == (3,) 274 | assert state["bias"].sharding == ("dout",) 275 | 276 | return x, None 277 | 278 | MLP = nnx.Scan( 279 | Block, 280 | variable_axes={nnx.Param: 0}, 281 | split_rngs=["params"], 282 | length=5, 283 | metadata_params={nnx.PARTITION_NAME: "layers"}, 284 | ) 285 | 286 | m = MLP(ctx=nnx.context(0)) 287 | 288 | # test sharding layers axes is set 289 | state = m.get_state() 290 | assert state["scan_module/linear/kernel"].value.shape == (5, 3, 3) 291 | assert state["scan_module/linear/kernel"].sharding == ("layers", "din", "dout") 292 | assert state["scan_module/linear/bias"].value.shape == (5, 3) 293 | assert state["scan_module/linear/bias"].sharding == ("layers", "dout") 294 | 295 | x = jnp.ones((1, 3)) 296 | y, out = m.call(x, None) 297 | 298 | # test sharding axes is preserved 299 | state = m.get_state() 300 | assert state["scan_module/linear/kernel"].value.shape == (5, 3, 3) 301 | assert state["scan_module/linear/kernel"].sharding == ("layers", "din", "dout") 302 | assert state["scan_module/linear/bias"].value.shape == (5, 3) 303 | assert state["scan_module/linear/bias"].sharding == ("layers", "dout") 304 | 305 | 306 | class TestRemat: 307 | 308 | def test_basic_remat(self): 309 | RematLinear = nnx.Remat(nnx.Linear) 310 | 311 | module = RematLinear(2, 3, ctx=nnx.context(0)) 312 | 313 | y = module.call(jnp.ones((1, 2))) 314 | 315 | assert y.shape == (1, 3) 316 | 317 | def test_remat_decorator(self): 318 | class RematLinear(nnx.Module): 319 | 320 | @nnx.remat 321 | def __init__(self, din: int, dout: int, *, ctx: nnx.Context): 322 | self.linear = nnx.Linear(din, dout, ctx=ctx) 323 | 324 | @nnx.remat 325 | def __call__(self, x: jax.Array) -> jax.Array: 326 | return self.linear(x) 327 | 328 | module = RematLinear(2, 3, ctx=nnx.context(0)) 329 | 330 | y = module(jnp.ones((1, 2))) 331 | 332 | assert y.shape == (1, 3) 333 | 334 | def test_remat_with_scan(self): 335 | class LinearBlock(nnx.Module): 336 | 337 | def __init__(self, *, ctx: nnx.Context): 338 | self.linear = nnx.Linear(3, 3, ctx=ctx) 339 | 340 | def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: 341 | x = self.linear(x) 342 | return x, None 343 | 344 | RematLinear = nnx.Remat(LinearBlock) 345 | 346 | ScanRematLinear = nnx.Scan( 347 | RematLinear, variable_axes={nnx.Param: 0}, split_rngs="params", length=5 348 | ) 349 | 350 | m = ScanRematLinear(ctx=nnx.context(0)) 351 | 352 | assert m.scan_module.remat_module.linear.kernel.shape == (5, 3, 3) 353 | assert m.scan_module.remat_module.linear.bias.shape == (5, 3) 354 | 355 | y, _ = m.call.call(jnp.ones((1, 3)), None) 356 | assert y.shape == (1, 3) 357 | 358 | y, _ = m(jnp.ones((1, 3)), None) 359 | assert y.shape == (1, 3) 360 | 361 | def test_remat_with_scan_decorator(self): 362 | scan = partial( 363 | nnx.scan, variable_axes={nnx.Param: 0}, split_rngs="params", length=5 364 | ) 365 | 366 | class ScanLinear(nnx.Module): 367 | 368 | @scan 369 | def __init__(self, *, ctx: nnx.Context): 370 | self.linear = nnx.Linear(3, 3, ctx=ctx) 371 | 372 | @scan 373 | @nnx.remat 374 | def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: 375 | x = self.linear(x) 376 | return x, None 377 | 378 | m = ScanLinear(ctx=nnx.context(0)) 379 | 380 | assert m.linear.kernel.shape == (5, 3, 3) 381 | assert m.linear.bias.shape == (5, 3) 382 | 383 | y, _ = m(jnp.ones((1, 3)), None) 384 | assert y.shape == (1, 3) 385 | -------------------------------------------------------------------------------- /tests/test_variable.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import jax 4 | import pytest 5 | 6 | import nnx 7 | 8 | A = tp.TypeVar("A") 9 | 10 | 11 | class TestVariable: 12 | 13 | def test_value(self): 14 | r1 = nnx.Node(1) 15 | assert r1.value == 1 16 | 17 | r2 = jax.tree_map(lambda x: x + 1, r1) 18 | 19 | assert r1.value == 1 20 | assert r2.value == 2 21 | assert r1 is not r2 22 | --------------------------------------------------------------------------------