├── .github └── workflows │ ├── ci.yml │ ├── docs.yml │ └── pypi.yml ├── .gitignore ├── .pylintrc ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── api │ ├── basics.rst │ ├── experimental.rst │ ├── modules.rst │ ├── transformations.rst │ └── utilities.rst ├── conf.py ├── index.rst ├── notebooks │ ├── basics.ipynb │ ├── jax_transformations.ipynb │ ├── operators.ipynb │ ├── performance.ipynb │ ├── training.ipynb │ └── understanding.ipynb └── requirements.txt ├── examples ├── basics.py ├── char_rnn.py ├── dataclass_module.py ├── denoising_diffusion │ ├── README.md │ ├── data_loader.py │ ├── model.py │ ├── requirements.txt │ └── train.py ├── graph_module.py ├── lazy_module.py ├── mnist.py ├── mnist_mixed_precision.py ├── notebooks │ ├── DCGAN.ipynb │ ├── VAE.ipynb │ ├── adversarial_examples.ipynb │ ├── fine_tuning_resnet18.ipynb │ ├── mixed_precision.ipynb │ ├── pretrained_resnet18.py │ └── test_pretrained_resnet18.ipynb ├── transformer │ ├── data.py │ ├── model.py │ └── train.py └── wave_gru │ ├── README.md │ ├── data_loader.py │ ├── model.py │ ├── prepare_data.sh │ ├── requirements.txt │ └── train.py ├── images └── pax_logo.png ├── pax ├── __init__.py ├── _src │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── base.py │ │ ├── graph_module.py │ │ ├── mixed_precision.py │ │ ├── module.py │ │ ├── module_and_value.py │ │ ├── mutable.py │ │ ├── pure.py │ │ ├── rng.py │ │ ├── safe_module.py │ │ ├── threading_local.py │ │ ├── transforms.py │ │ ├── utility_modules.py │ │ └── utils.py │ ├── nets │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── transformer.py │ ├── nn │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── batch_norm.py │ │ ├── conv.py │ │ ├── dropout.py │ │ ├── ema.py │ │ ├── embed.py │ │ ├── group_norm.py │ │ ├── identity.py │ │ ├── lambda_module.py │ │ ├── layer_norm.py │ │ ├── linear.py │ │ ├── pool.py │ │ ├── recurrent.py │ │ ├── rng_seq.py │ │ └── sequential.py │ └── utils.py ├── experimental │ ├── __init__.py │ └── graph.py ├── nets.py ├── py.typed └── utils.py ├── setup.py └── tests ├── test_auto_modules.py ├── test_counter.py ├── test_deepscan.py ├── test_finetune.py ├── test_freeze_unfreeze.py ├── test_graph_module.py ├── test_immutability.py ├── test_jax_transform.py ├── test_mixed_precision.py ├── test_multithread.py ├── test_nets.py ├── test_nn.py ├── test_optim.py ├── test_pax.py ├── test_performance.py ├── test_pure.py ├── test_summary.py ├── test_training.py ├── test_transforms.py └── test_utils.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test-ubuntu: 13 | name: "Test on ${{ matrix.python-version }} on ${{ matrix.os }}" 14 | runs-on: "${{ matrix.os }}" 15 | strategy: 16 | matrix: 17 | python-version: [3.7, 3.8, 3.9] 18 | os: [ubuntu-latest] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install .[test] 29 | - name: Test with pytest 30 | run: | 31 | pip install pytest pytest-xdist 32 | pytest -n auto -k "not perf" tests 33 | # pytest -n 1 -k "perf" tests 34 | - name: Test with pytype 35 | run: | 36 | pip install pytype 37 | pytype pax tests 38 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | # Source: https://raw.githubusercontent.com/deepmind/dm-haiku/0a28e731938ef932ed6c33555fb1051bea0b29bd/.github/workflows/docs.yml 2 | # Apache-2.0 License 3 | 4 | name: docs 5 | 6 | on: 7 | pull_request: 8 | branches: 9 | - main 10 | push: 11 | branches: 12 | - main 13 | 14 | jobs: 15 | test-ubuntu: 16 | name: "docs on ${{ matrix.python-version }} on ${{ matrix.os }}" 17 | runs-on: "${{ matrix.os }}" 18 | strategy: 19 | matrix: 20 | python-version: [3.7, 3.8, 3.9] 21 | os: [ubuntu-latest] 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v1 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | sudo apt install -y pandoc 31 | python -m pip install --upgrade pip 32 | pip install .[test] 33 | pip install -r docs/requirements.txt 34 | - name: Test doctests 35 | run: | 36 | cd docs 37 | make doctest 38 | - name: Test docs to HTML 39 | run: | 40 | cd docs 41 | make html -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | sphinx: 7 | builder: html 8 | configuration: docs/conf.py 9 | fail_on_warning: false 10 | 11 | python: 12 | version: 3.7 13 | install: 14 | - requirements: docs/requirements.txt 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Thông Nguyễn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | logo 3 |
4 | 5 | [**Introduction**](#introduction) 6 | | [**Getting started**](#gettingstarted) 7 | | [**Functional programming**](#functional) 8 | | [**Examples**](https://github.com/ntt123/pax/tree/main/examples/) 9 | | [**Modules**](#modules) 10 | | [**Fine-tuning**](#finetune) 11 | 12 | ![pytest](https://github.com/ntt123/pax/workflows/pytest/badge.svg) 13 | ![docs](https://readthedocs.org/projects/pax/badge/?version=main) 14 | ![pypi](https://img.shields.io/pypi/v/pax3) 15 | 16 | 17 | ## Introduction 18 | 19 | PAX is a [JAX]-based library for training neural networks. 20 | 21 | PAX modules are registered as JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), therefore, they can be input or output of JAX transformations such as `jax.jit`, `jax.grad`, etc. This makes programming with modules very convenient and easy to understand. 22 | 23 | ## Installation 24 | 25 | Install from PyPI: 26 | 27 | ```bash 28 | pip install pax3 29 | ``` 30 | 31 | Or install the latest version from Github: 32 | 33 | ```bash 34 | pip install git+https://github.com/ntt123/pax.git 35 | 36 | ## or test mode to run tests and examples 37 | pip install git+https://github.com/ntt123/pax.git#egg=pax3[test] 38 | ``` 39 | 40 | 41 | ## Getting started 42 | 43 | 44 | Below is a simple example of a `Linear` module. 45 | 46 | ```python 47 | import jax.numpy as jnp 48 | import pax 49 | 50 | class Linear(pax.Module): 51 | weight: jnp.ndarray 52 | bias: jnp.ndarray 53 | parameters = pax.parameters_method("weight", "bias") 54 | 55 | def __init__(self): 56 | super().__init__() 57 | self.weight = jnp.array(0.0) 58 | self.bias = jnp.array(0.0) 59 | 60 | def __call__(self, x): 61 | return self.weight * x + self.bias 62 | ``` 63 | 64 | The implementation is very similar to a normal python class. However, we need an additional line 65 | 66 | ```python 67 | parameters = pax.parameters_method("weight", "bias") 68 | ``` 69 | 70 | to declare that `weight` and `bias` are *trainable parameters* of the Linear module. 71 | 72 | ## PAX functional programming 73 | 74 | ### `pax.pure` 75 | 76 | A PAX module can have internal states. For example, below is a simple `Counter` module with an internal counter. 77 | 78 | ```python 79 | class Counter(pax.Module): 80 | count : jnp.ndarray 81 | 82 | def __init__(self): 83 | super().__init__() 84 | self.count = jnp.array(0) 85 | 86 | def __call__(self): 87 | self.count = self.count + 1 88 | return self.count 89 | ``` 90 | 91 | However, PAX *aims* to guarantee that modules will have no side effects from the outside point of view. 92 | Therefore, the modifications of these internal states are restricted. For example, we get an error when trying to call `Counter` directly. 93 | 94 | ```python 95 | counter = Counter() 96 | count = counter() 97 | # ... 98 | # ----> 9 self.count = self.count + 1 99 | # ... 100 | # ValueError: Cannot modify a module in immutable mode. 101 | # Please do this computation inside a function decorated by `pax.pure`. 102 | ``` 103 | 104 | Only functions decorated by `pax.pure` are allowed to modify input module's internal states. 105 | 106 | ```python 107 | @pax.pure 108 | def update_counter(counter: Counter): 109 | count = counter() 110 | return counter, count 111 | 112 | counter, count = update_counter(counter) 113 | print(counter.count, count) 114 | # 1 1 115 | ``` 116 | 117 | Note that we have to return `counter` in the output of `update_counter`, otherwise, the `counter` object will not be updated. This is because `pax.pure` only provides `update_counter` a copy of the `counter` object. 118 | 119 | 120 | ### `pax.purecall` 121 | 122 | For convenience, PAX provides the `pax.purecall` function. 123 | It is a shortcut for `pax.pure(lambda f, x: [f, f(x)])`. 124 | 125 | Instead of implementing an `update_counter` function, we can do the same thing with: 126 | 127 | ```python 128 | counter, count = pax.purecall(counter) 129 | print(counter.count, count) 130 | # 2, 2 131 | ``` 132 | 133 | ### Replacing parts 134 | 135 | PAX provides utility methods to modify a module in a functional way. 136 | 137 | The `replace` method creates a new module with attributes replaced. 138 | For example, to replace `weight` and `bias` of a `pax.Linear` module: 139 | 140 | ```python 141 | fc = pax.Linear(2, 2) 142 | fc = fc.replace(weight=jnp.ones((2,2)), bias=jnp.zeros((2,))) 143 | ``` 144 | 145 | The `replace_node` method replaces a pytree node of a module: 146 | 147 | ```python 148 | f = pax.Sequential( 149 | pax.Linear(2, 3), 150 | pax.Linear(3, 4), 151 | ) 152 | 153 | f = f.replace_node(f[-1], pax.Linear(3, 5)) 154 | print(f.summary()) 155 | # Sequential 156 | # ├── Linear(in_dim=2, out_dim=3, with_bias=True) 157 | # └── Linear(in_dim=3, out_dim=5, with_bias=True) 158 | ``` 159 | 160 | ## PAX and other libraries 161 | 162 | PAX learns a lot from other libraries: 163 | - PAX borrows the idea that _a module is also a pytree_ from [treex] and [equinox]. 164 | - PAX uses the concept of _trainable parameters_ and _non-trainable states_ from [dm-haiku]. 165 | - PAX has similar methods to PyTorch such as `model.apply()`, `model.parameters()`, `model.eval()`, etc. 166 | - PAX uses [objax]'s approach to implement optimizers as modules. 167 | - PAX uses [jmp] library for supporting mixed precision. 168 | - And of course, PAX is heavily influenced by [jax] functional programming approach. 169 | 170 | 171 | ## Examples 172 | 173 | A good way to learn about ``PAX`` is to see examples in the [examples/](./examples) directory. 174 | 175 | 176 |
177 | Click to expand 178 | 179 | | Path | Description | 180 | |----------|-----------------------| 181 | | ``char_rnn.py`` | train a RNN language model on TPU. | 182 | | ``transformer/`` | train a Transformer language model on TPU. | 183 | | ``mnist.py`` | train an image classifier on `MNIST` dataset. | 184 | | ``notebooks/VAE.ipynb`` | train a variational autoencoder. | 185 | | ``notebooks/DCGAN.ipynb`` | train a DCGAN model on `Celeb-A` dataset. | 186 | | ``notebooks/fine_tuning_resnet18.ipynb`` | finetune a pretrained ResNet18 model on `cats vs dogs` dataset. | 187 | | ``notebooks/mixed_precision.ipynb`` | train a U-Net image segmentation with mixed precision. | 188 | | ``mnist_mixed_precision.py`` | train an image classifier with mixed precision. | 189 | | ``wave_gru/`` | train a WaveGRU vocoder: convert mel-spectrogram to waveform. | 190 | | ``denoising_diffusion/`` | train a denoising diffusion model on `Celeb-A` dataset. | 191 | 192 |
193 | 194 | 195 | 196 | 197 | ## Modules 198 | 199 | At the moment, PAX includes: 200 | 201 | * ``pax.Embed``, 202 | * ``pax.Linear``, 203 | * ``pax.{GRU, LSTM}``, 204 | * ``pax.{BatchNorm1D, BatchNorm2D, LayerNorm, GroupNorm}``, 205 | * ``pax.{Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose}``, 206 | * ``pax.{Dropout, Sequential, Identity, Lambda, RngSeq, EMA}``. 207 | 208 | ## Optimizers 209 | 210 | PAX has its optimizers implemented in a separate library [opax](https://github.com/ntt123/opax). The `opax` library supports many common optimizers such as `adam`, `adamw`, `sgd`, `rmsprop`. Visit opax's GitHub repository for more information. 211 | 212 | 213 | ## Fine-tunning models 214 | 215 | PAX's Module provides the ``pax.freeze_parameters`` transformation to convert all trainable parameters to non-trainable states. 216 | 217 | ```python 218 | net = pax.Sequential( 219 | pax.Linear(28*28, 64), 220 | jax.nn.relu, 221 | pax.Linear(64, 10), 222 | ) 223 | 224 | net = pax.freeze_parameters(net) 225 | net = net.set(-1, pax.Linear(64, 2)) 226 | ``` 227 | 228 | After this, ``net.parameters()`` will only return trainable parameters of the last layer. 229 | 230 | 231 | [jax]: https://github.com/google/jax 232 | [objax]: https://github.com/google/objax 233 | [dm-haiku]: https://github.com/deepmind/dm-haiku 234 | [optax]: https://github.com/deepmind/optax 235 | [jmp]: https://github.com/deepmind/jmp 236 | [pytorch]: https://github.com/pytorch/pytorch 237 | [treex]: https://github.com/cgarciae/treex 238 | [equinox]: https://github.com/patrick-kidger/equinox 239 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/api/basics.rst: -------------------------------------------------------------------------------- 1 | PAX Basics 2 | ========== 3 | 4 | .. currentmodule:: pax 5 | 6 | .. autosummary:: 7 | Module 8 | EmptyNode 9 | pure 10 | purecall 11 | seed_rng_key 12 | next_rng_key 13 | 14 | 15 | 16 | PAX's Module 17 | ------------ 18 | 19 | .. currentmodule:: pax 20 | 21 | .. autoclass:: Module 22 | :members: 23 | __init__, 24 | parameters, 25 | training, 26 | train, 27 | eval, 28 | update_parameters, 29 | replace, 30 | replace_node, 31 | summary, 32 | apply, 33 | state_dict, 34 | load_state_dict, 35 | __or__, 36 | __mod__ 37 | 38 | 39 | 40 | .. autoclass:: ParameterModule 41 | :members: 42 | 43 | 44 | .. autoclass:: StateModule 45 | :members: 46 | 47 | 48 | .. autoclass:: EmptyNode 49 | :members: 50 | 51 | 52 | Purify functions and methods 53 | ---------------------------- 54 | 55 | .. currentmodule:: pax 56 | 57 | .. autofunction:: pure 58 | 59 | .. autofunction:: purecall 60 | 61 | 62 | Random Number Generator 63 | ----------------------- 64 | 65 | .. autosummary:: 66 | 67 | seed_rng_key 68 | next_rng_key 69 | 70 | 71 | seed_rng_key 72 | ~~~~~~~~~~~~ 73 | 74 | .. autofunction:: seed_rng_key 75 | 76 | 77 | next_rng_key 78 | ~~~~~~~~~~~~ 79 | 80 | .. autofunction:: next_rng_key 81 | -------------------------------------------------------------------------------- /docs/api/experimental.rst: -------------------------------------------------------------------------------- 1 | Experimental 2 | ============ 3 | 4 | .. currentmodule:: pax.experimental 5 | 6 | 7 | .. autosummary:: 8 | mutable 9 | Flattener 10 | LazyModule 11 | graph.build_graph_module 12 | default_mp_policy 13 | apply_scaled_gradients 14 | save_weights_to_dict 15 | load_weights_from_dict 16 | 17 | 18 | Mutable 19 | ------- 20 | 21 | .. autofunction:: mutable 22 | 23 | 24 | Flattener 25 | --------- 26 | 27 | .. autoclass:: Flattener 28 | :members: 29 | 30 | 31 | Graph API 32 | --------- 33 | 34 | .. currentmodule:: pax.experimental.graph 35 | 36 | .. autoclass:: Node 37 | :members: 38 | 39 | .. autoclass:: InputNode 40 | :members: 41 | 42 | .. autoclass:: GraphModule 43 | :members: 44 | 45 | .. autofunction:: build_graph_module 46 | 47 | 48 | Lazy Module 49 | ----------- 50 | 51 | .. currentmodule:: pax.experimental 52 | 53 | .. autoclass:: LazyModule 54 | :members: 55 | 56 | 57 | Mixed Precision 58 | --------------- 59 | 60 | .. currentmodule:: pax.experimental 61 | 62 | .. autofunction:: default_mp_policy 63 | .. autofunction:: apply_scaled_gradients 64 | 65 | 66 | Save and load weights 67 | --------------------- 68 | 69 | .. currentmodule:: pax.experimental 70 | 71 | .. autofunction:: save_weights_to_dict 72 | .. autofunction:: load_weights_from_dict 73 | -------------------------------------------------------------------------------- /docs/api/modules.rst: -------------------------------------------------------------------------------- 1 | Common Modules 2 | ============== 3 | 4 | .. currentmodule:: pax 5 | 6 | .. autosummary:: 7 | Linear 8 | Conv1D 9 | Conv2D 10 | Conv1DTranspose 11 | Conv2DTranspose 12 | BatchNorm1D 13 | BatchNorm2D 14 | LayerNorm 15 | GroupNorm 16 | Sequential 17 | VanillaRNN 18 | LSTM 19 | GRU 20 | MultiHeadAttention 21 | Identity 22 | avg_pool 23 | max_pool 24 | 25 | 26 | 27 | 28 | Linear 29 | ------ 30 | 31 | 32 | .. autoclass:: Linear 33 | :members: 34 | 35 | 36 | Dropout 37 | ------- 38 | 39 | .. autoclass:: Dropout 40 | :members: 41 | 42 | 43 | Embed 44 | ----- 45 | 46 | .. autoclass:: Embed 47 | :members: 48 | 49 | 50 | Convolution 51 | ----------- 52 | 53 | Conv1D 54 | ~~~~~~ 55 | 56 | .. autoclass:: Conv1D 57 | :members: 58 | 59 | Conv2D 60 | ~~~~~~ 61 | 62 | .. autoclass:: Conv2D 63 | :members: 64 | 65 | Conv1DTranspose 66 | ~~~~~~~~~~~~~~~ 67 | 68 | .. autoclass:: Conv1DTranspose 69 | :members: 70 | 71 | Conv2DTranspose 72 | ~~~~~~~~~~~~~~~ 73 | 74 | .. autoclass:: Conv2DTranspose 75 | :members: 76 | 77 | 78 | Normalization 79 | ------------- 80 | 81 | 82 | BatchNorm1D 83 | ~~~~~~~~~~~ 84 | 85 | .. autoclass:: BatchNorm1D 86 | :members: 87 | 88 | BatchNorm2D 89 | ~~~~~~~~~~~ 90 | 91 | .. autoclass:: BatchNorm2D 92 | :members: 93 | 94 | 95 | 96 | LayerNorm 97 | ~~~~~~~~~ 98 | 99 | 100 | .. autoclass:: LayerNorm 101 | :members: 102 | 103 | 104 | GroupNorm 105 | ~~~~~~~~~ 106 | 107 | 108 | .. autoclass:: GroupNorm 109 | :members: 110 | 111 | 112 | 113 | Recurrent 114 | --------- 115 | 116 | 117 | VanillaRNN 118 | ~~~~~~~~~~ 119 | 120 | .. autoclass:: VanillaRNN 121 | :members: 122 | 123 | 124 | LSTM 125 | ~~~~ 126 | 127 | .. autoclass:: LSTM 128 | :members: 129 | 130 | 131 | GRU 132 | ~~~ 133 | 134 | .. autoclass:: GRU 135 | :members: 136 | 137 | 138 | Pool 139 | ---- 140 | 141 | avg_pool 142 | ~~~~~~~~ 143 | 144 | .. autofunction:: avg_pool 145 | 146 | 147 | max_pool 148 | ~~~~~~~~ 149 | 150 | .. autofunction:: max_pool 151 | 152 | 153 | 154 | 155 | MultiHeadAttention 156 | ------------------ 157 | 158 | .. autoclass:: MultiHeadAttention 159 | :members: 160 | 161 | 162 | Utilities 163 | --------- 164 | 165 | Sequential 166 | ~~~~~~~~~~ 167 | 168 | .. autoclass:: Sequential 169 | :members: 170 | 171 | 172 | RngSeq 173 | ~~~~~~ 174 | 175 | .. autoclass:: RngSeq 176 | :members: 177 | 178 | 179 | Lambda 180 | ~~~~~~ 181 | 182 | .. autoclass:: Lambda 183 | 184 | 185 | Identity 186 | ~~~~~~~~ 187 | 188 | .. autoclass:: Identity 189 | :members: 190 | 191 | EMA 192 | ~~~ 193 | 194 | .. autoclass:: EMA 195 | :members: 196 | -------------------------------------------------------------------------------- /docs/api/transformations.rst: -------------------------------------------------------------------------------- 1 | Module Transformations 2 | ====================== 3 | 4 | .. currentmodule:: pax 5 | 6 | A module transformation is a pure function that inputs PAX's modules and outputs PAX's modules. 7 | 8 | .. autosummary:: 9 | 10 | update_parameters 11 | enable_train_mode 12 | enable_eval_mode 13 | select_parameters 14 | freeze_parameters 15 | unfreeze_parameters 16 | apply_mp_policy 17 | unwrap_mp_policy 18 | 19 | update_parameters 20 | ----------------- 21 | 22 | .. autofunction:: update_parameters 23 | 24 | 25 | enable_train_mode 26 | ----------------- 27 | 28 | .. autofunction:: enable_train_mode 29 | 30 | 31 | enable_eval_mode 32 | ---------------- 33 | 34 | .. autofunction:: enable_eval_mode 35 | 36 | 37 | select_parameters 38 | ----------------- 39 | 40 | .. autofunction:: select_parameters 41 | 42 | 43 | freeze_parameters 44 | ----------------- 45 | 46 | .. autofunction:: freeze_parameters 47 | 48 | 49 | unfreeze_parameters 50 | ------------------- 51 | 52 | .. autofunction:: unfreeze_parameters 53 | 54 | 55 | apply_mp_policy 56 | --------------- 57 | 58 | .. autofunction:: apply_mp_policy 59 | 60 | 61 | unwrap_mp_policy 62 | ---------------- 63 | 64 | .. autofunction:: unwrap_mp_policy 65 | -------------------------------------------------------------------------------- /docs/api/utilities.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | .. currentmodule:: pax 5 | 6 | 7 | .. autosummary:: 8 | parameters_method 9 | grad 10 | value_and_grad 11 | scan 12 | build_update_fn 13 | 14 | 15 | parameters_method 16 | ----------------- 17 | 18 | .. autofunction:: parameters_method 19 | 20 | 21 | grad 22 | ---- 23 | 24 | .. autofunction:: grad 25 | 26 | 27 | value_and_grad 28 | -------------- 29 | 30 | .. autofunction:: value_and_grad 31 | 32 | 33 | scan 34 | ---- 35 | 36 | .. autofunction:: scan 37 | 38 | 39 | build_update_fn 40 | --------------- 41 | 42 | .. autofunction:: build_update_fn 43 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # This file is an adaptation from 2 | # https://raw.githubusercontent.com/deepmind/dm-haiku/main/docs/conf.py 3 | # which is under Apache License, Version 2.0. 4 | 5 | 6 | # Configuration file for the Sphinx documentation builder. 7 | # 8 | # This file only contains a selection of the most common options. For a full 9 | # list see the documentation: 10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 11 | 12 | # -- Path setup -------------------------------------------------------------- 13 | 14 | # If extensions (or modules to document with autodoc) are in another directory, 15 | # add these directories to sys.path here. If the directory is relative to the 16 | # documentation root, use os.path.abspath to make it absolute, like shown here. 17 | # 18 | import doctest 19 | import inspect 20 | import os 21 | import sys 22 | 23 | sys.path.insert(0, os.path.abspath("..")) 24 | 25 | import pax 26 | import sphinxcontrib.katex as katex 27 | 28 | # -- Project information ----------------------------------------------------- 29 | 30 | project = "PAX" 31 | copyright = "2021, Thông Nguyễn" 32 | author = "Thông Nguyễn" 33 | 34 | 35 | # -- General configuration --------------------------------------------------- 36 | master_doc = "index" 37 | 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | "sphinx.ext.autodoc", 44 | "sphinx.ext.autosummary", 45 | "sphinx.ext.doctest", 46 | "sphinx.ext.inheritance_diagram", 47 | "sphinx.ext.intersphinx", 48 | "sphinx.ext.linkcode", 49 | "sphinx.ext.napoleon", 50 | "sphinxcontrib.bibtex", 51 | "sphinxcontrib.katex", 52 | "sphinx_autodoc_typehints", 53 | "nbsphinx", 54 | "IPython.sphinxext.ipython_console_highlighting", 55 | ] 56 | 57 | 58 | # Add any paths that contain templates here, relative to this directory. 59 | templates_path = ["_templates"] 60 | 61 | # List of patterns, relative to source directory, that match files and 62 | # directories to ignore when looking for source files. 63 | # This pattern also affects html_static_path and html_extra_path. 64 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 65 | 66 | 67 | # -- Options for autodoc ----------------------------------------------------- 68 | 69 | autodoc_default_options = { 70 | "member-order": "bysource", 71 | "special-members": True, 72 | "exclude-members": "__repr__, __str__, __weakref__", 73 | } 74 | 75 | 76 | # -- Options for HTML output ------------------------------------------------- 77 | 78 | 79 | # The theme to use for HTML and HTML Help pages. See the documentation for 80 | # a list of builtin themes. 81 | # 82 | html_theme = "sphinx_rtd_theme" 83 | 84 | # Add any paths that contain custom static files (such as style sheets) here, 85 | # relative to this directory. They are copied after the builtin static files, 86 | # so a file named "default.css" will overwrite the builtin "default.css". 87 | # html_static_path = ["_static"] 88 | 89 | 90 | # -- Options for doctest ----------------------------------------------------- 91 | 92 | doctest_test_doctest_blocks = "true" 93 | doctest_global_setup = """ 94 | import jax 95 | import jax.numpy as jnp 96 | import pax 97 | import opax 98 | pax.seed_rng_key(42) 99 | """ 100 | doctest_default_flags = ( 101 | doctest.ELLIPSIS 102 | | doctest.IGNORE_EXCEPTION_DETAIL 103 | | doctest.DONT_ACCEPT_TRUE_FOR_1 104 | | doctest.NORMALIZE_WHITESPACE 105 | ) 106 | 107 | 108 | # -- Options for katex ------------------------------------------------------ 109 | 110 | # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html 111 | latex_macros = r""" 112 | \def \d #1{\operatorname{#1}} 113 | """ 114 | 115 | # Translate LaTeX macros to KaTeX and add to options for HTML builder 116 | katex_macros = katex.latex_defs_to_katex_macros(latex_macros) 117 | katex_options = "macros: {" + katex_macros + "}" 118 | 119 | # Add LaTeX macros for LATEX builder 120 | latex_elements = {"preamble": latex_macros} 121 | 122 | 123 | # -- Source code links ------------------------------------------------------- 124 | 125 | 126 | def linkcode_resolve(domain, info): 127 | """Resolve a GitHub URL corresponding to Python object.""" 128 | if domain != "py": 129 | return None 130 | 131 | try: 132 | mod = sys.modules[info["module"]] 133 | except ImportError: 134 | return None 135 | 136 | obj = mod 137 | try: 138 | for attr in info["fullname"].split("."): 139 | obj = getattr(obj, attr) 140 | except AttributeError: 141 | return None 142 | else: 143 | obj = inspect.unwrap(obj) 144 | 145 | try: 146 | filename = inspect.getsourcefile(obj) 147 | except TypeError: 148 | return None 149 | 150 | try: 151 | source, lineno = inspect.getsourcelines(obj) 152 | except OSError: 153 | return None 154 | 155 | return "https://github.com/ntt123/pax/blob/main/pax/%s#L%d#L%d" % ( 156 | os.path.relpath(filename, start=os.path.dirname(pax.__file__)), 157 | lineno, 158 | lineno + len(source) - 1, 159 | ) 160 | 161 | 162 | # -- nbsphinx configuration -------------------------------------------------- 163 | 164 | nbsphinx_execute = "never" 165 | nbsphinx_codecell_lexer = "ipython" 166 | nbsphinx_kernel_name = "python" 167 | nbsphinx_timeout = 180 168 | nbsphinx_prolog = r""" 169 | {% set docname = 'docs/' + env.doc2path(env.docname, base=None) %} 170 | 171 | .. only:: html 172 | 173 | .. role:: raw-html(raw) 174 | :format: html 175 | 176 | .. nbinfo:: 177 | 178 | Interactive online version: 179 | :raw-html:`Open In Colab` 180 | """ 181 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/ntt123/pax/tree/main/docs 2 | 3 | 4 | .. PAX documentation master file, created by 5 | sphinx-quickstart on Fri Sep 3 01:09:13 2021. 6 | You can adapt this file completely to your liking, but it should at least 7 | contain the root `toctree` directive. 8 | 9 | PAX documentation 10 | ================= 11 | 12 | PAX is a stateful pytree library for training neural networks using JAX. It is designed to be simple 13 | and easy to use while preserving benefits of JAX. 14 | 15 | 16 | Installation 17 | ------------ 18 | 19 | To install the latest version:: 20 | 21 | pip install git+https://github.com/ntt123/pax.git 22 | 23 | 24 | .. toctree:: 25 | :caption: Guides 26 | :maxdepth: 1 27 | 28 | notebooks/basics 29 | notebooks/training 30 | notebooks/operators 31 | notebooks/understanding 32 | notebooks/jax_transformations 33 | notebooks/performance 34 | 35 | 36 | .. toctree:: 37 | :caption: API Documentation 38 | :maxdepth: 1 39 | 40 | api/basics 41 | api/modules 42 | api/transformations 43 | api/utilities 44 | api/experimental 45 | 46 | 47 | 48 | 49 | PAX is licensed under the MIT License. 50 | 51 | Indices 52 | ======= 53 | 54 | * :ref:`genindex` 55 | -------------------------------------------------------------------------------- /docs/notebooks/operators.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Operators\n", 8 | "\n", 9 | "There are a few operators that help to clean up the implementation.\n", 10 | "\n", 11 | "\n", 12 | "\n", 13 | "| Text | Operator |\n", 14 | "| ----------- | ----------- |\n", 15 | "| `mod, z = pax.purecall(mod, x, y)` | `mod, z = mod % (x, y)` |\n", 16 | "| `mod.parameters()` | `~mod` |\n", 17 | "| `pax.update_pytree(mod1, mod2)` | `mod1 | mod2` |\n", 18 | "| `mod1.update_parameters(mod2)` | `mod1 | ~mod2` |\n", 19 | "| `f = pax.Sequential(mod1, mod2)` | `f = pax.Sequential() >> mod1 >> mod2` |\n", 20 | "\n", 21 | "\n", 22 | "\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [] 29 | } 30 | ], 31 | "metadata": { 32 | "interpreter": { 33 | "hash": "4f946df053fbf2b937619d3c5458e7af74262f9a954d8797ba0b27400bcafe06" 34 | }, 35 | "kernelspec": { 36 | "display_name": "Python 3.8.6 64-bit", 37 | "name": "python3" 38 | }, 39 | "language_info": { 40 | "codemirror_mode": { 41 | "name": "ipython", 42 | "version": 3 43 | }, 44 | "file_extension": ".py", 45 | "mimetype": "text/x-python", 46 | "name": "python", 47 | "nbconvert_exporter": "python", 48 | "pygments_lexer": "ipython3", 49 | "version": "3.8.6" 50 | }, 51 | "orig_nbformat": 4 52 | }, 53 | "nbformat": 4, 54 | "nbformat_minor": 2 55 | } 56 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16 2 | ipykernel==5.3.4 3 | ipython==7.16.3 4 | Jinja2==2.11.3 5 | jq==1.1.1 6 | markupsafe==2.0.1 7 | matplotlib==3.3.3 8 | nbsphinx==0.8.0 9 | pandoc==1.0.2 10 | pygments==2.7.4 11 | seaborn==0.11.1 12 | sphinx_rtd_theme==0.5.0 13 | sphinx-autodoc-typehints==1.11.1 14 | sphinx==3.3.0 15 | sphinxcontrib-bibtex==1.0.0 16 | sphinxcontrib-katex==0.7.1 17 | 18 | 19 | # pax requirements 20 | jax 21 | jaxlib 22 | jmp 23 | numpy -------------------------------------------------------------------------------- /examples/basics.py: -------------------------------------------------------------------------------- 1 | """PAX basic stuffs.""" 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import opax 6 | import pax 7 | from opax import GradientTransformation 8 | 9 | 10 | class Linear(pax.Module): 11 | """A linear module with counter.""" 12 | 13 | weight: jnp.ndarray 14 | bias: jnp.ndarray 15 | counter: jnp.ndarray 16 | parameters = pax.parameters_method("weight", "bias") 17 | 18 | def __init__(self): 19 | super().__init__() 20 | self.weight = jax.random.normal(pax.next_rng_key(), (1,)) 21 | self.bias = jax.random.normal(pax.next_rng_key(), (1,)) 22 | self.counter = jnp.array(0) 23 | 24 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 25 | self.counter = self.counter + 1 26 | x = self.weight * x + self.bias 27 | return x 28 | 29 | 30 | def loss_fn(model: Linear, x: jnp.ndarray, y: jnp.ndarray): 31 | model, y_hat = pax.purecall(model, x) 32 | loss = jnp.mean(jnp.square(y_hat - y)) 33 | return loss, model 34 | 35 | 36 | @jax.jit 37 | def train_step(model: Linear, optimizer: GradientTransformation, x, y): 38 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, x, y) 39 | model, optimizer = opax.apply_gradients(model, optimizer, grads) 40 | return model, optimizer, loss 41 | 42 | 43 | def main(): 44 | # random seed 45 | pax.seed_rng_key(42) 46 | 47 | # model & optimizer 48 | net = Linear() 49 | print(net.summary()) 50 | opt = opax.adam(1e-1).init(net.parameters()) 51 | 52 | # data 53 | x = jax.random.normal(pax.next_rng_key(), (32, 1)) 54 | y = jax.random.normal(pax.next_rng_key(), (32, 1)) 55 | 56 | # training loop 57 | for _ in range(10): 58 | net, opt, loss = train_step(net, opt, x, y) 59 | print(f"step {net.counter:>2} loss {loss:.3f}") 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /examples/char_rnn.py: -------------------------------------------------------------------------------- 1 | """Train a rnn language model on TPU (if available).""" 2 | 3 | import inspect 4 | import os 5 | from functools import partial 6 | from typing import List, Tuple 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import jax.tools.colab_tpu 11 | import opax 12 | import pax 13 | import tensorflow as tf 14 | from tqdm.auto import tqdm 15 | 16 | pax.seed_rng_key(42) 17 | 18 | 19 | def setup_tpu_device(): 20 | print("Setting up TPU cores") 21 | jax.tools.colab_tpu.setup_tpu() 22 | print(jax.devices()) 23 | 24 | 25 | if "COLAB_TPU_ADDR" in os.environ: 26 | # TPU config 27 | setup_tpu_device() 28 | steps_per_update = 50 29 | num_devices = jax.device_count() 30 | batch_size = 32 * num_devices * steps_per_update 31 | seq_len = 128 32 | vocab_size = 256 33 | hidden_dim = 512 34 | num_steps = 50_000 35 | else: 36 | # CPU/GPU config 37 | steps_per_update = 1 38 | num_devices = jax.device_count() 39 | batch_size = 1 * num_devices * steps_per_update 40 | seq_len = 64 41 | vocab_size = 256 42 | hidden_dim = 256 43 | num_steps = 20_000 44 | 45 | 46 | class LM(pax.Module): 47 | """A RNN language model.""" 48 | 49 | lstm: pax.Module 50 | embed: pax.Module 51 | output: pax.Module 52 | 53 | vocab_size: int 54 | hidden_dim: int 55 | 56 | def __init__(self, vocab_size: int, hidden_dim: int): 57 | """ 58 | Arguments: 59 | vocab_size: int, size of the alphabet. 60 | hidden_dim: int, number of LSTM cells. 61 | """ 62 | super().__init__() 63 | self.vocab_size = vocab_size 64 | self.hidden_dim = hidden_dim 65 | self.embed = pax.Embed(vocab_size, hidden_dim) 66 | self.lstm = pax.LSTM(hidden_dim, hidden_dim) 67 | self.output = pax.Linear(hidden_dim, vocab_size) 68 | 69 | def __call__(self, x): 70 | x = self.embed(x) 71 | hx, x = pax.scan( 72 | self.lstm, 73 | self.lstm.initial_state(x.shape[0]), 74 | x, 75 | time_major=False, 76 | ) 77 | del hx 78 | logits = self.output(x) 79 | return logits 80 | 81 | def inference(self, prompt: List[int] = [], length=32): 82 | hx = self.lstm.initial_state(1) 83 | if len(prompt) == 0: 84 | prompt = [0] 85 | 86 | x = jnp.array([prompt[0]], dtype=jnp.int32) 87 | 88 | total_len = len(prompt) + length 89 | 90 | out = [x] 91 | 92 | @jax.jit 93 | def step(x, hx): 94 | x = self.embed(x) 95 | hx, x = self.lstm(hx, x) 96 | logits = self.output(x) 97 | return logits, hx 98 | 99 | for i in range(1, total_len): 100 | logits, hx = step(x, hx) 101 | if i >= len(prompt): 102 | x = jnp.argmax(logits, axis=-1) 103 | else: 104 | x = jnp.array([prompt[i]], dtype=jnp.int32) 105 | out.append(x) 106 | return jnp.concatenate(out) 107 | 108 | 109 | def loss_fn(model: LM, batch: jnp.ndarray): 110 | inputs = batch[:, :-1] 111 | targets = batch[:, 1:] 112 | 113 | logits = model(inputs) 114 | log_pr = jax.nn.log_softmax(logits, axis=-1) 115 | targets = jax.nn.one_hot(targets, num_classes=model.vocab_size) 116 | loss = -jnp.mean(jnp.sum(targets * log_pr, axis=-1)) 117 | return loss 118 | 119 | 120 | def update_step(model_and_optimizer: Tuple[LM, pax.Module], batch: jnp.ndarray): 121 | model, optimizer = model_and_optimizer 122 | loss, grads = jax.value_and_grad(loss_fn)(model, batch) 123 | grads = jax.lax.pmean(grads, axis_name="i") 124 | model, optimizer = opax.apply_gradients(model, optimizer, grads=grads) 125 | return (model, optimizer), loss 126 | 127 | 128 | @partial(jax.pmap, axis_name="i") 129 | def update_fn(model, optimizer, multi_batch: jnp.ndarray): 130 | (model, optimizer), losses = pax.scan(update_step, (model, optimizer), multi_batch) 131 | return model, optimizer, jnp.mean(losses) 132 | 133 | 134 | net = LM(vocab_size=vocab_size, hidden_dim=hidden_dim) 135 | 136 | optimizer = opax.chain( 137 | opax.clip_by_global_norm(1.0), 138 | opax.adam(1e-4), 139 | ).init(net.parameters()) 140 | 141 | # replicate on multiple devices 142 | net = jax.device_put_replicated(net, jax.devices()) 143 | print(net.summary()) 144 | optimizer = jax.device_put_replicated(optimizer, jax.devices()) 145 | 146 | 147 | def tokenize(text): 148 | t = [0] + [ord(c) for c in text] # ASCII, 0 is the [START] token 149 | return t 150 | 151 | 152 | def detokenize(tokens): 153 | text = [chr(t) if t != 0 else "[START]" for t in tokens] 154 | return "".join(text) 155 | 156 | 157 | data = inspect.getsource(LM) # a _true_ AGI learns about itself. 158 | data_token = tokenize(data) 159 | test_prompt = "class LM(pax.Module):" 160 | 161 | tfdata = ( 162 | tf.data.Dataset.from_tensors(data_token) 163 | .repeat() 164 | .map( 165 | lambda x: tf.image.random_crop(x, [seq_len + 1]), 166 | num_parallel_calls=tf.data.AUTOTUNE, 167 | ) 168 | .batch(batch_size) 169 | .prefetch(tf.data.AUTOTUNE) 170 | .as_numpy_iterator() 171 | ) 172 | 173 | loss_accum = 0.0, 0 174 | tr = tqdm(range(0, 1 + num_steps, steps_per_update), desc="training") 175 | for step in tr: 176 | batch = next(tfdata) 177 | # (num_devices,) is for jax.pmap, (steps_per_update,) is for pax.scan 178 | batch = jnp.reshape(batch, (num_devices, steps_per_update, -1) + batch.shape[1:]) 179 | net, optimizer, losses = update_fn(net, optimizer, batch) 180 | loss_accum = (loss_accum[0] + jnp.mean(losses), loss_accum[1] + 1) 181 | if step % 1000 == 0: 182 | loss = loss_accum[0] / loss_accum[1] 183 | loss_accum = 0.0, 0 184 | # eval on a single device 185 | eval_net = jax.tree_util.tree_map(lambda x: x[0], net.eval()) 186 | out = eval_net.inference( 187 | prompt=tokenize(test_prompt), 188 | length=(100 if step < num_steps else 1000), 189 | ) 190 | text = detokenize(out.tolist()) 191 | tr.write( 192 | f"[step {step}] loss {loss:.3f}\n" 193 | f"Prompt: {test_prompt}\n" 194 | f"========\n" 195 | f"{text}\n" 196 | f"========" 197 | ) 198 | 199 | del tfdata # needed to avoid exception 200 | -------------------------------------------------------------------------------- /examples/dataclass_module.py: -------------------------------------------------------------------------------- 1 | """How to implement a PAX module using python dataclass.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Callable, Optional 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import pax 9 | 10 | 11 | @dataclass 12 | class Linear(pax.Module): 13 | """A linear module""" 14 | 15 | in_dim: int 16 | out_dim: int 17 | with_bias: bool = True 18 | name: Optional[str] = None 19 | weight: jnp.ndarray = field(init=False, repr=False) 20 | bias: Optional[jnp.ndarray] = field(init=False, repr=False) 21 | counter: jnp.ndarray = field(init=False) 22 | w_init: Callable = field(default=jax.nn.initializers.normal(), repr=False) 23 | b_init: Callable = field(default=jax.nn.initializers.zeros, repr=False) 24 | parameters = pax.parameters_method("weight", "bias") 25 | 26 | def __post_init__(self): 27 | self.weight = self.w_init(pax.next_rng_key(), (self.in_dim, self.out_dim)) 28 | self.bias = None 29 | if self.with_bias: 30 | self.bias = self.b_init(pax.next_rng_key(), (self.out_dim,)) 31 | self.counter = jnp.array(0) 32 | 33 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 34 | self.counter += 1 35 | x = jnp.dot(x, self.weight) 36 | if self.with_bias: 37 | x = x + self.bias 38 | return x 39 | 40 | 41 | pax.seed_rng_key(42) 42 | 43 | fc = Linear(3, 4, name="fc1") 44 | 45 | print("Before:", fc) 46 | dummy_x = jnp.empty((32, 3)) 47 | fc, y = pax.purecall(fc, dummy_x) 48 | assert y.shape == (32, 4) 49 | print("After :", fc) 50 | -------------------------------------------------------------------------------- /examples/denoising_diffusion/README.md: -------------------------------------------------------------------------------- 1 | ## Denoising Diffusion Model 2 | 3 | We transcribe the PyTorch model at https://github.com/lucidrains/denoising-diffusion-pytorch. 4 | 5 | The implementation is almost identical to the PyTorch version. 6 | The difference is at how PAX manages random keys. PAX's version uses a `RngSeq` submodule to generates new random keys when needed. 7 | 8 | To train model: 9 | 10 | ```sh 11 | pip install -r requirements.txt 12 | python3 train.py 13 | ``` -------------------------------------------------------------------------------- /examples/denoising_diffusion/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | ### load celeb_a dataset 7 | 8 | # This is a hack to use a custom link to celeb-a dataset in tensorflow-datasets. 9 | # replace the ``tfds.image.CelebA._split_generators`` method by the following method 10 | # which uses our custom links. 11 | 12 | IMG_ALIGNED_DATA = ( 13 | "https://drive.google.com/uc?export=download&" 14 | "id=1iQRFaGXRiPBd-flIm0u-u8Jy6CfJ_q6j" 15 | ) 16 | 17 | EVAL_LIST = ( 18 | "https://drive.google.com/uc?export=download&" 19 | "id=1ab9MDLOblszbKKXoDe8jumFsSkn6lIX1" 20 | ) 21 | # Landmark coordinates: left_eye, right_eye etc. 22 | LANDMARKS_DATA = ( 23 | "https://drive.google.com/uc?export=download&" 24 | "id=1y8qfK-jaq1QWl9v_n_mBNIMu5-h3UXK4" 25 | ) 26 | 27 | # Attributes in the image (Eyeglasses, Mustache etc). 28 | ATTR_DATA = ( 29 | "https://drive.google.com/uc?export=download&" 30 | "id=1BPfcVuIqrAsJAgG40-XGWU7g2wmmQU30" 31 | ) 32 | 33 | 34 | def _split_generators(self, dl_manager): 35 | downloaded_dirs = dl_manager.download( 36 | { 37 | "img_align_celeba": IMG_ALIGNED_DATA, 38 | "list_eval_partition": EVAL_LIST, 39 | "list_attr_celeba": ATTR_DATA, 40 | "landmarks_celeba": LANDMARKS_DATA, 41 | } 42 | ) 43 | 44 | # Load all images in memory (~1 GiB) 45 | # Use split to convert: `img_align_celeba/000005.jpg` -> `000005.jpg` 46 | all_images = { 47 | os.path.split(k)[-1]: img 48 | for k, img in dl_manager.iter_archive(downloaded_dirs["img_align_celeba"]) 49 | } 50 | 51 | return [ 52 | tfds.core.SplitGenerator( 53 | name=tfds.Split.TRAIN, 54 | gen_kwargs={ 55 | "file_id": 0, 56 | "downloaded_dirs": downloaded_dirs, 57 | "downloaded_images": all_images, 58 | }, 59 | ), 60 | tfds.core.SplitGenerator( 61 | name=tfds.Split.VALIDATION, 62 | gen_kwargs={ 63 | "file_id": 1, 64 | "downloaded_dirs": downloaded_dirs, 65 | "downloaded_images": all_images, 66 | }, 67 | ), 68 | tfds.core.SplitGenerator( 69 | name=tfds.Split.TEST, 70 | gen_kwargs={ 71 | "file_id": 2, 72 | "downloaded_dirs": downloaded_dirs, 73 | "downloaded_images": all_images, 74 | }, 75 | ), 76 | ] 77 | 78 | 79 | img_mean = 0.5 80 | img_scale = 0.5 81 | image_size = 64 # size of input image: 64x64 82 | 83 | tfds.image.CelebA._split_generators = _split_generators 84 | 85 | 86 | def load_celeb_a(): 87 | ds = tfds.load("celeb_a") 88 | 89 | def img_ops(x): 90 | img = tf.cast(x["image"], tf.float32) / 255.0 91 | img = tf.image.resize( 92 | img, (image_size * 2, image_size), preserve_aspect_ratio=True 93 | ) 94 | img = tf.image.crop_to_bounding_box(img, 7, 0, 64, 64) 95 | img = (img - img_mean) / img_scale 96 | return img 97 | 98 | dataset = ( 99 | ds["train"].concatenate(ds["validation"]).concatenate(ds["test"]).map(img_ops) 100 | ) 101 | return dataset 102 | -------------------------------------------------------------------------------- /examples/denoising_diffusion/requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | fire 3 | jax 4 | opax 5 | pax3 6 | pillow 7 | tensorflow 8 | tqdm -------------------------------------------------------------------------------- /examples/denoising_diffusion/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import fire 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import opax 8 | import pax 9 | import tensorflow as tf 10 | from PIL import Image 11 | from tqdm.auto import tqdm 12 | 13 | from data_loader import load_celeb_a 14 | from model import GaussianDiffusion, UNet 15 | 16 | 17 | def make_image_grid(images, padding=2): 18 | """Place images in a square grid.""" 19 | n = images.shape[0] 20 | size = int(math.sqrt(n)) 21 | assert size * size == n, "expecting a square grid" 22 | img = images[0] 23 | 24 | H = img.shape[0] * size + padding * (size + 1) 25 | W = img.shape[1] * size + padding * (size + 1) 26 | out = np.zeros((H, W, img.shape[-1]), dtype=img.dtype) 27 | for i in range(n): 28 | x = i % size 29 | y = i // size 30 | xstart = x * (img.shape[0] + padding) + padding 31 | xend = xstart + img.shape[0] 32 | ystart = y * (img.shape[1] + padding) + padding 33 | yend = ystart + img.shape[1] 34 | out[xstart:xend, ystart:yend, :] = images[i] 35 | return out 36 | 37 | 38 | def train( 39 | batch_size: int = 32, 40 | learning_rate: float = 1e-4, 41 | num_training_steps: int = 10_000, 42 | log_freq: int = 1000, 43 | image_size: int = 64, 44 | random_seed: int = 42, 45 | ): 46 | 47 | pax.seed_rng_key(random_seed) 48 | 49 | model = UNet(dim=64, dim_mults=(1, 2, 4, 8)) 50 | 51 | diffusion = GaussianDiffusion( 52 | model, 53 | image_size=image_size, 54 | timesteps=1000, 55 | loss_type="l1", # L1 or L2 56 | ) 57 | 58 | dataset = load_celeb_a() 59 | 60 | dataloader = ( 61 | dataset.repeat() 62 | .shuffle(batch_size * 100) 63 | .batch(batch_size) 64 | .take(num_training_steps) 65 | .prefetch(tf.data.AUTOTUNE) 66 | ) 67 | 68 | def loss_fn(model, inputs): 69 | model, loss = pax.purecall(model, inputs) 70 | return loss, (loss, model) 71 | 72 | update_fn = pax.utils.build_update_fn(loss_fn) 73 | fast_update_fn = jax.jit(update_fn) 74 | 75 | optimizer = opax.adam(learning_rate)(diffusion.parameters()) 76 | 77 | total_loss = 0.0 78 | tr = tqdm(dataloader) 79 | for step, batch in enumerate(tr, 1): 80 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch) 81 | diffusion, optimizer, loss = fast_update_fn(diffusion, optimizer, batch) 82 | total_loss = total_loss + loss 83 | 84 | if step % log_freq == 0: 85 | loss = total_loss / log_freq 86 | total_loss = 0.0 87 | tr.write(f"[step {step:05d}] train loss {loss:.3f}") 88 | 89 | imgs = jax.device_get(diffusion.eval().sample(16)) 90 | imgs = ((imgs * 0.5 + 0.5) * 255).astype(jnp.uint8) 91 | imgs = make_image_grid(imgs) 92 | im = Image.fromarray(imgs) 93 | im.save(f"sample_{step:05d}.png") 94 | 95 | 96 | if __name__ == "__main__": 97 | fire.Fire(train) 98 | -------------------------------------------------------------------------------- /examples/graph_module.py: -------------------------------------------------------------------------------- 1 | """A model as a directed graph.""" 2 | 3 | import jax 4 | import pax 5 | import jax.numpy as jnp 6 | from pax.experimental.graph import Node, build_graph_module 7 | 8 | pax.seed_rng_key(42) 9 | 10 | 11 | def residual_net(x: Node): 12 | _, D = x.shape 13 | y = x >> pax.Linear(D, D) >> jax.nn.relu >> pax.Linear(D, D) >> pax.Dropout(0.2) 14 | z = (x | y) >> jax.lax.add 15 | return z 16 | 17 | 18 | inputs = jnp.ones((3, 8)) 19 | net = build_graph_module(residual_net)(inputs) 20 | print(net.summary()) 21 | net, _ = pax.purecall(net, inputs) 22 | -------------------------------------------------------------------------------- /examples/lazy_module.py: -------------------------------------------------------------------------------- 1 | """A forward function that builds the model on the fly.""" 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import opax 6 | import pax 7 | 8 | 9 | @pax.pure 10 | def forward(net: pax.experimental.LazyModule, x): 11 | fc1 = net.get_or_create("fc1", lambda: pax.Linear(1, 1)) 12 | x = jax.nn.relu(fc1(x)) 13 | fc2 = net.get_or_create("fc2", lambda: pax.Linear(1, 1)) 14 | x = fc2(x) 15 | return net, x 16 | 17 | 18 | def loss_fn(model, x: jnp.ndarray, y: jnp.ndarray): 19 | model, y_hat = forward(model, x) 20 | loss = jnp.mean(jnp.square(y_hat - y)) 21 | return loss, model 22 | 23 | 24 | @jax.jit 25 | def train_step(model, optimizer: opax.GradientTransformation, x, y): 26 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, x, y) 27 | model, optimizer = opax.apply_gradients(model, optimizer, grads) 28 | return model, optimizer, loss 29 | 30 | 31 | def train(): 32 | "train a lazy model." 33 | 34 | pax.seed_rng_key(42) 35 | 36 | # data 37 | x = jax.random.normal(pax.next_rng_key(), (32, 1)) 38 | y = jax.random.normal(pax.next_rng_key(), (32, 1)) 39 | 40 | # model & optimizer 41 | net, _ = forward(pax.experimental.LazyModule(), x) 42 | print(net.summary()) 43 | opt = opax.adam(1e-1)(net.parameters()) 44 | 45 | # training loop 46 | for step in range(10): 47 | net, opt, loss = train_step(net, opt, x, y) 48 | print(f"step {step} loss {loss:.3f}") 49 | 50 | return net 51 | 52 | 53 | if __name__ == "__main__": 54 | train() 55 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | """train a handwritten digit classifier.""" 2 | 3 | import pickle 4 | from pathlib import Path 5 | from typing import Mapping 6 | 7 | import fire 8 | import jax 9 | import jax.numpy as jnp 10 | import opax 11 | import pax 12 | import tensorflow_datasets as tfds 13 | from opax import GradientTransformation 14 | from tqdm.auto import tqdm 15 | 16 | Batch = Mapping[str, jnp.ndarray] 17 | 18 | 19 | class ConvNet(pax.Module): 20 | """ConvNet module.""" 21 | 22 | layers: pax.Sequential 23 | 24 | def __init__(self): 25 | super().__init__() 26 | self.layers = pax.Sequential() 27 | for i in range(5): 28 | self.layers >>= pax.Conv2D((1 if i == 0 else 32), 32, 6, padding="VALID") 29 | self.layers >>= pax.BatchNorm2D(32, True, True, 0.9) 30 | self.layers >>= jax.nn.relu 31 | self.layers >>= pax.Conv2D(32, 10, 3, padding="VALID") 32 | 33 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 34 | x = self.layers(x) 35 | return jnp.squeeze(x, (1, 2)) 36 | 37 | 38 | def loss_fn(model: ConvNet, batch: Batch): 39 | x = batch["image"].astype(jnp.float32) / 255 40 | target = batch["label"] 41 | model, logits = pax.purecall(model, x) 42 | log_pr = jax.nn.log_softmax(logits, axis=-1) 43 | log_pr = jnp.sum(jax.nn.one_hot(target, log_pr.shape[-1]) * log_pr, axis=-1) 44 | loss = -jnp.mean(log_pr) 45 | return loss, model 46 | 47 | 48 | @jax.jit 49 | def test_loss_fn(model: ConvNet, batch: Batch): 50 | model = model.eval() 51 | return loss_fn(model, batch)[0] 52 | 53 | 54 | @jax.jit 55 | def update_fn(model: ConvNet, optimizer: GradientTransformation, batch: Batch): 56 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, batch) 57 | params = model.parameters() 58 | optimizer, updates = pax.purecall(optimizer, grads, params) 59 | params = params.map(jax.lax.sub, updates) 60 | model = model.update_parameters(params) 61 | return model, optimizer, loss 62 | 63 | 64 | def load_dataset(split: str): 65 | """Loads the dataset as a tensorflow dataset.""" 66 | ds = tfds.load("mnist:3.*.*", split=split) 67 | return ds 68 | 69 | 70 | def save_ckpt(epoch: int, model: ConvNet, path: Path): 71 | model = jax.device_get(model) 72 | with open(path, "wb") as f: 73 | pickle.dump({"epoch": epoch, "state_dict": model.state_dict()}, f) 74 | 75 | 76 | def load_ckpt(model: ConvNet, path: Path): 77 | """Load model from saved tree leaves""" 78 | with open(path, "rb") as f: 79 | dic = pickle.load(f) 80 | return dic["epoch"], model.load_state_dict(dic["state_dict"]) 81 | 82 | 83 | def train( 84 | batch_size=32, 85 | num_epochs=10, 86 | learning_rate=1e-4, 87 | weight_decay=1e-4, 88 | ckpt_dir="/tmp", 89 | ): 90 | pax.seed_rng_key(42) 91 | 92 | # model 93 | net = ConvNet() 94 | print(net.summary()) 95 | 96 | # optimizer 97 | optimizer = opax.chain( 98 | opax.clip_by_global_norm(1.0), 99 | opax.adamw(learning_rate=learning_rate, weight_decay=weight_decay), 100 | ).init(net.parameters()) 101 | 102 | # data 103 | train_data = load_dataset("train").shuffle(10 * batch_size).batch(batch_size) 104 | test_data = load_dataset("test").shuffle(10 * batch_size).batch(batch_size) 105 | 106 | # resume from the latest checkpoint 107 | ckpts = sorted(Path(ckpt_dir).glob("pax_mnist_ckpt_*.pickle")) 108 | if len(ckpts) > 0: 109 | print("loading checkpoint at", ckpts[-1]) 110 | last_epoch, net = load_ckpt(net, ckpts[-1]) 111 | else: 112 | last_epoch = -1 113 | 114 | # training loop 115 | for epoch in range(last_epoch + 1, num_epochs): 116 | losses = 0.0 117 | 118 | # training 119 | for batch in tqdm(train_data, desc="train", leave=False): 120 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch) 121 | net, optimizer, loss = update_fn(net, optimizer, batch) 122 | losses = losses + loss 123 | loss = losses / len(train_data) 124 | 125 | # testing 126 | test_losses = 0.0 127 | for batch in tqdm(test_data, desc="test", leave=False): 128 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch) 129 | test_losses = test_losses + test_loss_fn(net, batch) 130 | test_loss = test_losses / len(test_data) 131 | 132 | save_ckpt(epoch, net, Path(ckpt_dir) / f"pax_mnist_ckpt_{epoch:02d}.pickle") 133 | # logging 134 | print(f"[Epoch {epoch}] train loss {loss:.3f} test loss {test_loss:.3f}") 135 | 136 | 137 | if __name__ == "__main__": 138 | fire.Fire(train) 139 | -------------------------------------------------------------------------------- /examples/mnist_mixed_precision.py: -------------------------------------------------------------------------------- 1 | """train a handwritten digit classifier with mixed precision.""" 2 | 3 | from typing import List, Mapping, Tuple 4 | 5 | import fire 6 | import jax 7 | import jax.numpy as jnp 8 | import jmp 9 | import opax 10 | import pax 11 | import tensorflow_datasets as tfds 12 | from opax.transform import GradientTransformation 13 | from tqdm.auto import tqdm 14 | 15 | Batch = Mapping[str, jnp.ndarray] 16 | 17 | 18 | class ConvNet(pax.Module): 19 | """ConvNet module.""" 20 | 21 | layers: List[Tuple[pax.Conv2D, pax.BatchNorm2D]] 22 | output: pax.Conv2D 23 | 24 | def __init__(self): 25 | super().__init__() 26 | self.layers = [] 27 | for i in range(5): 28 | conv_in = 1 if i == 0 else 32 29 | conv = pax.Conv2D(conv_in, 32, 6, padding="VALID") 30 | bn = pax.BatchNorm2D(32) 31 | self.layers.append((conv, bn)) 32 | 33 | self.output = pax.Conv2D(32, 10, 3, padding="VALID") 34 | 35 | def __call__(self, x: jnp.ndarray): 36 | for conv, bn in self.layers: 37 | x = bn(conv(x)) 38 | x = jax.nn.relu(x) 39 | x = self.output(x) 40 | return jnp.squeeze(x, (1, 2)) 41 | 42 | 43 | def loss_fn(model: ConvNet, batch: Batch, loss_scale: jmp.LossScale): 44 | x = batch["image"].astype(jnp.float32) / 255 45 | target = batch["label"] 46 | model, logits = pax.purecall(model, x) 47 | log_pr = jax.nn.log_softmax(logits, axis=-1) 48 | log_pr = jnp.sum(jax.nn.one_hot(target, log_pr.shape[-1]) * log_pr, axis=-1) 49 | loss = -jnp.mean(log_pr) 50 | return loss_scale.scale(loss), (loss, model) 51 | 52 | 53 | @jax.jit 54 | def test_loss_fn(model: ConvNet, batch: Batch): 55 | model = model.eval() 56 | return loss_fn(model, batch, jmp.NoOpLossScale())[0] 57 | 58 | 59 | def apply_gradients_w_loss_scale( 60 | model: pax.Module, 61 | optimizer: opax.GradientTransformation, 62 | loss_scale: jmp.LossScale, 63 | grads: pax.Module, 64 | ): 65 | grads = loss_scale.unscale(grads) 66 | skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale) 67 | if skip_nonfinite_updates: 68 | grads_finite = jmp.all_finite(grads) 69 | loss_scale = loss_scale.adjust(grads_finite) 70 | model, optimizer = opax.apply_gradients( 71 | model, optimizer, grads=grads, all_finite=grads_finite 72 | ) 73 | else: 74 | model, optimizer = opax.apply_gradients(model, optimizer, grads=grads) 75 | return model, optimizer, loss_scale 76 | 77 | 78 | @jax.jit 79 | def update_fn( 80 | model: ConvNet, 81 | optimizer: GradientTransformation, 82 | loss_scale: jmp.LossScale, 83 | batch: Batch, 84 | ): 85 | grad_fn = pax.grad(loss_fn, has_aux=True) 86 | grads, (loss, model) = grad_fn(model, batch, loss_scale=loss_scale) 87 | return apply_gradients_w_loss_scale(model, optimizer, loss_scale, grads) + (loss,) 88 | 89 | 90 | def load_dataset(split: str): 91 | """Loads the dataset as a tensorflow dataset.""" 92 | ds = tfds.load("mnist:3.*.*", split=split) 93 | return ds 94 | 95 | 96 | def mp_policy_fn(mod): 97 | half = jmp.half_dtype() 98 | full = jnp.float32 99 | linear_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=full) 100 | bn_policy = jmp.Policy(compute_dtype=full, param_dtype=full, output_dtype=full) 101 | 102 | if isinstance(mod, pax.Conv2D): 103 | return pax.apply_mp_policy(mod, mp_policy=linear_policy) 104 | elif isinstance(mod, pax.BatchNorm2D): 105 | return pax.apply_mp_policy(mod, mp_policy=bn_policy) 106 | else: 107 | return mod # unchanged 108 | 109 | 110 | def train(batch_size=32, num_epochs=5, learning_rate=1e-4, weight_decay=1e-4): 111 | pax.seed_rng_key(42) 112 | 113 | net = ConvNet() 114 | net = net.apply(mp_policy_fn) 115 | print(net.summary()) 116 | optimizer = opax.chain( 117 | opax.clip_by_global_norm(1.0), 118 | opax.adamw(learning_rate=learning_rate, weight_decay=weight_decay), 119 | ).init(net.parameters()) 120 | 121 | loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15), period=2000) 122 | 123 | train_data = ( 124 | load_dataset("train") 125 | .shuffle(10 * batch_size) 126 | .batch(batch_size, drop_remainder=True) 127 | ) 128 | test_data = load_dataset("test").batch(batch_size, drop_remainder=True) 129 | 130 | for epoch in range(0, num_epochs): 131 | losses = 0.0 132 | for batch in tqdm(train_data, desc="train", leave=False): 133 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch) 134 | net, optimizer, loss_scale, loss = update_fn( 135 | net, optimizer, loss_scale, batch 136 | ) 137 | losses = losses + loss 138 | loss = losses / len(train_data) 139 | 140 | test_losses = 0.0 141 | for batch in tqdm(test_data, desc="eval", leave=False): 142 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch) 143 | test_losses = test_losses + test_loss_fn(net, batch) 144 | test_loss = test_losses / len(test_data) 145 | 146 | print( 147 | f"[Epoch {epoch}] train loss {loss:.3f} test loss" 148 | f" {test_loss:.3f} loss scale {loss_scale.loss_scale}" 149 | ) 150 | 151 | 152 | if __name__ == "__main__": 153 | fire.Fire(train) 154 | -------------------------------------------------------------------------------- /examples/notebooks/pretrained_resnet18.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pax 3 | import torchvision 4 | 5 | IMAGENET_MEAN = np.array((0.485, 0.456, 0.406)) 6 | IMAGENET_STD = np.array((0.229, 0.224, 0.225)) 7 | 8 | 9 | def convert_conv(conv, name=None): 10 | """Return a pax.Conv2D module with weights from pretrained ``conv``.""" 11 | weight = conv.weight.data.contiguous().permute(2, 3, 1, 0).contiguous().numpy()[:] 12 | 13 | pax_conv = pax.Conv2D( 14 | in_features=conv.in_channels, 15 | out_features=conv.out_channels, 16 | kernel_shape=conv.kernel_size, 17 | stride=conv.stride, 18 | with_bias=False, 19 | padding=[(conv.padding[0],) * 2, (conv.padding[1],) * 2], 20 | data_format="NCHW", 21 | name=name, 22 | ) 23 | assert pax_conv.weight.shape == weight.shape 24 | return pax_conv.replace(weight=weight) 25 | 26 | 27 | def convert_bn(bn, name=None): 28 | """Return a pax.BatchNorm2D module from pretrained ``bn``.""" 29 | weight = bn.weight.data.numpy()[None, :, None, None] 30 | bias = bn.bias.data.numpy()[None, :, None, None] 31 | running_mean = bn.running_mean.data.numpy()[None, :, None, None] 32 | running_var = bn.running_var.data.numpy()[None, :, None, None] 33 | 34 | pax_bn = pax.BatchNorm2D( 35 | num_channels=bias.shape[1], 36 | create_offset=True, 37 | create_scale=True, 38 | decay_rate=0.9, 39 | eps=1e-5, 40 | data_format="NCHW", 41 | name=name, 42 | ) 43 | assert pax_bn.scale.shape == weight.shape 44 | assert pax_bn.offset.shape == bias.shape 45 | assert pax_bn.ema_mean.averages.shape == running_mean.shape 46 | assert pax_bn.ema_var.averages.shape == running_var.shape 47 | 48 | pax_bn = pax_bn.replace(scale=weight, offset=bias) 49 | pax_bn = pax_bn.replace_node(pax_bn.ema_mean.averages, running_mean) 50 | pax_bn = pax_bn.replace_node(pax_bn.ema_var.averages, running_var) 51 | return pax_bn 52 | 53 | 54 | def convert_basic_block(block): 55 | conv1 = convert_conv(block.conv1, name="conv1") 56 | bn1 = convert_bn(block.bn1, name="bn1") 57 | conv2 = convert_conv(block.conv2, name="conv2") 58 | bn2 = convert_bn(block.bn2, name="bn2") 59 | 60 | if block.downsample is not None: 61 | conv0 = convert_conv(block.downsample[0], name="proj_conv") 62 | bn0 = convert_bn(block.downsample[1], name="proj_bn") 63 | return ((conv1, bn1), (conv2, bn2)), (conv0, bn0) 64 | else: 65 | return (((conv1, bn1), (conv2, bn2)),) 66 | 67 | 68 | def convert_block_group(group): 69 | out = [] 70 | for i in range(len(group)): 71 | out.append(convert_basic_block(group[i])) 72 | return out 73 | 74 | 75 | def convert_linear(linear): 76 | weight = linear.weight.data.numpy() 77 | bias = linear.bias.data.numpy() 78 | pax_linear = pax.Linear( 79 | in_dim=weight.shape[1], out_dim=weight.shape[0], with_bias=True 80 | ) 81 | weight = np.transpose(weight) 82 | assert pax_linear.bias.shape == bias.shape 83 | assert pax_linear.weight.shape == weight.shape 84 | 85 | return pax_linear.replace(weight=weight, bias=bias) 86 | 87 | 88 | def load_pretrained_resnet18(): 89 | resnet18 = pax.nets.ResNet18(3, 1000) 90 | resnet18_pt = torchvision.models.resnet18(pretrained=True).eval() 91 | pax_resnet = [ 92 | convert_conv(resnet18_pt.conv1), 93 | convert_bn(resnet18_pt.bn1), 94 | convert_block_group(resnet18_pt.layer1), 95 | convert_block_group(resnet18_pt.layer2), 96 | convert_block_group(resnet18_pt.layer3), 97 | convert_block_group(resnet18_pt.layer4), 98 | convert_linear(resnet18_pt.fc), 99 | ] 100 | 101 | def replace_parts(resnet18): 102 | # replace resnet18 part by part 103 | resnet18.initial_conv = pax_resnet[0] 104 | resnet18.initial_batchnorm = pax_resnet[1] 105 | for i in range(len(resnet18.block_groups)): 106 | bg = resnet18.block_groups[i] 107 | for j in range(len(bg.blocks)): 108 | b = bg.blocks[j] 109 | mods = pax_resnet[2 + i][j] 110 | b.layers = mods[0] 111 | if b.use_projection: 112 | b.proj_conv = mods[1][0] 113 | b.proj_batchnorm = mods[1][1] 114 | 115 | resnet18.logits = pax_resnet[-1] 116 | # make sure we are in `eval` mode when doing evaluation. 117 | return resnet18.eval() 118 | 119 | return pax.pure(replace_parts)(resnet18) 120 | -------------------------------------------------------------------------------- /examples/transformer/data.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | def tokenize(text): 7 | t = [0] + [ord(c) for c in text] # ASCII, 0 is the [START] token 8 | return t 9 | 10 | 11 | def detokenize(tokens): 12 | text = [chr(t) if t != 0 else "[START]" for t in tokens] 13 | return "".join(text) 14 | 15 | 16 | def _device_put_sharded(sharded_tree, devices): 17 | leaves, treedef = jax.tree_util.tree_flatten(sharded_tree) 18 | n = leaves[0].shape[0] 19 | return jax.device_put_sharded( 20 | [ 21 | jax.tree_util.tree_unflatten(treedef, [l[i] for l in leaves]) 22 | for i in range(n) 23 | ], 24 | devices, 25 | ) 26 | 27 | 28 | # Source: https://github.com/deepmind/dm-haiku/blob/8fad8c7503c5f56fa9ea9b53f71b7082704e3a3e/examples/imagenet/dataset.py#L163 29 | def double_buffer(ds, num_devices, steps_per_update): 30 | """Keeps at least two batches on the accelerator. 31 | The current GPU allocator design reuses previous allocations. For a training 32 | loop this means batches will (typically) occupy the same region of memory as 33 | the previous batch. An issue with this is that it means we cannot overlap a 34 | host->device copy for the next batch until the previous step has finished and 35 | the previous batch has been freed. 36 | By double buffering we ensure that there are always two batches on the device. 37 | This means that a given batch waits on the N-2'th step to finish and free, 38 | meaning that it can allocate and copy the next batch to the accelerator in 39 | parallel with the N-1'th step being executed. 40 | Args: 41 | ds: Iterable of batches of numpy arrays. 42 | Yields: 43 | Batches of sharded device arrays. 44 | """ 45 | batch = None 46 | devices = jax.devices() 47 | for next_batch in ds: 48 | assert next_batch is not None 49 | next_batch = np.reshape( 50 | next_batch, (num_devices, steps_per_update, -1) + next_batch.shape[1:] 51 | ) 52 | next_batch = _device_put_sharded(next_batch, devices) 53 | if batch is not None: 54 | yield batch 55 | batch = next_batch 56 | if batch is not None: 57 | yield batch 58 | 59 | 60 | def make_data_loader(data, seq_len, batch_size, num_devices, steps_per_update): 61 | data_token = tokenize(data) 62 | data_token = [0] * seq_len + data_token 63 | 64 | tfdata = ( 65 | tf.data.Dataset.from_tensors(data_token) 66 | .repeat() 67 | .map( 68 | lambda x: tf.image.random_crop(x, [seq_len + 1]), 69 | num_parallel_calls=tf.data.AUTOTUNE, 70 | ) 71 | .batch(batch_size) 72 | .prefetch(tf.data.AUTOTUNE) 73 | .as_numpy_iterator() 74 | ) 75 | 76 | return double_buffer(tfdata, num_devices, steps_per_update) 77 | -------------------------------------------------------------------------------- /examples/transformer/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Sequence 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import pax 7 | from pax.nets import Transformer 8 | 9 | 10 | def positional_encoding(x): 11 | _, L, D = x.shape 12 | position = jnp.arange(0, L, dtype=x.dtype)[:, None] 13 | div_term = jnp.exp(jnp.arange(0, D, 2, dtype=x.dtype) * (-math.log(10_000.0) / D)) 14 | x1 = jnp.sin(position * div_term[None, :]) 15 | x2 = jnp.cos(position * div_term[None, :]) 16 | x_pos = jnp.concatenate((x1, x2), axis=-1) 17 | return x + x_pos[None, :, :] 18 | 19 | 20 | class LM(pax.Module): 21 | """A Transformer language model.""" 22 | 23 | transformer: Transformer 24 | embed: pax.Module 25 | output: pax.Module 26 | 27 | vocab_size: int 28 | hidden_dim: int 29 | 30 | def __init__( 31 | self, vocab_size: int, hidden_dim: int, num_layers: int, dropout: float = 0.1 32 | ): 33 | """ 34 | Arguments: 35 | vocab_size: int, size of the alphabet. 36 | hidden_dim: int, hidden dim. 37 | num_layers: int, num transformer blocks. 38 | """ 39 | super().__init__() 40 | self.vocab_size = vocab_size 41 | self.hidden_dim = hidden_dim 42 | self.embed = pax.Embed( 43 | vocab_size, 44 | hidden_dim, 45 | w_init=jax.nn.initializers.variance_scaling( 46 | 1.0, mode="fan_out", distribution="normal" 47 | ), 48 | ) 49 | self.transformer = Transformer( 50 | hidden_dim, hidden_dim // 64, num_layers, dropout_rate=dropout 51 | ) 52 | self.output = pax.Linear(hidden_dim, vocab_size) 53 | 54 | def __call__(self, x): 55 | x = self.embed(x) 56 | x = positional_encoding(x) 57 | x = self.transformer(x) 58 | logits = self.output(x) 59 | return logits 60 | 61 | @pax.pure 62 | def inference(self, prompt: Sequence[int] = (), length=1024, train_seq_len=256): 63 | def step(inputs, _): 64 | logits = self(inputs) 65 | x = jnp.argmax(logits[:, -1], axis=-1) 66 | next_inputs = jnp.concatenate((inputs[:, 1:], x[:, None]), axis=-1) 67 | return next_inputs, x 68 | 69 | if len(prompt) > train_seq_len: 70 | inputs = prompt[-train_seq_len:] 71 | else: 72 | inputs = prompt 73 | pad_len = train_seq_len - len(inputs) 74 | padded_inputs = [0] * pad_len + inputs 75 | x = jnp.array([padded_inputs], dtype=jnp.int32) 76 | L = length - len(prompt) 77 | _, out = pax.scan(step, x, None, length=L, time_major=False) 78 | return prompt + out[0].tolist() 79 | -------------------------------------------------------------------------------- /examples/transformer/train.py: -------------------------------------------------------------------------------- 1 | """Train a transformer language model on TPU (if available).""" 2 | 3 | import inspect 4 | import os 5 | from functools import partial 6 | from typing import Tuple 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import jax.tools.colab_tpu 11 | import opax 12 | import pax 13 | from opax import GradientTransformation 14 | from tqdm.auto import tqdm 15 | 16 | from data import detokenize, make_data_loader, tokenize 17 | from model import LM 18 | 19 | 20 | def setup_tpu_device(): 21 | print("Setting up TPU cores") 22 | jax.tools.colab_tpu.setup_tpu() 23 | print(jax.devices()) 24 | 25 | 26 | # shared config 27 | dropout = 0.1 28 | learning_rate = 1e-4 29 | vocab_size = 256 30 | pax.seed_rng_key(42) 31 | 32 | if "COLAB_TPU_ADDR" in os.environ: 33 | # TPU config 34 | # need to config TPU cores _before_ calling `jax.device_count`. 35 | setup_tpu_device() 36 | steps_per_update = 50 37 | num_devices = jax.device_count() 38 | batch_size = 32 * num_devices * steps_per_update 39 | seq_len = 256 40 | hidden_dim = 512 41 | num_steps = 1_000 42 | num_layers = 6 43 | else: 44 | # CPU/GPU config 45 | steps_per_update = 1 46 | num_devices = jax.device_count() 47 | batch_size = 8 * num_devices * steps_per_update 48 | seq_len = 64 49 | hidden_dim = 256 50 | num_steps = 20_000 51 | num_layers = 2 52 | 53 | 54 | def loss_fn(model: LM, batch: jnp.ndarray): 55 | inputs = batch[:, :-1] 56 | targets = batch[:, 1:] 57 | 58 | model, logits = pax.purecall(model, inputs) 59 | log_pr = jax.nn.log_softmax(logits, axis=-1) 60 | targets = jax.nn.one_hot(targets, num_classes=model.vocab_size) 61 | loss = -jnp.mean(jnp.sum(targets * log_pr, axis=-1)) 62 | return loss, model 63 | 64 | 65 | def update_step(model_and_optim: Tuple[LM, GradientTransformation], batch: jnp.ndarray): 66 | model, optimizer = model_and_optim 67 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, batch) 68 | grads = jax.lax.pmean(grads, axis_name="i") 69 | params = model.parameters() 70 | optimizer, updates = pax.purecall(optimizer, grads, params) 71 | params = params.map(jax.lax.sub, updates) 72 | model = model.update_parameters(params) 73 | return (model, optimizer), loss 74 | 75 | 76 | @partial(jax.pmap, axis_name="i") 77 | def update_fn(model: LM, optimizer: GradientTransformation, multi_batch: jnp.ndarray): 78 | (model, optimizer), losses = pax.scan(update_step, (model, optimizer), multi_batch) 79 | return model, optimizer, jnp.sum(losses) 80 | 81 | 82 | def train(): 83 | net = LM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers) 84 | print(net.summary()) 85 | optimizer = opax.chain( 86 | opax.clip_by_global_norm(1.0), 87 | opax.adam(learning_rate), 88 | ).init(net.parameters()) 89 | 90 | data = inspect.getsource(LM) # a _true_ AGI learns about itself. 91 | test_prompt = data[:20] 92 | data_iter = make_data_loader( 93 | data, 94 | seq_len=seq_len, 95 | batch_size=batch_size, 96 | num_devices=num_devices, 97 | steps_per_update=steps_per_update, 98 | ) 99 | 100 | # replicate on multiple devices 101 | net = jax.device_put_replicated(net, jax.devices()) 102 | optimizer = jax.device_put_replicated(optimizer, jax.devices()) 103 | 104 | total_losses = 0.0 105 | tr = tqdm(range(0, 1 + num_steps, steps_per_update), desc="training") 106 | for step in tr: 107 | batch = next(data_iter) 108 | # (num_devices,) is for jax.pmap, (steps_per_update,) is for pax.scan 109 | net, optimizer, loss = update_fn(net, optimizer, batch) 110 | total_losses = total_losses + loss 111 | if step % 1000 == 0: 112 | loss = jnp.mean(total_losses) / (1000 if step > 0 else steps_per_update) 113 | total_losses = jnp.zeros_like(total_losses) 114 | # eval on a single device 115 | eval_net = jax.tree_util.tree_map(lambda x: x[0], net.eval()) 116 | out = eval_net.inference( 117 | prompt=tokenize(test_prompt), 118 | length=(128 if step < num_steps else 1024), 119 | train_seq_len=seq_len, 120 | ) 121 | text = detokenize(out) 122 | tr.write( 123 | f"[step {step}] loss {loss:.3f}\n" 124 | f"Prompt: {test_prompt}\n" 125 | f"========\n" 126 | f"{text}\n" 127 | f"========" 128 | ) 129 | 130 | 131 | if __name__ == "__main__": 132 | train() 133 | -------------------------------------------------------------------------------- /examples/wave_gru/README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This example is an implementation of [Lyra](https://github.com/google/lyra) WaveGRU network. 4 | However, we predict the 8-bit mu-compressed waveform instead of the raw 16-bit waveform. 5 | 6 | 7 | ## Data preparation 8 | 9 | We use `ffmpeg` and `sox` to do audio conversion and silence trimming. 10 | 11 | 12 | To prepare audio clip: 13 | 14 | pip install -r requirements.txt 15 | bash prepare_data.sh 16 | 17 | ## Train WaveGRU 18 | 19 | python3 train.py # 1 hour on a Tesla T4 20 | -------------------------------------------------------------------------------- /examples/wave_gru/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import librosa 5 | import numpy as np 6 | 7 | 8 | def data_loader( 9 | batch_size: int, 10 | n_mels: int, 11 | n_fft: int, 12 | hop_length: int, 13 | win_length: int, 14 | sample_rate: int, 15 | fmin: int, 16 | fmax: int, 17 | mu: int, 18 | n_frames: int, 19 | split="train", 20 | pad: int = 31, 21 | ): 22 | if not os.path.exists("/tmp/wave_gru_clip.wav"): 23 | os.system("bash /tmp/prepare_clip.sh") 24 | 25 | wav, _ = librosa.load("/tmp/wave_gru_clip.wav", sr=sample_rate) 26 | 27 | L = len(wav) * 9 // 10 28 | if split == "train": 29 | wav = wav[:L] 30 | else: 31 | wav = wav[L:] 32 | 33 | mel = librosa.feature.melspectrogram( 34 | n_mels=n_mels, 35 | y=wav, 36 | sr=sample_rate, 37 | n_fft=n_fft, 38 | hop_length=hop_length, 39 | win_length=win_length, 40 | fmin=fmin, 41 | fmax=fmax, 42 | center=False, 43 | ) 44 | 45 | mel = mel.T 46 | 47 | logmel = np.log(1e-3 + mel) 48 | mu_wav = librosa.mu_compress(wav, mu=mu, quantize=True) + mu // 2 49 | 50 | if split == "test": 51 | yield (logmel, mu_wav) 52 | return 53 | 54 | batch = [] 55 | while True: 56 | left = random.randint(0, logmel.shape[0] - n_frames - pad * 2) 57 | right = left + pad + n_frames + pad 58 | cond = logmel[left:right] # included padding 59 | x = mu_wav[(left + pad) * hop_length : (right - pad) * hop_length + 1] 60 | batch.append((cond, x)) 61 | if len(batch) == batch_size: 62 | conds, xs = zip(*batch) 63 | conds = np.array(conds) 64 | xs = np.array(xs, dtype=np.int16) 65 | yield (conds, xs) 66 | batch = [] 67 | -------------------------------------------------------------------------------- /examples/wave_gru/model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import pax 6 | 7 | 8 | class UpsampleNet(pax.Module): 9 | """Upsampling melspectrogram.""" 10 | 11 | def __init__(self, n_mels, num_output_channels): 12 | super().__init__() 13 | self.input_conv = pax.Conv1D(n_mels, 512, 1, padding="VALID") 14 | self.dilated_convs = [] 15 | self.bns = [] 16 | for i in range(5): 17 | conv = pax.Conv1D(512, 512, 3, rate=2 ** i, padding="VALID") 18 | self.dilated_convs.append(conv) 19 | self.bns.append(pax.BatchNorm1D(512, True, True, 0.99)) 20 | self.upsample_conv_1 = pax.Conv1DTranspose(512, 512, 4, stride=4) 21 | self.upsample_bn1 = pax.BatchNorm1D(512, True, True, 0.99) 22 | self.upsample_conv_2 = pax.Conv1DTranspose(512, 512, 4, stride=4) 23 | self.upsample_bn2 = pax.BatchNorm1D(512, True, True, 0.99) 24 | self.output_conv = pax.Conv1D(512, num_output_channels, 1, padding="VALID") 25 | 26 | def __call__(self, mel): 27 | x = self.input_conv(mel) 28 | 29 | # Large receptive fields 30 | for conv, batch_norm in zip(self.dilated_convs, self.bns): 31 | residual = jax.nn.relu(batch_norm(conv(x))) 32 | pad = (x.shape[1] - residual.shape[1]) // 2 33 | x = x[:, pad:-pad] + residual 34 | 35 | # upsample 36 | x = jax.nn.relu(self.upsample_bn1(self.upsample_conv_1(x))) 37 | x = jax.nn.relu(self.upsample_bn2(self.upsample_conv_2(x))) 38 | 39 | x = self.output_conv(x) 40 | 41 | # tile x16 42 | N, L, D = x.shape 43 | x = jnp.tile(x[:, :, None, :], (1, 1, 16, 1)) 44 | x = jnp.reshape(x, (N, -1, D)) 45 | 46 | return x 47 | 48 | 49 | class WaveGRU(pax.Module): 50 | def __init__(self, n_mels, hidden_dim, n_mu_bits=8): 51 | super().__init__() 52 | self.n_mu_bits = n_mu_bits 53 | self.hidden_dim = hidden_dim 54 | 55 | self.upsampling = UpsampleNet(n_mels, hidden_dim) 56 | self.gru = pax.GRU(hidden_dim, hidden_dim) 57 | self.logits = pax.Linear(hidden_dim, 2 ** n_mu_bits) 58 | self.embed = pax.Embed(2 ** n_mu_bits, hidden_dim) 59 | 60 | def __call__(self, inputs): 61 | logmel, wav = inputs 62 | x = self.upsampling(logmel) 63 | hx = self.gru.initial_state(x.shape[0]) 64 | wav = self.embed(wav) 65 | assert x.shape == wav.shape 66 | x = x + wav 67 | _, x = pax.scan(self.gru, hx, x, time_major=False) 68 | x = self.logits(x) 69 | return x 70 | 71 | def inference(self, logmel, rng_key=None): 72 | if rng_key is None: 73 | rng_key = pax.next_rng_key() 74 | 75 | x = jnp.array([2 ** (self.n_mu_bits - 1)], dtype=jnp.int32) 76 | hx = self.gru.initial_state(1) 77 | 78 | conds = self.upsampling(logmel) 79 | 80 | def loop(prev_state, inputs): 81 | x, hx, rng_key = prev_state 82 | rng_key, next_rng_key = jax.random.split(rng_key) 83 | 84 | x = self.embed(x) + inputs 85 | hx, x = self.gru(hx, x) 86 | x = self.logits(x) 87 | x = jax.random.categorical(rng_key, x) 88 | return (x, hx, next_rng_key), x 89 | 90 | _, x = pax.scan(loop, (x, hx, rng_key), conds, time_major=False) 91 | return x 92 | -------------------------------------------------------------------------------- /examples/wave_gru/prepare_data.sh: -------------------------------------------------------------------------------- 1 | # "Yoshua Bengio: Deep Learning Cognition | Full Keynote - AI in 2020 & Beyond" 2 | youtube-dl -f 139 https://www.youtube.com/watch?v=GibjI5FoZsE --output /tmp/wave_gru_clip.m4a 3 | # convert m4a to wav 4 | ffmpeg -i /tmp/wave_gru_clip.m4a -ac 1 -ar 16000 -acodec pcm_s16le /tmp/wave_gru_clip_.wav 5 | # trim silences 6 | sox /tmp/wave_gru_clip_.wav /tmp/wave_gru_clip.wav silence -l 1 0.1 1% -1 1.0 1% 7 | -------------------------------------------------------------------------------- /examples/wave_gru/requirements.txt: -------------------------------------------------------------------------------- 1 | fire 2 | librosa 3 | opax 4 | soundfile 5 | tqdm 6 | youtube-dl -------------------------------------------------------------------------------- /examples/wave_gru/train.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import librosa 6 | import opax 7 | import pax 8 | import soundfile 9 | from tqdm.auto import tqdm 10 | 11 | from data_loader import data_loader 12 | from model import WaveGRU 13 | 14 | 15 | def loss_fn(model: WaveGRU, inputs): 16 | logmel, wav = inputs 17 | input_wav = wav[:, :-1] 18 | target_wav = wav[:, 1:] 19 | model, logits = pax.purecall(model, (logmel, input_wav)) 20 | log_pr = jax.nn.log_softmax(logits, axis=-1) 21 | target_wave = jax.nn.one_hot(target_wav, num_classes=logits.shape[-1]) 22 | log_pr = jnp.sum(log_pr * target_wave, axis=-1) 23 | loss = -jnp.mean(log_pr) 24 | return loss, (loss, model) 25 | 26 | 27 | def generate_test_sample(step, test_logmel, wave_gru, length, sample_rate, mu): 28 | generated_mu = wave_gru.eval().inference(test_logmel[None, :length, :]) 29 | generated_mu = jax.device_get(generated_mu) 30 | synthesized_clip = librosa.mu_expand( 31 | generated_mu[0] - mu // 2, mu=mu, quantize=True 32 | ) 33 | file_name = f"/tmp/wave_gru_sample_{step:05d}.wav" 34 | soundfile.write( 35 | file_name, 36 | synthesized_clip, 37 | samplerate=sample_rate, 38 | ) 39 | return file_name 40 | 41 | 42 | def train( 43 | hidden_dim: int = 512, 44 | num_training_steps: int = 5_000, 45 | batch_size: int = 128, 46 | learning_rate: float = 5e-4, 47 | sample_rate: int = 16_000, 48 | max_global_norm: float = 1.0, 49 | n_fft=1024, 50 | hop_length=256, 51 | win_length=1024, 52 | n_mels=80, 53 | fmin=0, 54 | fmax=8000, 55 | seq_len=2 ** 10, 56 | n_mu_bits=8, 57 | log_freq: int = 1000, 58 | random_seed=42, 59 | ): 60 | pax.seed_rng_key(random_seed) 61 | mu = 2 ** n_mu_bits - 1 62 | n_frames = seq_len // hop_length 63 | wave_gru = WaveGRU(n_mels, hidden_dim) 64 | print(wave_gru.summary()) 65 | 66 | optimizer = opax.chain( 67 | opax.clip_by_global_norm(max_global_norm), 68 | opax.adam(learning_rate), 69 | ).init(wave_gru.parameters()) 70 | 71 | split_loader = partial( 72 | data_loader, 73 | batch_size=batch_size, 74 | n_mels=n_mels, 75 | n_fft=n_fft, 76 | hop_length=hop_length, 77 | win_length=win_length, 78 | sample_rate=sample_rate, 79 | mu=mu, 80 | n_frames=n_frames, 81 | fmin=fmin, 82 | fmax=fmax, 83 | ) 84 | data_iter = split_loader(split="train") 85 | test_iter = split_loader(split="test") 86 | test_logmel, _ = next(test_iter) 87 | 88 | update_fn = jax.jit(pax.utils.build_update_fn(loss_fn)) 89 | total_loss = 0.0 90 | tr = tqdm(range(1, 1 + num_training_steps)) 91 | for step in tr: 92 | batch = next(data_iter) 93 | wave_gru, optimizer, loss = update_fn(wave_gru, optimizer, batch) 94 | total_loss = total_loss + loss 95 | 96 | if step % log_freq == 0: 97 | loss = total_loss / log_freq 98 | total_loss = 0.0 99 | file_name = generate_test_sample( 100 | step, test_logmel, wave_gru, 1000, sample_rate, mu 101 | ) 102 | tr.write( 103 | f"[step {step}] train loss {loss:.3f} synthesized clip {file_name}" 104 | ) 105 | 106 | 107 | if __name__ == "__main__": 108 | import fire 109 | 110 | fire.Fire(train) 111 | -------------------------------------------------------------------------------- /images/pax_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTT123/pax/13916cb86ede38c56750cf1bde3ac37c63674014/images/pax_logo.png -------------------------------------------------------------------------------- /pax/__init__.py: -------------------------------------------------------------------------------- 1 | """PAX package.""" 2 | 3 | from pax import experimental, nets, utils 4 | from pax._src.core import ( 5 | EmptyNode, 6 | Module, 7 | ParameterModule, 8 | StateModule, 9 | apply_mp_policy, 10 | assert_structure_equal, 11 | enable_eval_mode, 12 | enable_train_mode, 13 | freeze_parameters, 14 | module_and_value, 15 | parameters_method, 16 | pure, 17 | purecall, 18 | select_parameters, 19 | unfreeze_parameters, 20 | unwrap_mp_policy, 21 | update_parameters, 22 | ) 23 | from pax._src.core.rng import next_rng_key, seed_rng_key 24 | from pax._src.nn import ( 25 | EMA, 26 | GRU, 27 | LSTM, 28 | BatchNorm1D, 29 | BatchNorm2D, 30 | Conv1D, 31 | Conv1DTranspose, 32 | Conv2D, 33 | Conv2DTranspose, 34 | Dropout, 35 | Embed, 36 | GroupNorm, 37 | GRUState, 38 | Identity, 39 | Lambda, 40 | LayerNorm, 41 | Linear, 42 | LSTMState, 43 | MultiHeadAttention, 44 | RngSeq, 45 | Sequential, 46 | VanillaRNN, 47 | VanillaRNNState, 48 | avg_pool, 49 | max_pool, 50 | ) 51 | from pax._src.nn.dropout import dropout 52 | from pax._src.utils import build_update_fn, grad, scan, value_and_grad 53 | 54 | __version__ = "0.5.9" 55 | 56 | __all__ = ( 57 | "apply_mp_policy", 58 | "assert_structure_equal", 59 | "avg_pool", 60 | "BatchNorm1D", 61 | "BatchNorm2D", 62 | "build_update_fn", 63 | "Conv1D", 64 | "Conv1DTranspose", 65 | "Conv2D", 66 | "Conv2DTranspose", 67 | "dropout", 68 | "Dropout", 69 | "EMA", 70 | "Embed", 71 | "EmptyNode", 72 | "enable_eval_mode", 73 | "enable_train_mode", 74 | "experimental", 75 | "freeze_parameters", 76 | "grad", 77 | "GroupNorm", 78 | "GRU", 79 | "GRUState", 80 | "Identity", 81 | "Lambda", 82 | "LayerNorm", 83 | "Linear", 84 | "LSTM", 85 | "LSTMState", 86 | "max_pool", 87 | "module_and_value", 88 | "Module", 89 | "MultiHeadAttention", 90 | "nets", 91 | "next_rng_key", 92 | "ParameterModule", 93 | "parameters_method", 94 | "pure", 95 | "purecall", 96 | "RngSeq", 97 | "scan", 98 | "seed_rng_key", 99 | "select_parameters", 100 | "Sequential", 101 | "StateModule", 102 | "unfreeze_parameters", 103 | "unwrap_mp_policy", 104 | "update_parameters", 105 | "utils", 106 | "value_and_grad", 107 | "VanillaRNN", 108 | "VanillaRNNState", 109 | ) 110 | 111 | 112 | try: 113 | del _src # pylint: disable=undefined-variable 114 | except NameError: 115 | pass 116 | -------------------------------------------------------------------------------- /pax/_src/__init__.py: -------------------------------------------------------------------------------- 1 | ### 2 | ### Empty init 3 | ### 4 | -------------------------------------------------------------------------------- /pax/_src/core/__init__.py: -------------------------------------------------------------------------------- 1 | """PAX Module""" 2 | 3 | from .graph_module import GraphModule, InputNode, build_graph_module 4 | from .mixed_precision import apply_mp_policy, unwrap_mp_policy 5 | from .module import EmptyNode, Module, parameters_method 6 | from .module_and_value import module_and_value 7 | from .mutable import mutable 8 | from .pure import pure, purecall 9 | from .transforms import ( 10 | enable_eval_mode, 11 | enable_train_mode, 12 | freeze_parameters, 13 | select_parameters, 14 | unfreeze_parameters, 15 | update_parameters, 16 | ) 17 | from .utility_modules import Flattener, LazyModule, ParameterModule, StateModule 18 | from .utils import assert_structure_equal 19 | -------------------------------------------------------------------------------- /pax/_src/core/base.py: -------------------------------------------------------------------------------- 1 | """PAX BaseModule.""" 2 | 3 | # Note: This file is originated from 4 | # https://raw.githubusercontent.com/cgarciae/treex/32e4cce5ca0cc991cda8076903853621d0aa4ab9/treex/module.py 5 | # which is under MIT License. 6 | 7 | from typing import Any, List, Mapping, Optional, Tuple, TypeVar 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.tree_util 12 | import numpy as np 13 | 14 | T = TypeVar("T", bound="BaseModule") 15 | M = TypeVar("M") 16 | 17 | 18 | class BaseModule: 19 | """BaseModule manages all information related to the pytree. 20 | 21 | There are two important methods: 22 | 23 | - ``tree_flatten`` converts a module to ``(leaves, treedef)`` 24 | - ``tree_unflatten`` restores the module. 25 | 26 | BaseModule maintains a ``pytree_attributes`` tuple that lists all subtree attribute names. 27 | """ 28 | 29 | _pytree_attributes: Tuple[str, ...] = () 30 | _mixed_pytree_attributes: Optional[Tuple[str, ...]] = None 31 | 32 | @property 33 | def pytree_attributes(self): 34 | if self._mixed_pytree_attributes is not None: 35 | return self._pytree_attributes + self._mixed_pytree_attributes 36 | else: 37 | return self._pytree_attributes 38 | 39 | def find_and_register_pytree_attributes(self: T): 40 | """Find and register ndarrays and submodules.""" 41 | is_mod_or_node = lambda x: isinstance(x, (BaseModule, EmptyNode)) 42 | is_pytree = lambda x: isinstance(x, pytree_cls) 43 | 44 | pytree_attributes = [] 45 | mixed_pytree_attributes = [] 46 | for name, value in self.__dict__.items(): 47 | leaves, _ = jax.tree_util.tree_flatten(value, is_leaf=is_mod_or_node) 48 | pytree_cls = (jnp.ndarray, np.ndarray, BaseModule, EmptyNode) 49 | any_pytree = any(map(is_pytree, leaves)) 50 | all_pytree = all(map(is_pytree, leaves)) 51 | if any_pytree and all_pytree: 52 | pytree_attributes.append(name) 53 | elif any_pytree: 54 | mixed_pytree_attributes.append(name) 55 | self._pytree_attributes = tuple(pytree_attributes) 56 | if len(mixed_pytree_attributes) > 0: 57 | self._mixed_pytree_attributes = tuple(mixed_pytree_attributes) 58 | else: 59 | self._mixed_pytree_attributes = None 60 | 61 | def tree_flatten(self) -> Tuple[List[jnp.ndarray], Mapping[str, Any]]: 62 | """Convert a module to ``(children, treedef)``.""" 63 | aux = dict(self.__dict__) 64 | children = [aux.pop(name) for name in self._pytree_attributes] 65 | if self._mixed_pytree_attributes is not None: 66 | is_module = lambda x: isinstance(x, BaseModule) 67 | array_mod_cls = (jnp.ndarray, np.ndarray, BaseModule) 68 | is_array_mod = lambda x: isinstance(x, array_mod_cls) 69 | for name in self._mixed_pytree_attributes: 70 | value = aux.pop(name) 71 | leaves, treedef = jax.tree_util.tree_flatten(value, is_leaf=is_module) 72 | leaves = (v if is_array_mod(v) else ValueNode(v) for v in leaves) 73 | value = jax.tree_util.tree_unflatten(treedef, leaves) 74 | children.append(value) 75 | return children, aux 76 | 77 | @classmethod 78 | def tree_unflatten(cls, aux, children): 79 | """Recreate a module from its ``(children, treedef)``.""" 80 | module = object.__new__(cls) 81 | module_dict = module.__dict__ 82 | module_dict.update(aux) 83 | module_dict.update(zip(module._pytree_attributes, children)) 84 | if module._mixed_pytree_attributes is not None: 85 | L = len(module._pytree_attributes) 86 | is_leaf = lambda x: isinstance(x, (ValueNode, BaseModule)) 87 | unwrap = lambda x: x.value if isinstance(x, ValueNode) else x 88 | for name, value in zip(module._mixed_pytree_attributes, children[L:]): 89 | module_dict[name] = jax.tree_util.tree_map( 90 | unwrap, value, is_leaf=is_leaf 91 | ) 92 | return module 93 | 94 | def __init_subclass__(cls): 95 | """Any subclass of ``Module`` is also registered as pytree.""" 96 | jax.tree_util.register_pytree_node_class(cls) 97 | 98 | def __eq__(self, o: object) -> bool: 99 | """Compare two modules.""" 100 | if id(self) == id(o): 101 | return True 102 | 103 | if type(self) is not type(o): 104 | return False 105 | 106 | self_leaves, self_treedef = jax.tree_util.tree_flatten(self) 107 | o_leaves, o_treedef = jax.tree_util.tree_flatten(o) 108 | 109 | if len(self_leaves) != len(o_leaves): 110 | return False 111 | 112 | if self_treedef != o_treedef: 113 | return False 114 | 115 | leaves_equal = jax.tree_util.tree_map( 116 | lambda a, b: a is b, self_leaves, o_leaves 117 | ) 118 | return all(leaves_equal) 119 | 120 | def __hash__(self) -> int: 121 | leaves, treedef = jax.tree_util.tree_flatten(self) 122 | leaves = jax.tree_util.tree_map(lambda x: (x.shape, x.dtype), leaves) 123 | return hash((tuple(leaves), treedef)) 124 | 125 | 126 | # Note: this class is inspired by treex's `Nothing` class. 127 | @jax.tree_util.register_pytree_node_class 128 | class EmptyNode: 129 | """Mark an uninitialized or deleted pytree node.""" 130 | 131 | def tree_flatten(self): 132 | """Flatten empty node.""" 133 | return (), None 134 | 135 | @classmethod 136 | def tree_unflatten(cls, aux, children): 137 | """Unflatten empty node.""" 138 | del aux, children 139 | return EmptyNode() 140 | 141 | def __repr__(self) -> str: 142 | return "EmptyNode" 143 | 144 | def __eq__(self, o: object) -> bool: 145 | if isinstance(o, EmptyNode): 146 | return True 147 | return False 148 | 149 | 150 | @jax.tree_util.register_pytree_node_class 151 | class ValueNode: 152 | """We use this class to store a value in treedef.""" 153 | 154 | def __init__(self, value): 155 | super().__init__() 156 | self.value = value 157 | 158 | def tree_flatten(self): 159 | return (), self.value 160 | 161 | @classmethod 162 | def tree_unflatten(cls, value, children): 163 | return ValueNode(value) 164 | 165 | def __repr__(self) -> str: 166 | return f"ValueNode({self.value})" 167 | -------------------------------------------------------------------------------- /pax/_src/core/mixed_precision.py: -------------------------------------------------------------------------------- 1 | """Enforce mixed-precision policy.""" 2 | 3 | import functools 4 | from typing import TypeVar 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import jmp 9 | 10 | from .module import Module 11 | from .safe_module import find_descriptor 12 | 13 | T = TypeVar("T", bound=Module) 14 | 15 | 16 | def _wrap_method(func): 17 | """Wrap a class's method to enforce mixe-precision policy.""" 18 | 19 | @functools.wraps(func) 20 | def mp_method_wrapper(self, *args, **kwargs): 21 | """A mixed-precision method. 22 | 23 | - Convert all weights to compute dtype. 24 | - Cast all arguments to compute dtype. 25 | - Call the original method. 26 | - Convert all weights to param dtype. 27 | - Cast output to output dtype. 28 | 29 | We bypass PAX mutability checking to make mixed-precision 30 | policy transparent from the user's point of view. 31 | """ 32 | original_values = {} 33 | casted_original = {} 34 | # pylint: disable=protected-access 35 | 36 | # convert weights to compute dtype 37 | for name in self.pytree_attributes: 38 | value = getattr(self, name) 39 | if not _has_module(value): 40 | casted_value = self._pax_mp_policy.cast_to_compute(value) 41 | self.__dict__[name] = casted_value 42 | original_values[name] = value 43 | casted_original[name] = casted_value 44 | 45 | # cast arguments to compute dtype 46 | args, kwargs = self._pax_mp_policy.cast_to_compute((args, kwargs)) 47 | output = func.__get__(self, type(self))(*args, **kwargs) # type:ignore 48 | 49 | # convert weights to param dtype 50 | for name in self.pytree_attributes: 51 | value = getattr(self, name) 52 | if not _has_module(value): 53 | if value is not casted_original[name]: # modified 54 | casted_value = self._pax_mp_policy.cast_to_param(value) 55 | setattr(self, name, casted_value) 56 | else: 57 | # avoid casting operation 58 | self.__dict__[name] = original_values[name] 59 | 60 | # cast output to output dtype 61 | output = self._pax_mp_policy.cast_to_output(output) 62 | return output 63 | 64 | return mp_method_wrapper 65 | 66 | 67 | def _mp_repr(mp_policy): 68 | dtype_to_name = { 69 | jnp.bfloat16: "H", 70 | jnp.float16: "H", 71 | jnp.float32: "F", 72 | jnp.float64: "F", 73 | } 74 | 75 | return ( 76 | dtype_to_name[mp_policy.param_dtype] 77 | + dtype_to_name[mp_policy.compute_dtype] 78 | + dtype_to_name[mp_policy.output_dtype] 79 | ) 80 | 81 | 82 | def apply_mp_policy(module: T, mp_policy: jmp.Policy) -> T: 83 | """Create a mixed-precision module. 84 | 85 | Create a subclass on the fly to enforce the mixed-precision policy. 86 | 87 | >>> import jmp 88 | >>> mp_policy = jmp.get_policy("params=float32,compute=float16,output=float32") 89 | >>> net = pax.Linear(3, 3) 90 | >>> net = pax.apply_mp_policy(net, mp_policy) 91 | >>> print(net.summary()) 92 | Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF) 93 | """ 94 | 95 | if hasattr(module, "_pax_mp_policy"): 96 | raise ValueError( 97 | "Cannot apply multiple mixed-precision policies on an object.\n" 98 | "Call `pax.unwrap_mp_policy(...)` to remove the policy first." 99 | ) 100 | 101 | # pylint: disable=protected-access 102 | cls_name = module.__class__.__name__ 103 | module_methods = dir(Module) 104 | base = module.__class__ 105 | 106 | methods = {} 107 | for name in dir(base): 108 | if name != "__call__" and name.startswith("__"): 109 | continue 110 | if name == "__call__" or name not in module_methods: 111 | value = getattr(base, name) 112 | if callable(value): 113 | value = find_descriptor(base, name) 114 | if value is None: 115 | continue 116 | if isinstance(value, (staticmethod, classmethod)): 117 | methods[name] = value 118 | else: 119 | methods[name] = _wrap_method(value) 120 | 121 | def _repr(self, info=None): 122 | if info is None: 123 | info = {} 124 | info["mp_policy"] = _mp_repr(self._pax_mp_policy) 125 | return super(base, self)._repr(info) # type: ignore 126 | 127 | methods["_repr"] = _repr 128 | 129 | cls = type(cls_name, (base,), methods) 130 | obj = object.__new__(cls) 131 | obj.__dict__.update(module.__dict__) 132 | obj.__dict__["_pax_mp_policy"] = mp_policy 133 | for name in obj.pytree_attributes: 134 | value = getattr(obj, name) 135 | if not _has_module(value): 136 | obj.__dict__[name] = mp_policy.cast_to_param(obj.__dict__[name]) 137 | return obj 138 | 139 | 140 | def unwrap_mp_policy(module: T) -> T: 141 | """Unwrap a mixed-precision module to recreate the original module. 142 | 143 | >>> import jmp 144 | >>> mp_policy = jmp.get_policy("params=float32,compute=float16,output=float32") 145 | >>> net = pax.Linear(3, 3) 146 | >>> net = pax.apply_mp_policy(net, mp_policy) 147 | >>> print(net.summary()) 148 | Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF) 149 | >>> net = pax.unwrap_mp_policy(net) 150 | >>> print(net.summary()) 151 | Linear(in_dim=3, out_dim=3, with_bias=True) 152 | """ 153 | if not hasattr(module, "_pax_mp_policy"): 154 | raise ValueError("Expected a mixed-precision module.") 155 | 156 | base = module.__class__.__base__ 157 | original = object.__new__(base) 158 | original.__dict__.update(module.__dict__) 159 | del original.__dict__["_pax_mp_policy"] 160 | return original 161 | 162 | 163 | def _has_module(mod): 164 | is_mod = lambda x: x is not mod 165 | leaves, _ = jax.tree_util.tree_flatten(mod, is_leaf=is_mod) 166 | return any(map(is_mod, leaves)) 167 | -------------------------------------------------------------------------------- /pax/_src/core/module_and_value.py: -------------------------------------------------------------------------------- 1 | """PAX mechanisms to make PAX method pure.""" 2 | 3 | from functools import partial 4 | from types import MethodType 5 | from typing import Callable, Tuple, TypeVar 6 | 7 | from .base import BaseModule 8 | from .pure import pure 9 | 10 | O = TypeVar("O") 11 | T = TypeVar("T", bound=BaseModule) 12 | 13 | 14 | def module_and_value(module_or_method: Callable[..., O]) -> Callable[..., Tuple[T, O]]: 15 | """Return a pure function that executes a module's method. 16 | 17 | This pure function also returns the updated input module in the output. 18 | 19 | Example: 20 | 21 | >>> net = pax.Linear(1, 1) 22 | >>> x = jnp.ones((32, 1)) 23 | >>> net, y = pax.module_and_value(net)(x) # note: `net` is also returned. 24 | 25 | 26 | Arguments: 27 | module_or_method: Either a PAX module or a method of a PAX module. 28 | 29 | Returns: 30 | A pure function. 31 | """ 32 | is_bound_method = True 33 | if isinstance(module_or_method, MethodType): # a method 34 | mod = module_or_method.__self__ 35 | func = module_or_method.__func__ 36 | elif isinstance(module_or_method, BaseModule): # a module 37 | mod = module_or_method 38 | assert hasattr(mod, "__call__"), "Expecting a callable module." 39 | func = module_or_method.__call__.__func__ 40 | elif callable(module_or_method): 41 | is_bound_method = False 42 | func = module_or_method 43 | else: 44 | raise ValueError("Expecting a module or a module's method.") 45 | 46 | @pure 47 | def _run(mod, *args, **kwargs): 48 | assert isinstance(mod, BaseModule), "Expecting a PAX module." 49 | out = func(mod, *args, **kwargs) 50 | return mod, out 51 | 52 | if is_bound_method: 53 | return partial(_run, mod) 54 | else: 55 | return _run 56 | -------------------------------------------------------------------------------- /pax/_src/core/mutable.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | from .module import Module 4 | from .pure import get_all_submodules 5 | from .threading_local import allow_mutation 6 | 7 | 8 | @contextmanager 9 | def mutable(module: Module): 10 | """A context manager that allows a copy module to be mutable inside the context. 11 | 12 | >>> net = pax.Linear(1, 2) 13 | >>> with pax.experimental.mutable(net) as net: 14 | ... net.bias = jnp.array(0.) 15 | >>> assert net.bias.item() == 0. 16 | """ 17 | 18 | copy = module.copy() 19 | all_submodules = get_all_submodules(copy) 20 | 21 | with allow_mutation(all_submodules): 22 | try: 23 | yield copy 24 | finally: 25 | copy.find_and_register_pytree_attributes() 26 | copy.scan_bugs() 27 | -------------------------------------------------------------------------------- /pax/_src/core/pure.py: -------------------------------------------------------------------------------- 1 | """PAX mechanisms to make PAX functions pure.""" 2 | 3 | import functools 4 | from types import MethodType 5 | from typing import Any, Callable, Tuple, TypeVar 6 | 7 | import jax 8 | 9 | from .base import BaseModule 10 | from .threading_local import allow_mutation 11 | 12 | T = TypeVar("T") 13 | O = TypeVar("O") 14 | 15 | 16 | def pure(func: Callable): 17 | """Make a function pure by copying the inputs. 18 | 19 | Any modification on the copy will not affect the original inputs. 20 | 21 | **Note**: only functions that are wrapped by `pax.pure` are allowed to modify PAX's Modules. 22 | 23 | Example: 24 | 25 | >>> f = pax.Linear(3,3) 26 | >>> f.a_list = [] 27 | Traceback (most recent call last): 28 | ... 29 | ValueError: Cannot modify a module in immutable mode. 30 | Please do this computation inside a function decorated by `pax.pure`. 31 | >>> 32 | >>> @pax.pure 33 | ... def add_list(m): 34 | ... m.a_list = [] 35 | ... return m 36 | ... 37 | >>> f = add_list(f) 38 | >>> print(f.a_list) 39 | [] 40 | 41 | Arguments: 42 | func: A function. 43 | 44 | Returns: 45 | A pure function. 46 | """ 47 | 48 | @functools.wraps(func) 49 | def wrapper(*args, **kwargs): 50 | for m in _get_modules((func, args, kwargs)): 51 | m.scan_bugs() 52 | 53 | # support calling method 54 | if isinstance(func, MethodType): 55 | args = (func.__self__, *args) 56 | unbound_func = func.__func__ 57 | # or calling a module 58 | elif isinstance(func, BaseModule) and callable(func): 59 | args = (func, *args) 60 | unbound_func = func.__call__.__func__ 61 | elif callable(func): 62 | unbound_func = func 63 | else: 64 | raise ValueError("Not supported") 65 | 66 | args, kwargs = _copy((args, kwargs)) 67 | modules = get_all_submodules((args, kwargs)) 68 | with allow_mutation(modules): 69 | out = unbound_func(*args, **kwargs) 70 | 71 | for m in modules: 72 | m.find_and_register_pytree_attributes() 73 | m.scan_bugs() 74 | return out 75 | 76 | return wrapper 77 | 78 | 79 | @pure 80 | def purecall(module: Callable[..., O], *args, **kwargs) -> Tuple[Any, O]: 81 | """Call a module and return the updated module. 82 | 83 | A shortcut for `pax.pure(lambda f, x: [f, f(x)])`. 84 | """ 85 | assert isinstance(module, BaseModule) 86 | assert callable(module) 87 | return module, module(*args, **kwargs) 88 | 89 | 90 | def _get_modules(tree): 91 | "Return a list of modules in the pytree `tree`." 92 | modules = jax.tree_util.tree_flatten( 93 | tree, is_leaf=lambda x: isinstance(x, BaseModule) 94 | )[0] 95 | modules = [m for m in modules if isinstance(m, BaseModule)] 96 | return modules 97 | 98 | 99 | def get_all_submodules(value): 100 | submods = _get_modules(value) 101 | out = list(submods) 102 | for mod in submods: 103 | out.extend(get_all_submodules(mod.submodules())) 104 | return out 105 | 106 | 107 | def _copy(value: T) -> T: 108 | leaves, treedef = jax.tree_util.tree_flatten(value) 109 | return jax.tree_util.tree_unflatten(treedef, leaves) 110 | -------------------------------------------------------------------------------- /pax/_src/core/rng.py: -------------------------------------------------------------------------------- 1 | """Random Number Generator.""" 2 | 3 | from .threading_local import KeyArray, next_rng_key, seed_rng_key 4 | 5 | __all__ = ( 6 | "KeyArray", 7 | "next_rng_key", 8 | "seed_rng_key", 9 | ) 10 | -------------------------------------------------------------------------------- /pax/_src/core/safe_module.py: -------------------------------------------------------------------------------- 1 | """Safeguards to prevent potential bugs.""" 2 | 3 | import inspect 4 | from typing import Iterable, List, Type, TypeVar 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | 10 | from .base import BaseModule 11 | from .threading_local import allow_mutation, is_mutable 12 | 13 | T = TypeVar("T") 14 | 15 | 16 | class SafeBaseModuleMetaclass(type): 17 | """Metaclass for `SafeBaseModule`.""" 18 | 19 | def __call__(cls: Type[T], *args, **kwargs) -> T: 20 | module = cls.__new__(cls, *args, **kwargs) # type: ignore 21 | 22 | with allow_mutation(module): 23 | cls.__init__(module, *args, **kwargs) 24 | module.find_and_register_pytree_attributes() 25 | 26 | # scan module after initialization for potential bugs 27 | if hasattr(module, "__slots__"): 28 | raise ValueError("`__slots__` is not supported by PAX modules.") 29 | module._assert_not_shared_module() 30 | module._assert_not_shared_weight() 31 | module._scan_fields(module._class_fields()) 32 | return module 33 | 34 | 35 | class SafeBaseModule(BaseModule, metaclass=SafeBaseModuleMetaclass): 36 | """Adding safe guards to BaseModule to prevent bugs.""" 37 | 38 | def _class_fields(self): 39 | for name, value in inspect.getmembers(self): 40 | if name.startswith("__") or inspect.ismethod(value): 41 | continue 42 | 43 | if name in self.__dict__: 44 | continue 45 | 46 | if find_descriptor(self.__class__, name) is not None: 47 | # ignore descriptors 48 | continue 49 | 50 | yield name 51 | 52 | def _assert_mutability(self): 53 | if not is_mutable(self): 54 | raise ValueError( 55 | "Cannot modify a module in immutable mode.\n" 56 | "Please do this computation inside a function decorated by `pax.pure`." 57 | ) 58 | 59 | def _assert_not_shared_module(self): 60 | """Shared module is not allowed.""" 61 | shared_module = _find_shared_module(self) 62 | if shared_module is not None: 63 | raise ValueError( 64 | f"The module `{shared_module}` is shared between two nodes of the pytree.\n" 65 | f"This is not allowed to prevent potential silence bugs." 66 | ) 67 | 68 | def _assert_not_shared_weight(self): 69 | """Shared weight is not allowed.""" 70 | leaves = jax.tree_util.tree_leaves(self) 71 | leaf_ids = set() 72 | for leaf in leaves: 73 | if id(leaf) in leaf_ids: 74 | raise ValueError( 75 | f"Detected a shared ndarray. This is not allowed.\n" 76 | f"Shape={leaf.shape}\n" 77 | f"Dtype={leaf.dtype}\n" 78 | f"Value={leaf}", 79 | ) 80 | leaf_ids.add(id(leaf)) 81 | 82 | def _scan_fields(self, fields: Iterable[str]): 83 | """Scan fields for *potential* bugs.""" 84 | 85 | for name in fields: 86 | if name in self.pytree_attributes: 87 | continue 88 | 89 | value = getattr(self, name) 90 | is_mod = lambda x: isinstance(x, BaseModule) 91 | is_ndarray = lambda x: isinstance(x, (jnp.ndarray, np.ndarray)) 92 | mods, _ = jax.tree_util.tree_flatten(value, is_leaf=is_mod) 93 | leaves = jax.tree_util.tree_leaves(value) 94 | has_mods = any(map(is_mod, mods)) 95 | has_arrays = any(map(is_ndarray, mods)) 96 | 97 | if has_mods: 98 | raise ValueError( 99 | f"\n" 100 | f"Unregistered field `{self}.{name}`:\n" 101 | f" value={value}\n" 102 | f"contains a module leaf.\n" 103 | ) 104 | 105 | if has_arrays: 106 | raise ValueError( 107 | f"\n" 108 | f"Unregistered field `{self}.{name}`:\n" 109 | f" value={value}\n" 110 | f"contains a ndarray leaf.\n" 111 | ) 112 | 113 | 114 | def _find_shared_module(module: BaseModule): 115 | """Find shared module. 116 | 117 | - Return the first module that is shared by two nodes of the pytree. 118 | - Return `None` if there is no shared module. 119 | """ 120 | 121 | def _get_all_modules(mod: BaseModule, lst: List): 122 | lst.append(mod) 123 | is_mod = lambda x: isinstance(x, BaseModule) and x is not mod 124 | submodules, _ = jax.tree_util.tree_flatten(mod, is_leaf=is_mod) 125 | submodules = (m for m in submodules if is_mod(m)) 126 | for m in submodules: 127 | _get_all_modules(m, lst) 128 | 129 | mods = [] 130 | _get_all_modules(module, mods) 131 | module_ids = set() 132 | for m in mods: 133 | if id(m) in module_ids: 134 | return m 135 | module_ids.add(id(m)) 136 | 137 | return None 138 | 139 | 140 | # source: https://stackoverflow.com/a/21963090 141 | def find_descriptor(cls, attrname): 142 | """Find the descriptor of an attribute.""" 143 | 144 | def hasspecialmethod(obj, name): 145 | return any(name in klass.__dict__ for klass in type(obj).__mro__) 146 | 147 | for klass in cls.__mro__: 148 | if attrname in klass.__dict__: 149 | descriptor = klass.__dict__[attrname] 150 | if not hasspecialmethod(descriptor, "__get__"): 151 | return None 152 | return descriptor 153 | return None 154 | -------------------------------------------------------------------------------- /pax/_src/core/threading_local.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manage thread local states 3 | """ 4 | 5 | import random 6 | import threading 7 | import weakref 8 | from contextlib import contextmanager 9 | from typing import Any, Optional, Tuple, Union 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | import jax.tree_util 14 | 15 | KeyArray = Union[Any, jnp.ndarray] 16 | 17 | 18 | class PaxThreadingLocalState(threading.local): 19 | """Manage all thread local states used by PAX""" 20 | 21 | _mutable_module_ref_list: Tuple[weakref.ReferenceType, ...] 22 | _mutable_module_level: int 23 | _rng: Optional[random.Random] 24 | 25 | def __init__(self): 26 | super().__init__() 27 | self._mutable_module_ref_list = () 28 | self._mutable_module_level = _jax_cur_level() 29 | self._rng = random.Random(42) 30 | 31 | def add_mutable_module(self, module): 32 | """add `module` to mutable list""" 33 | self._mutable_module_ref_list = ( 34 | weakref.ref(module), 35 | *self._mutable_module_ref_list, 36 | ) 37 | 38 | def is_mutable(self, module): 39 | """Is `module` mutable?""" 40 | 41 | # cannot modify a module whose level of abstraction 42 | # is lower than the current level 43 | if self._mutable_module_level < _jax_cur_level(): 44 | return False 45 | 46 | for ref in self._mutable_module_ref_list: 47 | if module is ref(): 48 | return True 49 | 50 | return False 51 | 52 | @contextmanager 53 | def allow_mutation(self, modules): 54 | """A context manager that turns on mutability.""" 55 | 56 | if not isinstance(modules, (tuple, list)): 57 | modules = (modules,) 58 | modules = tuple(weakref.ref(mod) for mod in modules) 59 | 60 | prev = self._mutable_module_ref_list 61 | prev_abstraction_level = self._mutable_module_level 62 | try: 63 | self._mutable_module_ref_list = modules 64 | self._mutable_module_level = _jax_cur_level() 65 | yield 66 | finally: 67 | self._mutable_module_ref_list = prev 68 | self._mutable_module_level = prev_abstraction_level 69 | 70 | def seed_rng_key(self, seed: int) -> None: 71 | """Set ``self._rng = random.Random(seed)``. 72 | 73 | Arguments: 74 | seed: an integer seed. 75 | """ 76 | assert isinstance(seed, int) 77 | self._rng = random.Random(seed) 78 | 79 | def next_rng_key(self) -> KeyArray: 80 | """Return a random rng key. Renew the global random state.""" 81 | seed = self._rng.randint(1, 999999999) 82 | return jax.random.PRNGKey(seed) 83 | 84 | def get_rng_state(self): 85 | """Return internal random states.""" 86 | return self._rng.getstate() 87 | 88 | def set_rng_state(self, state): 89 | """Set internal random states.""" 90 | self._rng.setstate(state) 91 | 92 | 93 | def _jax_cur_level(): 94 | """ 95 | Return the level of current jax trace. 96 | 97 | If it is an eval_trace, return -1. 98 | """ 99 | trace = jax.core.thread_local_state.trace_state.trace_stack.stack[-1] 100 | if trace.trace_type == jax.core.EvalTrace: 101 | return -1 102 | else: 103 | return trace.level 104 | 105 | 106 | PAX_STATE = PaxThreadingLocalState() 107 | add_mutable_module = PAX_STATE.add_mutable_module 108 | allow_mutation = PAX_STATE.allow_mutation 109 | get_rng_state = PAX_STATE.get_rng_state 110 | is_mutable = PAX_STATE.is_mutable 111 | next_rng_key = PAX_STATE.next_rng_key 112 | seed_rng_key = PAX_STATE.seed_rng_key 113 | set_rng_state = PAX_STATE.set_rng_state 114 | -------------------------------------------------------------------------------- /pax/_src/core/transforms.py: -------------------------------------------------------------------------------- 1 | """Transform a module to a new one.""" 2 | from typing import Any, TypeVar 3 | 4 | import jax 5 | 6 | from .module import Module, parameters_method, update_pytree 7 | 8 | TreeDef = Any 9 | 10 | T = TypeVar("T", bound=Module) 11 | K = TypeVar("K", bound=Module) 12 | O = TypeVar("O", bound=Module) 13 | 14 | 15 | def enable_train_mode(mod: T) -> T: 16 | """Return a module in training mode.""" 17 | return mod.train() 18 | 19 | 20 | def enable_eval_mode(mod: T) -> T: 21 | """Return a module in evaluation mode.""" 22 | return mod.eval() 23 | 24 | 25 | def freeze_parameters(mod: T) -> T: 26 | """Return a copy module with all trainable parameters are converted to non-trainable states.""" 27 | 28 | def _freeze_apply_fn(mod: T) -> T: 29 | return mod.replace_method(parameters=parameters_method()) 30 | 31 | return mod.apply(_freeze_apply_fn) 32 | 33 | 34 | def unfreeze_parameters(mod: T, *, origin: T) -> T: 35 | """Return a copy module with all trainable parameters are converted to non-trainable states.""" 36 | tree_def = jax.tree_util.tree_structure(origin) 37 | leaves = jax.tree_util.tree_leaves(mod) 38 | return jax.tree_util.tree_unflatten(tree_def, leaves) 39 | 40 | 41 | def select_parameters(mod: T) -> T: 42 | """Select `PARAMETER` leaves only.""" 43 | return mod.parameters() 44 | 45 | 46 | def update_parameters(mod: T, *, params: T) -> T: 47 | """Return a module that uses trainable parameters in `params`.""" 48 | return update_pytree(mod, other=params.parameters()) 49 | -------------------------------------------------------------------------------- /pax/_src/core/utility_modules.py: -------------------------------------------------------------------------------- 1 | """Utility Modules.""" 2 | 3 | 4 | from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | from .module import Module, parameters_method 10 | 11 | T = TypeVar("T", bound=Module) 12 | O = TypeVar("O") 13 | 14 | 15 | class ParameterModule(Module): 16 | """A PAX module that registers attributes as parameters by default.""" 17 | 18 | def parameters(self): 19 | return self.apply_submodules(lambda x: x.parameters()) 20 | 21 | 22 | class StateModule(Module): 23 | """A PAX module that registers attributes as states by default.""" 24 | 25 | parameters = parameters_method() 26 | 27 | 28 | class LazyModule(Module): 29 | """A lazy module is a module that only creates submodules when needed. 30 | 31 | 32 | Example: 33 | 34 | >>> from dataclasses import dataclass 35 | >>> @dataclass 36 | ... class MLP(pax.experimental.LazyModule): 37 | ... features: list 38 | ... 39 | ... def __call__(self, x): 40 | ... sizes = zip(self.features[:-1], self.features[1:]) 41 | ... for i, (in_dim, out_dim) in enumerate(sizes): 42 | ... fc = self.get_or_create(f"fc_{i}", lambda: pax.Linear(in_dim, out_dim)) 43 | ... x = jax.nn.relu(fc(x)) 44 | ... return x 45 | ... 46 | ... 47 | >>> mlp, _ = MLP([1, 2, 3]) % jnp.ones((1, 1)) 48 | >>> print(mlp.summary()) 49 | MLP(features=[1, 2, 3]) 50 | ├── Linear(in_dim=1, out_dim=2, with_bias=True) 51 | └── Linear(in_dim=2, out_dim=3, with_bias=True) 52 | """ 53 | 54 | def get_or_create(self, name, create_fn: Callable[[], T]) -> T: 55 | """Create and register a new attribute when it is not exist. 56 | 57 | Return the attribute. 58 | """ 59 | if hasattr(self, name): 60 | value = getattr(self, name) 61 | else: 62 | assert callable(create_fn), "Expect a callable function" 63 | value = create_fn() 64 | setattr(self, name, value) 65 | return value 66 | 67 | 68 | class Lambda(Module): 69 | """Convert a function to a module. 70 | 71 | Example: 72 | 73 | >>> net = pax.Lambda(jax.nn.relu) 74 | >>> print(net.summary()) 75 | x => relu(x) 76 | >>> y = net(jnp.array(-1)) 77 | >>> y 78 | DeviceArray(0, dtype=int32, weak_type=True) 79 | """ 80 | 81 | func: Callable 82 | 83 | def __init__(self, func: Callable, name: Optional[str] = None): 84 | super().__init__(name=name) 85 | self.func = func 86 | 87 | def __call__(self, *args, **kwargs): 88 | return self.func(*args, **kwargs) 89 | 90 | def __repr__(self) -> str: 91 | if self.name is not None: 92 | return super().__repr__() 93 | else: 94 | return f"{self.__class__.__qualname__}({self.func.__name__})" 95 | 96 | def summary(self, return_list: bool = False) -> Union[str, List[str]]: 97 | if self.name is not None: 98 | name = self.name 99 | elif isinstance(self.func, jax.custom_jvp) and hasattr(self.func, "fun"): 100 | if hasattr(self.func.fun, "__name__"): 101 | name = self.func.fun.__name__ 102 | else: 103 | name = f"{self.func.fun}" 104 | elif hasattr(self.func, "__name__"): 105 | name = self.func.__name__ 106 | else: 107 | name = f"{self.func}" 108 | output = f"x => {name}(x)" 109 | return [output] if return_list else output 110 | 111 | 112 | class Flattener(Module): 113 | """Flatten PAX modules for better performance. 114 | 115 | Example: 116 | 117 | >>> net = pax.Linear(3, 3) 118 | >>> opt = opax.adam(1e-3)(net.parameters()) 119 | >>> flat_mods = pax.experimental.Flattener(model=net, optimizer=opt) 120 | >>> net, opt = flat_mods.model, flat_mods.optimizer 121 | >>> print(net.summary()) 122 | Linear(in_dim=3, out_dim=3, with_bias=True) 123 | >>> print(opt.summary()) 124 | chain..Chain 125 | ├── scale_by_adam..ScaleByAdam 126 | │ ├── Linear(in_dim=3, out_dim=3, with_bias=True) 127 | │ └── Linear(in_dim=3, out_dim=3, with_bias=True) 128 | └── scale..Scale 129 | """ 130 | 131 | treedef_dict: Dict[str, Any] 132 | leaves_dict: Dict[str, Sequence[jnp.ndarray]] 133 | 134 | def __init__(self, **kwargs): 135 | """Create a new flattener.""" 136 | super().__init__() 137 | self.treedef_dict = {} 138 | self.leaves_dict = {} 139 | 140 | for name, value in kwargs.items(): 141 | leaves, treedef = jax.tree_util.tree_flatten(value) 142 | self.treedef_dict[name] = treedef 143 | self.leaves_dict[name] = leaves 144 | 145 | def __getattr__(self, name: str) -> Any: 146 | if name in self.treedef_dict: 147 | treedef = self.treedef_dict[name] 148 | leaves = self.leaves_dict[name] 149 | value = jax.tree_util.tree_unflatten(treedef, leaves) 150 | return value 151 | else: 152 | raise AttributeError() 153 | 154 | def update(self: T, **kwargs) -> T: 155 | """Update the flattener. 156 | 157 | Example: 158 | 159 | >>> net = pax.Linear(3, 3) 160 | >>> flats = pax.experimental.Flattener(net=net) 161 | >>> flats = flats.update(net=pax.Linear(4, 4)) 162 | >>> print(flats.net.summary()) 163 | Linear(in_dim=4, out_dim=4, with_bias=True) 164 | """ 165 | new_self = self.copy() 166 | for name, value in kwargs.items(): 167 | leaves, treedef = jax.tree_util.tree_flatten(value) 168 | new_self.treedef_dict[name] = treedef 169 | new_self.leaves_dict[name] = leaves 170 | return new_self 171 | 172 | def parameters(self: T) -> T: 173 | """Raise an error. 174 | 175 | Need to reconstruct the original module before getting parameters. 176 | """ 177 | 178 | raise ValueError( 179 | "A flattener only stores ndarray leaves as non-trainable states.\n" 180 | "Reconstruct the original module before getting parameters." 181 | ) 182 | -------------------------------------------------------------------------------- /pax/_src/core/utils.py: -------------------------------------------------------------------------------- 1 | """Useful functions.""" 2 | 3 | from typing import TypeVar 4 | from unittest import TestCase 5 | 6 | import jax 7 | 8 | from .module import Module 9 | 10 | T = TypeVar("T", bound=Module) 11 | 12 | 13 | def assert_structure_equal(tree_a: T, tree_b: T): 14 | """Assert that the two pytrees are structurally the same. 15 | 16 | Print out the difference. 17 | """ 18 | if jax.tree_util.tree_structure(tree_a) == jax.tree_util.tree_structure(tree_b): 19 | return True 20 | 21 | def check(subtree_a, subtree_b): 22 | if isinstance(subtree_a, Module) and isinstance(subtree_b, Module): 23 | assert_structure_equal(subtree_a, subtree_b) 24 | 25 | has_error = False 26 | try: 27 | jax.tree_util.tree_map( 28 | check, 29 | tree_a, 30 | tree_b, 31 | is_leaf=lambda x: isinstance(x, Module) 32 | and x is not tree_a 33 | and x is not tree_b, 34 | ) 35 | except ValueError: 36 | has_error = True 37 | 38 | if has_error: 39 | test_case = TestCase() 40 | test_case.maxDiff = None 41 | # do not compare weights 42 | tree_a_w_none_leaves = jax.tree_util.tree_map(lambda _: None, tree_a) 43 | tree_b_w_none_leaves = jax.tree_util.tree_map(lambda _: None, tree_b) 44 | test_case.assertDictEqual( 45 | vars(tree_a_w_none_leaves), vars(tree_b_w_none_leaves) 46 | ) 47 | 48 | return has_error 49 | -------------------------------------------------------------------------------- /pax/_src/nets/__init__.py: -------------------------------------------------------------------------------- 1 | """Public nets.""" 2 | 3 | from .resnet import ( 4 | ResNet, 5 | ResNet18, 6 | ResNet34, 7 | ResNet50, 8 | ResNet101, 9 | ResNet152, 10 | ResNet200, 11 | ) 12 | from .transformer import Transformer 13 | -------------------------------------------------------------------------------- /pax/_src/nets/transformer.py: -------------------------------------------------------------------------------- 1 | """Transformer Decoder Stack.""" 2 | 3 | from typing import Dict, Optional, Sequence 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from ..core import Module 10 | from ..nn import LayerNorm, Linear, MultiHeadAttention, RngSeq 11 | from ..nn.dropout import dropout 12 | 13 | 14 | class CausalSelfAttention(MultiHeadAttention): 15 | """Self attention with a causal mask applied.""" 16 | 17 | def __call__( 18 | self, 19 | query: jnp.ndarray, 20 | key: Optional[jnp.ndarray] = None, 21 | value: Optional[jnp.ndarray] = None, 22 | mask: Optional[jnp.ndarray] = None, 23 | ) -> jnp.ndarray: 24 | key = key if key is not None else query 25 | value = value if value is not None else query 26 | 27 | seq_len = query.shape[1] 28 | causal_mask = np.tril(np.ones((seq_len, seq_len))) 29 | mask = mask * causal_mask if mask is not None else causal_mask 30 | 31 | return super().__call__(query, key, value, mask) 32 | 33 | 34 | class DenseBlock(Module): 35 | """A 2-layer MLP which widens then narrows the input.""" 36 | 37 | def __init__(self, in_dim: int, init_scale: float, widening_factor: int = 4): 38 | super().__init__() 39 | self._init_scale = init_scale 40 | initializer = jax.nn.initializers.variance_scaling( 41 | self._init_scale, mode="fan_in", distribution="normal" 42 | ) 43 | self._widening_factor = widening_factor 44 | self.fc1 = Linear(in_dim, in_dim * widening_factor, w_init=initializer) 45 | self.fc2 = Linear(in_dim * widening_factor, in_dim, w_init=initializer) 46 | 47 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 48 | x = self.fc1(x) 49 | x = jax.nn.gelu(x) 50 | return self.fc2(x) 51 | 52 | 53 | class Transformer(Module): 54 | """A transformer stack.""" 55 | 56 | layers: Sequence[Dict[str, Module]] 57 | 58 | def __init__(self, dim: int, num_heads: int, num_layers: int, dropout_rate: float): 59 | super().__init__() 60 | assert dim % num_heads == 0 61 | self._num_layers = num_layers 62 | self._num_heads = num_heads 63 | self._dropout_rate = dropout_rate 64 | 65 | self.rng_seq = RngSeq() 66 | 67 | init_scale = 2.0 / self._num_layers 68 | layers = [] 69 | for _ in range(num_layers): 70 | layers.append( 71 | { 72 | "attention": CausalSelfAttention( 73 | num_heads=self._num_heads, 74 | key_size=dim // num_heads, 75 | w_init_scale=init_scale, 76 | ), 77 | "attn_layer_norm": LayerNorm(dim, -1, True, True), 78 | "dense_layer_norm": LayerNorm(dim, -1, True, True), 79 | "dense_block": DenseBlock(dim, init_scale), 80 | } 81 | ) 82 | self.layers = layers 83 | self.layer_norm_output = LayerNorm(dim, -1, True, True) 84 | 85 | def __call__( 86 | self, h: jnp.ndarray, mask: Optional[jnp.ndarray] = None 87 | ) -> jnp.ndarray: 88 | """Connects the transformer. 89 | Args: 90 | h: Inputs, [B, T, H]. 91 | mask: Padding mask, [B, T]. 92 | is_training: Whether we're training or not. 93 | Returns: 94 | Array of shape [B, T, H]. 95 | """ 96 | 97 | dropout_rate = self._dropout_rate if self.training else 0.0 98 | if mask is not None: 99 | mask = mask[:, None, None, :] 100 | 101 | # Note: names chosen to approximately match those used in the GPT-2 code; 102 | # see https://github.com/openai/gpt-2/blob/master/src/model.py. 103 | rngs = self.rng_seq.next_rng_key(self._num_layers * 2) 104 | for i in range(self._num_layers): 105 | h_norm = self.layers[i]["attn_layer_norm"](h) 106 | h_attn = self.layers[i]["attention"](h_norm, mask=mask) 107 | h_attn = dropout(rngs[i * 2 + 0], dropout_rate, h_attn) 108 | h = h + h_attn 109 | h_norm = self.layers[i]["dense_layer_norm"](h) 110 | h_dense = self.layers[i]["dense_block"](h_norm) 111 | h_dense = dropout(rngs[i * 2 + 1], dropout_rate, h_dense) 112 | h = h + h_dense 113 | h = self.layer_norm_output(h) 114 | 115 | return h 116 | -------------------------------------------------------------------------------- /pax/_src/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Modules.""" 2 | 3 | from .attention import MultiHeadAttention 4 | from .batch_norm import BatchNorm1D, BatchNorm2D 5 | from .conv import Conv1D, Conv1DTranspose, Conv2D, Conv2DTranspose 6 | from .dropout import Dropout 7 | from .ema import EMA 8 | from .embed import Embed 9 | from .group_norm import GroupNorm 10 | from .identity import Identity 11 | from .lambda_module import Lambda 12 | from .layer_norm import LayerNorm 13 | from .linear import Linear 14 | from .pool import avg_pool, max_pool 15 | from .recurrent import GRU, LSTM, GRUState, LSTMState, VanillaRNN, VanillaRNNState 16 | from .rng_seq import RngSeq 17 | from .sequential import Sequential 18 | -------------------------------------------------------------------------------- /pax/_src/nn/attention.py: -------------------------------------------------------------------------------- 1 | """Transformer self-attention module.""" 2 | 3 | # Source: https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/attention.py 4 | from typing import Optional 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | 10 | from ..core import Module 11 | from .linear import Linear 12 | 13 | 14 | class MultiHeadAttention(Module): 15 | """Multi-headed attention mechanism. 16 | As described in the vanilla Transformer paper: 17 | "Attention is all you need" https://arxiv.org/abs/1706.03762 18 | """ 19 | 20 | num_heads: int 21 | key_size: int 22 | value_size: int 23 | model_size: int 24 | 25 | def __init__( 26 | self, 27 | num_heads: int, 28 | key_size: int, 29 | w_init_scale: float, 30 | ): 31 | super().__init__() 32 | self.num_heads = num_heads 33 | self.key_size = key_size 34 | self.value_size = key_size 35 | self.model_size = key_size * num_heads 36 | w_init = jax.nn.initializers.variance_scaling( 37 | w_init_scale, mode="fan_in", distribution="normal" 38 | ) 39 | self.query_projection = Linear( 40 | self.model_size, self.model_size, w_init=w_init, name="qry_proj" 41 | ) 42 | self.key_projection = Linear( 43 | self.model_size, self.model_size, w_init=w_init, name="key_proj" 44 | ) 45 | self.value_projection = Linear( 46 | self.model_size, self.model_size, w_init=w_init, name="val_proj" 47 | ) 48 | self.output_projection = Linear( 49 | self.model_size, self.model_size, w_init=w_init, name="out_proj" 50 | ) 51 | 52 | def __call__( 53 | self, 54 | query: jnp.ndarray, 55 | key: jnp.ndarray, 56 | value: jnp.ndarray, 57 | mask: Optional[jnp.ndarray] = None, 58 | ) -> jnp.ndarray: 59 | """Compute (optionally masked) MHA with queries, keys & values.""" 60 | 61 | query_heads = self.query_projection(query) 62 | key_heads = self.key_projection(key) 63 | value_heads = self.value_projection(value) 64 | (query_heads, key_heads, value_heads) = jax.tree_util.tree_map( 65 | lambda x, y: x.reshape(*y.shape[:-1], self.num_heads, self.key_size), 66 | (query_heads, key_heads, value_heads), 67 | (query, key, value), 68 | ) 69 | 70 | attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads) 71 | sqrt_key_size = np.sqrt(self.key_size).astype(key.dtype) 72 | attn_logits = attn_logits / sqrt_key_size 73 | if mask is not None: 74 | # assert mask.shape == attn_logits.shape 75 | attn_logits = jnp.where(mask, attn_logits, -1e30) 76 | 77 | attn_weights = jax.nn.softmax(attn_logits) 78 | attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) 79 | # Concatenate attention matrix of all heads into a single vector. 80 | attn_vec = jnp.reshape(attn, (*query.shape[:-1], -1)) 81 | return self.output_projection(attn_vec) 82 | 83 | def __repr__(self, info=None) -> str: 84 | info = {"num_heads": self.num_heads, "key_size": self.key_size} 85 | return self._repr(info) 86 | -------------------------------------------------------------------------------- /pax/_src/nn/batch_norm.py: -------------------------------------------------------------------------------- 1 | """BatchNorm modules.""" 2 | 3 | from typing import Optional, Sequence 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from ..core import Module, parameters_method 9 | from .ema import EMA 10 | 11 | 12 | class BatchNorm(Module): 13 | """A Generic BatchNorm Module. 14 | 15 | Normalize a mini-batch of data by subtracting its mean and dividing by its standard deviation. 16 | 17 | Use EMA modules to track the averaged mean and averaged variance for later uses in `eval` mode. 18 | """ 19 | 20 | scale: Optional[jnp.ndarray] 21 | offset: Optional[jnp.ndarray] 22 | 23 | parameters = parameters_method("scale", "offset") 24 | 25 | ema_mean: EMA 26 | ema_var: EMA 27 | 28 | reduced_axes: Sequence[int] 29 | create_offset: bool 30 | create_scale: bool 31 | eps: float 32 | data_format: Optional[str] 33 | 34 | def __init__( 35 | self, 36 | num_channels: int, 37 | create_scale: bool = True, 38 | create_offset: bool = True, 39 | decay_rate: float = 0.9, 40 | eps: float = 1e-5, 41 | data_format: Optional[str] = None, 42 | reduced_axes=None, 43 | param_shape=None, 44 | *, 45 | name: Optional[str] = None, 46 | ): 47 | """Create a new BatchNorm module. 48 | 49 | Arguments: 50 | num_channels: the number of filters. 51 | create_scale: create a trainable `scale` parameter. 52 | create_offset: create a trainable `offset` parameter. 53 | decay_rate: the decay rate for tracking the averaged mean and the averaged variance. 54 | eps: a small positive number to avoid divided by zero. 55 | data_format: the data format ["NHWC", NCHW", "NWC", "NCW"]. 56 | reduced_axes: list of axes that will be reduced in the `jnp.mean` computation. 57 | param_shape: the shape of parameters. 58 | """ 59 | super().__init__(name=name) 60 | assert 0 <= decay_rate <= 1 61 | 62 | self.num_channels = num_channels 63 | self.data_format = data_format 64 | self.create_scale = create_scale 65 | self.create_offset = create_offset 66 | self.eps = eps 67 | self.decay_rate = decay_rate 68 | 69 | self.reduced_axes = tuple(reduced_axes) 70 | 71 | if create_scale: 72 | self.scale = jnp.ones(param_shape, dtype=jnp.float32) 73 | else: 74 | self.scale = None 75 | if create_offset: 76 | self.offset = jnp.zeros(param_shape, dtype=jnp.float32) 77 | else: 78 | self.offset = None 79 | 80 | # initial values do not matter because debias=True 81 | initial_mean = jnp.zeros(param_shape, dtype=jnp.float32) 82 | self.ema_mean = EMA(initial_mean, decay_rate, debias=True) 83 | initial_var = jnp.ones(param_shape, dtype=jnp.float32) 84 | self.ema_var = EMA(initial_var, decay_rate, debias=True) 85 | 86 | def __call__(self, x): 87 | if self.training: 88 | batch_mean = jnp.mean(x, axis=self.reduced_axes, keepdims=True) 89 | batch_mean_of_squares = jnp.mean( 90 | jnp.square(x), axis=self.reduced_axes, keepdims=True 91 | ) 92 | batch_var = batch_mean_of_squares - jnp.square(batch_mean) 93 | self.ema_mean(batch_mean) 94 | self.ema_var(batch_var) 95 | else: 96 | batch_mean = self.ema_mean.averages 97 | batch_var = self.ema_var.averages 98 | 99 | if self.create_scale: 100 | scale = self.scale 101 | else: 102 | scale = 1.0 103 | 104 | if self.create_offset: 105 | offset = self.offset 106 | else: 107 | offset = 0.0 108 | 109 | inv = scale * jax.lax.rsqrt(batch_var + self.eps) 110 | x = (x - batch_mean) * inv + offset 111 | return x 112 | 113 | def __repr__(self): 114 | info = { 115 | "num_channels": self.num_channels, 116 | "create_scale": self.create_scale, 117 | "create_offset": self.create_offset, 118 | "data_format": self.data_format, 119 | "decay_rate": self.decay_rate, 120 | } 121 | return self._repr(info) 122 | 123 | def summary(self, return_list: bool = False): 124 | lines = super().summary(return_list=True) 125 | if return_list: 126 | return lines[:1] 127 | else: 128 | return lines[0] 129 | 130 | 131 | class BatchNorm1D(BatchNorm): 132 | """The 1D version of BatchNorm.""" 133 | 134 | def __init__( 135 | self, 136 | num_channels: int, 137 | create_scale: bool = True, 138 | create_offset: bool = True, 139 | decay_rate: float = 0.9, 140 | eps: float = 1e-5, 141 | data_format: str = "NWC", 142 | *, 143 | name: Optional[str] = None, 144 | ): 145 | assert data_format in ["NWC", "NCW"], "expecting a correct `data_format`" 146 | 147 | param_shape = [1, 1, 1] 148 | if data_format == "NWC": 149 | axis = -1 150 | reduced_axes = [0, 1] 151 | else: 152 | axis = 1 153 | reduced_axes = [0, 2] 154 | param_shape[axis] = num_channels 155 | 156 | super().__init__( 157 | num_channels=num_channels, 158 | create_scale=create_scale, 159 | create_offset=create_offset, 160 | decay_rate=decay_rate, 161 | eps=eps, 162 | data_format=data_format, 163 | param_shape=param_shape, 164 | reduced_axes=reduced_axes, 165 | name=name, 166 | ) 167 | 168 | 169 | class BatchNorm2D(BatchNorm): 170 | """The 2D version of BatchNorm.""" 171 | 172 | def __init__( 173 | self, 174 | num_channels: int, 175 | create_scale: bool = True, 176 | create_offset: bool = True, 177 | decay_rate: float = 0.9, 178 | eps: float = 1e-5, 179 | data_format: str = "NHWC", 180 | *, 181 | name: Optional[str] = None, 182 | ): 183 | assert data_format in ["NHWC", "NCHW"], "expecting a correct `data_format`" 184 | 185 | param_shape = [1, 1, 1, 1] 186 | if data_format == "NHWC": 187 | axis = -1 188 | reduced_axes = [0, 1, 2] 189 | else: 190 | axis = 1 191 | reduced_axes = [0, 2, 3] 192 | param_shape[axis] = num_channels 193 | 194 | super().__init__( 195 | num_channels=num_channels, 196 | create_scale=create_scale, 197 | create_offset=create_offset, 198 | decay_rate=decay_rate, 199 | eps=eps, 200 | data_format=data_format, 201 | param_shape=param_shape, 202 | reduced_axes=reduced_axes, 203 | name=name, 204 | ) 205 | -------------------------------------------------------------------------------- /pax/_src/nn/dropout.py: -------------------------------------------------------------------------------- 1 | """Dropout module.""" 2 | 3 | from typing import Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from ..core import StateModule 9 | from ..core.rng import KeyArray, next_rng_key 10 | 11 | 12 | def dropout(rng_key: KeyArray, dropout_rate: float, x: jnp.ndarray) -> jnp.ndarray: 13 | """Dropout input `x` randomly. 14 | 15 | Scaling the input by ``1 / (1-dropout_rate)`` makes ``E[output] = input``. 16 | """ 17 | assert 0 <= dropout_rate < 1.0 18 | 19 | if dropout_rate == 0.0: 20 | return x 21 | else: 22 | mask = jax.random.bernoulli(rng_key, dropout_rate, shape=x.shape) 23 | x = jnp.where(mask, 0.0, x / (1.0 - dropout_rate)) 24 | return x 25 | 26 | 27 | class Dropout(StateModule): 28 | """A Dropout Module. 29 | 30 | Dropout module stores an internal state ``rng_key``. 31 | It refreshes ``rng_key`` whenever a forward pass is executed. 32 | """ 33 | 34 | rng_key: KeyArray 35 | dropout_rate: float 36 | 37 | def __init__(self, dropout_rate: float, *, name: Optional[str] = None): 38 | """Create a dropout module. 39 | 40 | Arguments: 41 | dropout_rate: the probability of dropping an element. 42 | name: the module name. 43 | """ 44 | super().__init__(name=name) 45 | assert 0 <= dropout_rate < 1.0 46 | 47 | self.dropout_rate = dropout_rate 48 | self.rng_key = next_rng_key() 49 | 50 | def __call__(self, x): 51 | """Dropout `x` randomly. 52 | 53 | Return the input `x` if in `eval` mode or `dropout_rate=0`. 54 | """ 55 | 56 | if self.training and self.dropout_rate > 0: 57 | self.rng_key, rng_key = jax.random.split(self.rng_key) 58 | return dropout(rng_key, self.dropout_rate, x) 59 | else: 60 | return x 61 | 62 | def __repr__(self): 63 | return super()._repr({"dropout_rate": self.dropout_rate}) 64 | -------------------------------------------------------------------------------- /pax/_src/nn/ema.py: -------------------------------------------------------------------------------- 1 | """EMA module.""" 2 | 3 | from typing import Any, Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from ..core import StateModule 9 | 10 | 11 | def _has_integer_leaves(x): 12 | """check if there is any interger/bool leaves""" 13 | leaves = jax.tree_util.tree_leaves(x) 14 | return not all(jnp.issubdtype(leaf, jnp.floating) for leaf in leaves) 15 | 16 | 17 | class EMA(StateModule): 18 | """Exponential Moving Average (EMA) Module""" 19 | 20 | averages: Any 21 | decay_rate: float 22 | debias: Optional[jnp.ndarray] 23 | allow_int: bool 24 | 25 | def __init__( 26 | self, 27 | initial_value, 28 | decay_rate: float, 29 | debias: bool = False, 30 | allow_int: bool = False, 31 | ): 32 | """Create a new EMA module. 33 | 34 | If allow_int=True, integer leaves are updated to 35 | the newest values instead of averaging. 36 | 37 | Arguments: 38 | initial_value: the initial value. 39 | decay_rate: the decay rate. 40 | debias: ignore the initial value to avoid biased estimates. 41 | allow_int: allow integer values. 42 | """ 43 | if not allow_int: 44 | if _has_integer_leaves(initial_value): 45 | raise ValueError( 46 | "There are integer arrays in the initial value.\n" 47 | "Use `allow_int=True` to allow this." 48 | ) 49 | 50 | super().__init__() 51 | self.averages = initial_value 52 | self.decay_rate = decay_rate 53 | self.allow_int = allow_int 54 | if debias: 55 | # avoid integer ndarray for `jax.grad` convenience, 56 | # e.g., no need to pass `allow_int=True` to `jax.grad`. 57 | self.debias = jnp.array(0.0) 58 | else: 59 | self.debias = None 60 | 61 | def __call__(self, xs): 62 | """Return the ema of `xs`. Also, update internal states.""" 63 | if not self.allow_int: 64 | if _has_integer_leaves(xs): 65 | raise ValueError( 66 | "There are integer arrays in the new value.\n" 67 | "Use `allow_int=True` to allow this." 68 | ) 69 | 70 | if self.training: 71 | if self.debias is not None: 72 | cond = self.debias > 0 73 | debias_func = lambda a, x: jnp.where(cond, a, x) 74 | self.debias = jnp.array(1.0) 75 | else: 76 | debias_func = lambda a, _: a 77 | 78 | def update_fn(a, x): 79 | if jnp.issubdtype(a, jnp.floating): 80 | a = debias_func(a, x) 81 | return a * self.decay_rate + x * (1 - self.decay_rate) 82 | else: 83 | return x 84 | 85 | self.averages = jax.tree_util.tree_map(update_fn, self.averages, xs) 86 | 87 | return self.averages 88 | -------------------------------------------------------------------------------- /pax/_src/nn/embed.py: -------------------------------------------------------------------------------- 1 | """Embed module.""" 2 | 3 | from typing import Callable, Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from ..core import ParameterModule 9 | from ..core.rng import KeyArray, next_rng_key 10 | 11 | 12 | class Embed(ParameterModule): 13 | """Embed module maps integer values to real vectors. 14 | The embedded vectors are trainable. 15 | """ 16 | 17 | weight: jnp.ndarray 18 | vocab_size: int 19 | embed_dim: int 20 | 21 | def __init__( 22 | self, 23 | vocab_size: int, 24 | embed_dim: int, 25 | w_init: Optional[Callable] = None, 26 | *, 27 | rng_key: Optional[KeyArray] = None, 28 | name: Optional[str] = None 29 | ): 30 | """ 31 | An embed module. 32 | 33 | Arguments: 34 | vocab_size: the number of embedded vectors. 35 | embed_dim: the size of embedded vectors. 36 | w_init: weight initializer. Default: `truncated_normal`. 37 | name: module name. 38 | """ 39 | 40 | super().__init__(name=name) 41 | 42 | self.vocab_size = vocab_size 43 | self.embed_dim = embed_dim 44 | shape = [vocab_size, embed_dim] 45 | 46 | if w_init is None: 47 | w_init = jax.nn.initializers.normal() 48 | 49 | if rng_key is None: 50 | rng_key = next_rng_key() 51 | 52 | self.weight = w_init(rng_key, shape) 53 | 54 | def __call__(self, x: jnp.ndarray): 55 | """Return embedded vectors indexed by ``x``.""" 56 | return self.weight[(x,)] 57 | 58 | def __repr__(self): 59 | info = {"vocab_size": self.vocab_size, "embed_dim": self.embed_dim} 60 | return self._repr(info) 61 | -------------------------------------------------------------------------------- /pax/_src/nn/identity.py: -------------------------------------------------------------------------------- 1 | """Identity module.""" 2 | 3 | from ..core import Module 4 | 5 | 6 | class Identity(Module): 7 | """Identity function as a module.""" 8 | 9 | def __call__(self, x): 10 | """return x""" 11 | return x 12 | -------------------------------------------------------------------------------- /pax/_src/nn/lambda_module.py: -------------------------------------------------------------------------------- 1 | """Lambda module.""" 2 | 3 | from ..core.utility_modules import Lambda 4 | 5 | __all__ = ("Lambda",) 6 | -------------------------------------------------------------------------------- /pax/_src/nn/layer_norm.py: -------------------------------------------------------------------------------- 1 | """LayerNorm Module.""" 2 | 3 | # The implementation is almost identical to dm-haiku LayerNorm at: 4 | # https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/layer_norm.py 5 | # deepmind/dm-haiku is licensed under the Apache License 2.0 6 | # 7 | # Differences: 8 | # 1. We need to input ``num_channels``, the size of the last dimension, 9 | # to initialize scale/offset parameters. 10 | # 2. We can input `rng_key` to seed the value of scale/offset parameters. 11 | 12 | import collections 13 | from typing import Callable, Optional, Sequence, Union 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import numpy as np 18 | 19 | from ..core import ParameterModule 20 | from ..core.rng import KeyArray, next_rng_key 21 | 22 | 23 | class LayerNorm(ParameterModule): 24 | """LayerNorm module. 25 | See: https://arxiv.org/abs/1607.06450. 26 | """ 27 | 28 | scale: Optional[jnp.ndarray] 29 | offset: Optional[jnp.ndarray] 30 | 31 | def __init__( 32 | self, 33 | num_channels: int, 34 | axis: Union[int, Sequence[int], slice], 35 | create_scale: bool, 36 | create_offset: bool, 37 | eps: float = 1e-5, 38 | scale_init: Optional[Callable] = None, 39 | offset_init: Optional[Callable] = None, 40 | *, 41 | rng_key: Optional[KeyArray] = None, 42 | name: Optional[str] = None, 43 | ): 44 | jax.nn.initializers 45 | """Constructs a LayerNorm module. 46 | 47 | Arguments: 48 | num_channels: Integer, size of the last dimension. The data format is ``[N, ..., C]``. 49 | axis: Integer, list of integers, or slice indicating which axes to normalize over. 50 | create_scale: Bool, defines whether to create a trainable scale 51 | per channel applied after the normalization. 52 | create_offset: Bool, defines whether to create a trainable offset 53 | per channel applied after normalization and scaling. 54 | eps: Small epsilon to avoid division by zero variance. 55 | Defaults ``1e-5``, as in the paper and Sonnet. 56 | scale_init: Optional initializer for gain (aka scale). By default, one. 57 | offset_init: Optional initializer for bias (aka offset). By default, zero. 58 | rng_key: RNG key. 59 | name: module name. 60 | """ 61 | super().__init__(name=name) 62 | if not create_scale and scale_init is not None: 63 | raise ValueError("Cannot set `scale_init` if `create_scale=False`.") 64 | if not create_offset and offset_init is not None: 65 | raise ValueError("Cannot set `offset_init` if `create_offset=False`.") 66 | 67 | if isinstance(axis, slice): 68 | self.axis = axis 69 | elif isinstance(axis, int): 70 | self.axis = (axis,) 71 | elif isinstance(axis, collections.abc.Iterable) and all( 72 | isinstance(ax, int) for ax in axis 73 | ): 74 | self.axis = tuple(axis) 75 | else: 76 | raise ValueError("`axis` should be an int, slice or iterable of ints.") 77 | 78 | self.eps = eps 79 | self.create_scale = create_scale 80 | self.create_offset = create_offset 81 | self.scale_init = scale_init or jax.nn.initializers.ones 82 | self.offset_init = offset_init or jax.nn.initializers.zeros 83 | self.num_channels = num_channels 84 | 85 | param_shape = [num_channels] 86 | rng_key = next_rng_key() if rng_key is None else rng_key 87 | rng1, rng2 = jax.random.split(rng_key) 88 | if create_scale: 89 | self.scale = self.scale_init(rng1, param_shape) 90 | else: 91 | self.scale = None 92 | if create_offset: 93 | self.offset = self.offset_init(rng2, param_shape) 94 | else: 95 | self.offset = None 96 | 97 | def __call__( 98 | self, 99 | inputs: jnp.ndarray, 100 | scale: Optional[jnp.ndarray] = None, 101 | offset: Optional[jnp.ndarray] = None, 102 | ) -> jnp.ndarray: 103 | """Returns normalized inputs. 104 | 105 | Arguments: 106 | inputs: An array, where the data format is ``[N, ..., C]``. 107 | scale: An array up to n-D. The shape of this tensor must be broadcastable 108 | to the shape of ``inputs``. This is the scale applied to the normalized 109 | inputs. This cannot be passed in if the module was constructed with 110 | ``create_scale=True``. 111 | offset: An array up to n-D. The shape of this tensor must be broadcastable 112 | to the shape of ``inputs``. This is the offset applied to the normalized 113 | inputs. This cannot be passed in if the module was constructed with 114 | ``create_offset=True``. 115 | 116 | Returns: 117 | The array, normalized. 118 | """ 119 | if self.create_scale and scale is not None: 120 | raise ValueError("Cannot pass `scale` at call time if `create_scale=True`.") 121 | if self.create_offset and offset is not None: 122 | raise ValueError( 123 | "Cannot pass `offset` at call time if `create_offset=True`." 124 | ) 125 | 126 | axis = self.axis 127 | if isinstance(axis, slice): 128 | axis = tuple(range(inputs.ndim)[axis]) 129 | 130 | mean = jnp.mean(inputs, axis=axis, keepdims=True) 131 | variance = jnp.var(inputs, axis=axis, keepdims=True) 132 | 133 | # param_shape = inputs.shape[-1:] 134 | if self.create_scale: 135 | scale = self.scale 136 | elif scale is None: 137 | scale = np.array(1.0, dtype=inputs.dtype) 138 | 139 | if self.create_offset: 140 | offset = self.offset 141 | elif offset is None: 142 | offset = np.array(0.0, dtype=inputs.dtype) 143 | 144 | scale = jnp.broadcast_to(scale, inputs.shape) 145 | offset = jnp.broadcast_to(offset, inputs.shape) 146 | mean = jnp.broadcast_to(mean, inputs.shape) 147 | 148 | eps = jax.lax.convert_element_type(self.eps, variance.dtype) 149 | inv = scale * jax.lax.rsqrt(variance + eps) 150 | return inv * (inputs - mean) + offset 151 | 152 | def __repr__(self, info=None) -> str: 153 | info = { 154 | "num_channels": self.num_channels, 155 | "axis": self.axis, 156 | "create_scale": self.create_scale, 157 | "create_offset": self.create_offset, 158 | } 159 | return self._repr(info) 160 | -------------------------------------------------------------------------------- /pax/_src/nn/linear.py: -------------------------------------------------------------------------------- 1 | """Linear module.""" 2 | 3 | from typing import Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from ..core import ParameterModule 10 | from ..core.rng import KeyArray, next_rng_key 11 | 12 | 13 | class Linear(ParameterModule): 14 | """A linear transformation is applied over the last dimension of the input.""" 15 | 16 | weight: jnp.ndarray 17 | bias: jnp.ndarray 18 | 19 | in_dim: int 20 | out_dim: int 21 | with_bias: bool 22 | 23 | def __init__( 24 | self, 25 | in_dim: int, 26 | out_dim: int, 27 | with_bias: bool = True, 28 | w_init=None, 29 | b_init=None, 30 | *, 31 | rng_key: KeyArray = None, 32 | name: Optional[str] = None, 33 | ): 34 | """ 35 | Arguments: 36 | in_dim: the number of input features. 37 | out_dim: the number of output features. 38 | with_bias: whether to add a bias to the output (default: True). 39 | w_init: initializer function for the weight matrix. 40 | b_init: initializer function for the bias. 41 | rng_key: the key to generate initial parameters. 42 | name: module name. 43 | """ 44 | super().__init__(name=name) 45 | self.in_dim = in_dim 46 | self.out_dim = out_dim 47 | self.with_bias = with_bias 48 | 49 | rng_key = next_rng_key() if rng_key is None else rng_key 50 | if w_init is None: 51 | w_init = jax.nn.initializers.normal(stddev=1.0 / np.sqrt(self.in_dim)) 52 | if b_init is None: 53 | b_init = jax.nn.initializers.normal(stddev=1.0 / np.sqrt(self.in_dim)) 54 | rng_key_w, rng_key_b = jax.random.split(rng_key) 55 | self.weight = w_init(rng_key_w, (in_dim, out_dim)) 56 | if self.with_bias: 57 | self.bias = b_init(rng_key_b, (out_dim,)) 58 | 59 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 60 | """Applies a linear transformation to the inputs along the last dimension. 61 | 62 | Arguments: 63 | x: The nd-array to be transformed. 64 | """ 65 | assert len(x.shape) >= 2, "expecting an input of shape `N...C`" 66 | x = jnp.dot(x, self.weight) 67 | if self.with_bias: 68 | x = x + self.bias 69 | return x 70 | 71 | def __repr__(self): 72 | info = { 73 | "in_dim": self.in_dim, 74 | "out_dim": self.out_dim, 75 | "with_bias": self.with_bias, 76 | } 77 | return self._repr(info) 78 | -------------------------------------------------------------------------------- /pax/_src/nn/pool.py: -------------------------------------------------------------------------------- 1 | # Source: https://raw.githubusercontent.com/deepmind/dm-haiku/main/haiku/_src/pool.py 2 | # 3 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """Pooling Haiku modules.""" 18 | 19 | import warnings 20 | from typing import Optional, Sequence, Tuple, Union 21 | 22 | import jax.numpy as jnp 23 | import numpy as np 24 | from jax import lax 25 | 26 | 27 | def _infer_shape( 28 | x: jnp.ndarray, 29 | size: Union[int, Sequence[int]], 30 | channel_axis: Optional[int] = -1, 31 | ) -> Tuple[int, ...]: 32 | """Infer shape for pooling window or strides.""" 33 | if isinstance(size, int): 34 | if channel_axis and not 0 <= abs(channel_axis) < x.ndim: 35 | raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}") 36 | if channel_axis and channel_axis < 0: 37 | channel_axis = x.ndim + channel_axis 38 | return (1,) + tuple(size if d != channel_axis else 1 for d in range(1, x.ndim)) 39 | elif len(size) < x.ndim: 40 | # Assume additional dimensions are batch dimensions. 41 | return (1,) * (x.ndim - len(size)) + tuple(size) 42 | else: 43 | assert x.ndim == len(size) 44 | return tuple(size) 45 | 46 | 47 | _VMAP_SHAPE_INFERENCE_WARNING = ( 48 | "When running under vmap, passing an `int` (except for `1`) for " 49 | "`window_shape` or `strides` will result in the wrong shape being inferred " 50 | "because the batch dimension is not visible to Haiku. Please update your " 51 | "code to specify a full unbatched size. " 52 | "" 53 | "For example if you had `pool(x, window_shape=3, strides=1)` before, you " 54 | "should now pass `pool(x, window_shape=(3, 3, 1), strides=1)`. " 55 | "" 56 | "Haiku will assume that any additional dimensions in your input are " 57 | "batch dimensions, and will pad `window_shape` and `strides` accordingly " 58 | "making your module support both batched and per-example inputs." 59 | ) 60 | 61 | 62 | def _warn_if_unsafe(window_shape, strides): 63 | unsafe = lambda size: isinstance(size, int) and size != 1 64 | if unsafe(window_shape) or unsafe(strides): 65 | warnings.warn(_VMAP_SHAPE_INFERENCE_WARNING, DeprecationWarning) 66 | 67 | 68 | def max_pool( 69 | value: jnp.ndarray, 70 | window_shape: Union[int, Sequence[int]], 71 | strides: Union[int, Sequence[int]], 72 | padding: str, 73 | channel_axis: Optional[int] = -1, 74 | ) -> jnp.ndarray: 75 | """Max pool. 76 | 77 | Args: 78 | value: Value to pool. 79 | window_shape: Shape of the pooling window, an int or same rank as value. 80 | strides: Strides of the pooling window, an int or same rank as value. 81 | padding: Padding algorithm. Either ``VALID`` or ``SAME``. 82 | channel_axis: Axis of the spatial channels for which pooling is skipped, 83 | used to infer ``window_shape`` or ``strides`` if they are an integer. 84 | 85 | Returns: 86 | Pooled result. Same rank as value. 87 | """ 88 | if padding not in ("SAME", "VALID"): 89 | raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") 90 | 91 | _warn_if_unsafe(window_shape, strides) 92 | window_shape = _infer_shape(value, window_shape, channel_axis) 93 | strides = _infer_shape(value, strides, channel_axis) 94 | 95 | return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides, padding) 96 | 97 | 98 | def avg_pool( 99 | value: jnp.ndarray, 100 | window_shape: Union[int, Sequence[int]], 101 | strides: Union[int, Sequence[int]], 102 | padding: str, 103 | channel_axis: Optional[int] = -1, 104 | ) -> jnp.ndarray: 105 | """Average pool. 106 | 107 | Args: 108 | value: Value to pool. 109 | window_shape: Shape of the pooling window, an int or same rank as value. 110 | strides: Strides of the pooling window, an int or same rank as value. 111 | padding: Padding algorithm. Either ``VALID`` or ``SAME``. 112 | channel_axis: Axis of the spatial channels for which pooling is skipped, 113 | used to infer ``window_shape`` or ``strides`` if they are an integer. 114 | 115 | Returns: 116 | Pooled result. Same rank as value. 117 | 118 | Raises: 119 | ValueError: If the padding is not valid. 120 | """ 121 | if padding not in ("SAME", "VALID"): 122 | raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") 123 | 124 | _warn_if_unsafe(window_shape, strides) 125 | window_shape = _infer_shape(value, window_shape, channel_axis) 126 | strides = _infer_shape(value, strides, channel_axis) 127 | 128 | reduce_window_args = (0.0, lax.add, window_shape, strides, padding) 129 | pooled = lax.reduce_window(value, *reduce_window_args) 130 | if padding == "VALID": 131 | # Avoid the extra reduce_window. 132 | return pooled / np.prod(window_shape) 133 | else: 134 | # Count the number of valid entries at each input point, then use that for 135 | # computing average. Assumes that any two arrays of same shape will be 136 | # padded the same. 137 | window_counts = lax.reduce_window(jnp.ones_like(value), *reduce_window_args) 138 | assert pooled.shape == window_counts.shape 139 | return pooled / window_counts 140 | -------------------------------------------------------------------------------- /pax/_src/nn/recurrent.py: -------------------------------------------------------------------------------- 1 | """Recurrent modules.""" 2 | 3 | from typing import Callable, NamedTuple, Optional, Tuple 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from ..core import Module 9 | from ..core.rng import KeyArray, next_rng_key 10 | from .linear import Linear 11 | 12 | 13 | class LSTMState(NamedTuple): 14 | """LSTMState.""" 15 | 16 | hidden: jnp.ndarray 17 | cell: jnp.ndarray 18 | 19 | 20 | class GRUState(NamedTuple): 21 | """GRUState.""" 22 | 23 | hidden: jnp.ndarray 24 | 25 | 26 | class VanillaRNNState(NamedTuple): 27 | """VanillaRNNState.""" 28 | 29 | hidden: jnp.ndarray 30 | 31 | 32 | class RNN(Module): 33 | """Base class for all recurrent modules.""" 34 | 35 | def __init__(self, name: Optional[str] = None): 36 | super().__init__(name=name) 37 | 38 | def initial_state(self, batch_size): 39 | raise NotImplementedError() 40 | 41 | 42 | class VanillaRNN(RNN): 43 | """Basic recurrent neural network.""" 44 | 45 | input_dim: int 46 | hidden_dim: int 47 | fc: Linear 48 | 49 | def __init__( 50 | self, 51 | input_dim: int, 52 | hidden_dim: int, 53 | *, 54 | rng_key: KeyArray = None, 55 | name: Optional[str] = None 56 | ): 57 | """Create a vanilla RNN module. 58 | 59 | Arguments: 60 | input_dim: input dimension. 61 | hidden_dim: hidden dimension. 62 | rng_key: random key. 63 | name: module name. 64 | """ 65 | super().__init__(name=name) 66 | self.input_dim = input_dim 67 | self.hidden_dim = hidden_dim 68 | self.fc = Linear( 69 | input_dim + hidden_dim, 70 | hidden_dim, 71 | rng_key=rng_key, 72 | name="vanilla_rnn_fc", 73 | ) 74 | 75 | def __call__( 76 | self, state: VanillaRNNState, x: jnp.ndarray 77 | ) -> Tuple[VanillaRNNState, jnp.ndarray]: 78 | """A single rnn step.""" 79 | xh = jnp.concatenate((x, state.hidden), axis=-1) 80 | hidden = jnp.tanh(self.fc(xh)) 81 | return VanillaRNNState(hidden), hidden 82 | 83 | def __repr__(self): 84 | info = {"input_dim": self.input_dim, "hidden_dim": self.hidden_dim} 85 | return self._repr(info) 86 | 87 | def initial_state(self, batch_size) -> VanillaRNNState: 88 | shape = (batch_size, self.hidden_dim) 89 | hidden = jnp.zeros(shape=shape, dtype=jnp.float32) 90 | return VanillaRNNState(hidden=hidden) 91 | 92 | 93 | class LSTM(RNN): 94 | """Long Short Term Memory (LSTM) RNN module.""" 95 | 96 | input_dim: int 97 | hidden_dim: int 98 | 99 | weight: jnp.ndarray 100 | bias: jnp.ndarray 101 | 102 | def __init__( 103 | self, 104 | input_dim: int, 105 | hidden_dim: int, 106 | w_init: Optional[Callable] = None, 107 | forget_gate_bias: float = 0.0, 108 | *, 109 | rng_key: KeyArray = None, 110 | name: Optional[str] = None 111 | ): 112 | """Create a LSTM module. 113 | 114 | Arguments: 115 | input_dim: The input dimension. 116 | hidden_dim: The number of LSTM cells. 117 | w_init: weight initializer. 118 | forget_gate_bias: Prefer forget. Default `0`. 119 | rng_key: random key. 120 | name: module name. 121 | """ 122 | 123 | super().__init__(name=name) 124 | self.input_dim = input_dim 125 | self.hidden_dim = hidden_dim 126 | self.forget_gate_bias = forget_gate_bias 127 | 128 | self.fc = Linear( 129 | (input_dim + hidden_dim), 130 | 4 * hidden_dim, 131 | rng_key=rng_key, 132 | name="lstm_fc", 133 | w_init=w_init, 134 | ) 135 | 136 | def __call__( 137 | self, 138 | state: LSTMState, 139 | x: jnp.ndarray, 140 | ) -> Tuple[LSTMState, jnp.ndarray]: 141 | """Do a single lstm step. 142 | 143 | 144 | Arguments: 145 | state: The current LSTM state. 146 | x: The input. 147 | """ 148 | xh = jnp.concatenate((x, state.hidden), axis=-1) 149 | gated = self.fc(xh) 150 | i, g, f, o = jnp.split(gated, 4, axis=-1) 151 | f = jax.nn.sigmoid(f + self.forget_gate_bias) 152 | c = f * state.cell + jax.nn.sigmoid(i) * jnp.tanh(g) 153 | h = jax.nn.sigmoid(o) * jnp.tanh(c) 154 | return LSTMState(h, c), h 155 | 156 | def __repr__(self): 157 | info = {"input_dim": self.input_dim, "hidden_dim": self.hidden_dim} 158 | return self._repr(info) 159 | 160 | def initial_state(self, batch_size) -> LSTMState: 161 | shape = (batch_size, self.hidden_dim) 162 | hidden = jnp.zeros(shape=shape, dtype=jnp.float32) 163 | cell = jnp.zeros(shape=shape, dtype=jnp.float32) 164 | return LSTMState(hidden=hidden, cell=cell) 165 | 166 | 167 | class GRU(RNN): 168 | """This class implements the "fully gated unit" GRU. 169 | 170 | Reference: https://en.wikipedia.org/wiki/Gated_recurrent_unit 171 | """ 172 | 173 | input_dim: int 174 | hidden_dim: int 175 | 176 | def __init__( 177 | self, 178 | input_dim: int, 179 | hidden_dim: int, 180 | *, 181 | rng_key: Optional[KeyArray] = None, 182 | name: Optional[str] = None 183 | ): 184 | """Create a GRU module. 185 | 186 | Arguments: 187 | input_dim: the input size. 188 | hidden_dim: the number of GRU cells. 189 | """ 190 | super().__init__(name=name) 191 | 192 | self.input_dim = input_dim 193 | self.hidden_dim = hidden_dim 194 | 195 | if rng_key is None: 196 | rng_key = next_rng_key() 197 | rng_key_1, rng_key_2 = jax.random.split(rng_key, 2) 198 | self.xh_zr_fc = Linear( 199 | (input_dim + hidden_dim), hidden_dim * 2, name="xh_to_zr", rng_key=rng_key_1 200 | ) 201 | 202 | self.xh_h_fc = Linear( 203 | (input_dim + hidden_dim), hidden_dim, name="xh_to_h", rng_key=rng_key_2 204 | ) 205 | 206 | def initial_state(self, batch_size: int) -> GRUState: 207 | """Create an all zeros initial state.""" 208 | return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32)) 209 | 210 | def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]: 211 | """Do a single gru step. 212 | 213 | Arguments: 214 | state: The current GRU state. 215 | x: The input. 216 | """ 217 | hidden = state.hidden 218 | xh = jnp.concatenate((x, hidden), axis=-1) 219 | zr = jax.nn.sigmoid(self.xh_zr_fc(xh)) 220 | z, r = jnp.split(zr, 2, axis=-1) 221 | 222 | xrh = jnp.concatenate((x, r * hidden), axis=-1) 223 | h_hat = jnp.tanh(self.xh_h_fc(xrh)) 224 | h = (1 - z) * hidden + z * h_hat 225 | return GRUState(h), h 226 | 227 | def __repr__(self): 228 | info = {"input_dim": self.input_dim, "hidden_dim": self.hidden_dim} 229 | return self._repr(info) 230 | -------------------------------------------------------------------------------- /pax/_src/nn/rng_seq.py: -------------------------------------------------------------------------------- 1 | """RngSeq module.""" 2 | 3 | from typing import Optional, Sequence, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from ..core import StateModule, rng 10 | 11 | 12 | class RngSeq(StateModule): 13 | """A module which generates an infinite sequence of rng keys.""" 14 | 15 | _rng_key: rng.KeyArray 16 | 17 | def __init__( 18 | self, seed: Optional[int] = None, rng_key: Optional[rng.KeyArray] = None 19 | ): 20 | """Initialize a random key sequence. 21 | 22 | **Note**: ``rng_key`` has a higher priority than ``seed``. 23 | 24 | Arguments: 25 | seed: an integer seed. 26 | rng_key: a jax random key. 27 | """ 28 | super().__init__() 29 | if rng_key is not None: 30 | rng_key_ = rng_key 31 | elif seed is not None: 32 | rng_key_ = jax.random.PRNGKey(seed) 33 | else: 34 | rng_key_ = rng.next_rng_key() 35 | 36 | if isinstance(rng_key_, (np.ndarray, jnp.ndarray)): 37 | self._rng_key = rng_key_ 38 | else: 39 | raise ValueError("Impossible") 40 | 41 | def next_rng_key( 42 | self, num_keys: int = 1 43 | ) -> Union[rng.KeyArray, Sequence[rng.KeyArray]]: 44 | """Return the next random key of the sequence. 45 | 46 | **Note**: 47 | 48 | * Return a key if ``num_keys`` is ``1``, 49 | * Return a list of keys if ``num_keys`` is greater than ``1``. 50 | * This is not a deterministic sequence if values of ``num_keys`` are mixed randomly. 51 | 52 | Arguments: 53 | num_keys: return more than one key. 54 | """ 55 | self._rng_key, *rng_keys = jax.random.split(self._rng_key, num_keys + 1) 56 | return rng_keys[0] if num_keys == 1 else rng_keys 57 | -------------------------------------------------------------------------------- /pax/_src/nn/sequential.py: -------------------------------------------------------------------------------- 1 | """Sequential module.""" 2 | 3 | from typing import Optional, Tuple, TypeVar 4 | 5 | from ..core import Module 6 | from .lambda_module import Lambda 7 | 8 | T = TypeVar("T", bound=Module) 9 | 10 | 11 | class Sequential(Module): 12 | """Execute layers in order. 13 | 14 | Support pax.Module (callable pytree) and any jax functions. 15 | 16 | For example: 17 | 18 | >>> net = pax.Sequential( 19 | ... pax.Linear(2, 32), 20 | ... jax.nn.relu, 21 | ... pax.Linear(32, 3) 22 | ... ) 23 | >>> print(net.summary()) 24 | Sequential 25 | ├── Linear(in_dim=2, out_dim=32, with_bias=True) 26 | ├── x => relu(x) 27 | └── Linear(in_dim=32, out_dim=3, with_bias=True) 28 | >>> x = jnp.empty((3, 2)) 29 | >>> y = net(x) 30 | >>> y.shape 31 | (3, 3) 32 | """ 33 | 34 | # Note: we cannot mix pax.Module and jax functions (e.g., jax.nn.relu) in the same list. 35 | # therefore, we have to convert a jax function to ``Lambda`` module first. 36 | modules: Tuple[Module, ...] 37 | 38 | def __init__(self, *layers, name: Optional[str] = None): 39 | """Create a Sequential module.""" 40 | super().__init__(name=name) 41 | self.modules = tuple( 42 | (f if isinstance(f, Module) else Lambda(f)) for f in layers 43 | ) 44 | 45 | def __call__(self, x): 46 | """Call layers in order.""" 47 | for f in self.modules: 48 | x = f(x) 49 | return x 50 | 51 | def __getitem__(self, index: int) -> T: 52 | """Get an item from the `modules` list.""" 53 | return self.modules[index] 54 | 55 | def set(self: T, index: int, value) -> T: 56 | """Set an item to the `modules` list.""" 57 | if not isinstance(value, Module): 58 | value = Lambda(value) 59 | 60 | modules = list(self.modules) 61 | modules[index] = value 62 | return super().replace(modules=tuple(modules)) 63 | 64 | def __rshift__(self, other: Module): 65 | return Sequential(*self.modules, other) 66 | -------------------------------------------------------------------------------- /pax/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Experimental API""" 2 | 3 | 4 | from pax._src.core import Flattener, LazyModule, mutable 5 | from pax._src.utils import ( 6 | apply_scaled_gradients, 7 | default_mp_policy, 8 | load_weights_from_dict, 9 | save_weights_to_dict, 10 | ) 11 | 12 | from . import graph 13 | 14 | __all__ = ( 15 | "apply_scaled_gradients", 16 | "default_mp_policy", 17 | "Flattener", 18 | "graph", 19 | "LazyModule", 20 | "load_weights_from_dict", 21 | "mutable", 22 | "save_weights_to_dict", 23 | ) 24 | -------------------------------------------------------------------------------- /pax/experimental/graph.py: -------------------------------------------------------------------------------- 1 | """Experimental graph API""" 2 | 3 | 4 | from pax._src.core.graph_module import GraphModule, InputNode, Node, build_graph_module 5 | 6 | __all__ = ( 7 | "build_graph_module", 8 | "GraphModule", 9 | "InputNode", 10 | "Node", 11 | ) 12 | -------------------------------------------------------------------------------- /pax/nets.py: -------------------------------------------------------------------------------- 1 | """Public nets.""" 2 | 3 | from pax._src.nets import ( 4 | ResNet18, 5 | ResNet34, 6 | ResNet50, 7 | ResNet101, 8 | ResNet152, 9 | ResNet200, 10 | Transformer, 11 | ) 12 | 13 | __all__ = ( 14 | "ResNet18", 15 | "ResNet34", 16 | "ResNet50", 17 | "ResNet101", 18 | "ResNet152", 19 | "ResNet200", 20 | "Transformer", 21 | ) 22 | -------------------------------------------------------------------------------- /pax/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTT123/pax/13916cb86ede38c56750cf1bde3ac37c63674014/pax/py.typed -------------------------------------------------------------------------------- /pax/utils.py: -------------------------------------------------------------------------------- 1 | """Public utility functions.""" 2 | 3 | from pax._src.utils import build_update_fn, grad, scan 4 | 5 | __all__ = ( 6 | "build_update_fn", 7 | "grad", 8 | "scan", 9 | ) 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup PAX3 package.""" 2 | 3 | from setuptools import find_namespace_packages, setup 4 | 5 | 6 | def _get_version(): 7 | with open("pax/__init__.py", encoding="utf-8") as file: 8 | for line in file: 9 | if line.startswith("__version__"): 10 | _globals = {} 11 | exec(line, _globals) # pylint: disable=exec-used 12 | return _globals["__version__"] 13 | raise ValueError("`__version__` not defined in `pax/__init__.py`") 14 | 15 | 16 | __version__ = _get_version() 17 | URL = "https://github.com/ntt123/pax" 18 | 19 | install_requires = ["jax>=0.2.21", "jmp>=0.0.2"] 20 | setup_requires = [] 21 | tests_requires = [ 22 | "chex", 23 | "dm-haiku", 24 | "fire", 25 | "opax", 26 | "pytest", 27 | "pytype", 28 | "tqdm", 29 | ] 30 | 31 | setup( 32 | name="pax3", 33 | version=__version__, 34 | description="A stateful pytree library for training neural networks.", 35 | long_description=open("README.md", encoding="utf-8").read(), 36 | long_description_content_type="text/markdown", 37 | author="Thông Nguyễn", 38 | url=URL, 39 | keywords=[ 40 | "deep-learning", 41 | "jax", 42 | ], 43 | install_requires=install_requires, 44 | setup_requires=setup_requires, 45 | tests_require=tests_requires, 46 | packages=find_namespace_packages(exclude=["examples", "tests"]), 47 | extras_require={"test": tests_requires}, 48 | python_requires=">=3.7", 49 | include_package_data=True, 50 | zip_safe=False, 51 | ) 52 | -------------------------------------------------------------------------------- /tests/test_auto_modules.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pax 4 | 5 | # import pytest 6 | 7 | 8 | def test_scan_bug_param_module(): 9 | class M(pax.ParameterModule): 10 | def __init__(self): 11 | super().__init__() 12 | self.a = jnp.array(0.0) 13 | 14 | # with pytest.raises(ValueError): 15 | _ = M() 16 | 17 | 18 | def test_scan_bug_state_module(): 19 | class M(pax.StateModule): 20 | def __init__(self): 21 | super().__init__() 22 | self.a = jnp.array(0.0) 23 | 24 | # with pytest.raises(ValueError): 25 | _ = M() 26 | 27 | 28 | def test_auto_module(): 29 | class M(pax.experimental.LazyModule): 30 | def __call__(self, x): 31 | x = self.get_or_create("fc", lambda: pax.Linear(1, 1))(x) 32 | x = jax.nn.relu(x) 33 | return x 34 | 35 | m = M() 36 | x = jnp.ones((2, 1)) 37 | m, _ = pax.module_and_value(m)(x) 38 | print(m.summary()) 39 | -------------------------------------------------------------------------------- /tests/test_counter.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pax 4 | 5 | 6 | def test_counter(): 7 | class Counter(pax.Module): 8 | counter: jnp.ndarray 9 | bias: jnp.ndarray 10 | parameters = pax.parameters_method("counter") 11 | 12 | def __init__(self, start_value: int = 0): 13 | super().__init__() 14 | 15 | self.counter = jnp.array(start_value, dtype=jnp.int32) 16 | self.bias = jnp.array(0.0) 17 | 18 | def __call__(self, x): 19 | self.counter = self.counter + 1 20 | return self.counter * x + self.bias 21 | 22 | @pax.pure 23 | def loss_fn(model: Counter, x: jnp.ndarray): 24 | y = model(x) 25 | loss = jnp.mean(jnp.square(x - y)) 26 | return loss, (loss, model) 27 | 28 | grad_fn = jax.grad(loss_fn, has_aux=True, allow_int=True) 29 | 30 | net = Counter(3) 31 | x = jnp.array(10.0) 32 | grads, (loss, net) = grad_fn(net, x) 33 | assert grads.counter.dtype is jax.float0 34 | assert grads.bias.item() == 60.0 35 | -------------------------------------------------------------------------------- /tests/test_deepscan.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pax 4 | import pytest 5 | 6 | 7 | def test_list_of_mod(): 8 | class M(pax.Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.a = [pax.Linear(3, 3)] 12 | 13 | m = M() 14 | # m.pax.name_to_kind["a"] == pax.PaxKind.MODULE 15 | 16 | 17 | def test_assigned_field_an_array(): 18 | class M(pax.ParameterModule): 19 | def __init__(self): 20 | super().__init__() 21 | self.a = np.array([3.0, 1.0], dtype=np.float32) 22 | 23 | # no error because we will automatically assign `a` to kind PARAMETER 24 | m = M() 25 | # assert m.pax.name_to_kind["a"] == pax.PaxKind.PARAMETER 26 | 27 | class N(pax.Module): 28 | def __init__(self): 29 | super().__init__() 30 | 31 | n = N() 32 | 33 | n.scan_bugs() 34 | # no error because we will automatically assign `a` to kind PARAMETER 35 | def mutate(n: N) -> N: 36 | n.b = jnp.array([1, 2, 3], dtype=jnp.float32) 37 | return n 38 | 39 | n = pax.pure(mutate)(n) 40 | assert "b" in n.pytree_attributes 41 | 42 | # assert n.pax.name_to_kind["b"] == pax.PaxKind.PARAMETER 43 | 44 | 45 | def test_assign_int_to_param(): 46 | class M(pax.ParameterModule): 47 | def __init__(self): 48 | super().__init__() 49 | self.a = np.array([3, 1], dtype=np.int32) 50 | 51 | _ = M() 52 | 53 | 54 | def test_assign_int_to_param_deepscan(): 55 | class M(pax.Module): 56 | def __init__(self): 57 | super().__init__() 58 | self.a = np.array([3, 1], dtype=np.int32) 59 | 60 | _ = M() 61 | # m = pax.freeze_parameters(m) 62 | # d = OrderedDict(m.name_to_kind) 63 | # d["a"] = pax.module.PaxKind.PARAMETER 64 | # m.__dict__["name_to_kind"] = MappingProxyType(d) 65 | # m = pax.scan_bugs(m) 66 | 67 | 68 | # def test_jit_(): 69 | # class M(pax.Module): 70 | # def __init__(self): 71 | # super().__init__() 72 | # self.a_list = [pax.Linear(2, 2)] 73 | 74 | # def __call__(self, x): 75 | # self.a_list.append(0) 76 | # return x 77 | 78 | # m = M() 79 | 80 | # @pax.jit_ 81 | # def fwd(m, x): 82 | # return m(x) 83 | 84 | # with pytest.raises(ValueError): 85 | # x = fwd(m, jnp.zeros((2, 2))) 86 | -------------------------------------------------------------------------------- /tests/test_finetune.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import opax 7 | import pax 8 | 9 | 10 | def test_finetune(): 11 | pax.seed_rng_key(42) 12 | 13 | class MLP(pax.Module): 14 | layers: List[pax.Linear] 15 | 16 | def __init__(self, dims: List[int]): 17 | super().__init__() 18 | layers = [] 19 | for in_dim, out_dim in zip(dims[:-1], dims[1:]): 20 | layers.append(pax.Linear(in_dim, out_dim)) 21 | self.layers = layers 22 | 23 | def __call__(self, x): 24 | for f in self.layers: 25 | x = f(x) 26 | x = jax.nn.sigmoid(x) 27 | return x 28 | 29 | net = MLP([10, 2, 2, 2, 10]) 30 | 31 | @pax.pure 32 | def loss_fn(params: MLP, model: MLP, x): 33 | model = pax.update_parameters(model, params=params) 34 | y = model(x) 35 | loss = jnp.mean(jnp.square(x - y)) 36 | return loss, (loss, model) 37 | 38 | x = jax.random.normal(pax.next_rng_key(), (1, 10)) 39 | 40 | # make all layers non-trainable except the last layer. 41 | for i in range(len(net.layers) - 1): 42 | net.layers[i] = pax.freeze_parameters(net.layers[i]) 43 | 44 | # net.layers[-1] = pax.Linear(2, 10) 45 | optimizer = opax.adam(1e-2)(net.parameters()) 46 | 47 | @jax.jit 48 | def update_fn(model, optimizer, x): 49 | params = model.parameters() 50 | grads, (loss, model) = jax.grad(loss_fn, has_aux=True)(params, model, x) 51 | model, optimizer = opax.apply_gradients(model, optimizer, grads=grads) 52 | return model, optimizer, loss 53 | 54 | old_layers = net.layers 55 | for i in range(100): 56 | net, optimizer, loss = update_fn(net, optimizer, x) 57 | if i % 10 == 0: 58 | print(f"[step {i:03d}] loss {loss:.3f}") 59 | new_layers = net.layers 60 | 61 | for i in range(len(net.layers) - 1): 62 | np.testing.assert_array_equal(old_layers[i].weight, new_layers[i].weight) 63 | 64 | np.testing.assert_raises( 65 | AssertionError, 66 | np.testing.assert_array_equal, 67 | old_layers[-1].weight, 68 | new_layers[-1].weight, 69 | ) 70 | -------------------------------------------------------------------------------- /tests/test_freeze_unfreeze.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import pax 3 | 4 | 5 | def test_freeze_really_working(): 6 | a = pax.Sequential( 7 | pax.Linear(3, 3), 8 | pax.Linear(5, 5), 9 | ) 10 | b = pax.freeze_parameters(a) 11 | # assert b[0].pax.name_to_kind["weight"] == pax.PaxKind.STATE 12 | # assert a[0].pax.name_to_kind["weight"] == pax.PaxKind.PARAMETER 13 | 14 | 15 | def test_freeze_mapping_proxy(): 16 | a = pax.Sequential( 17 | pax.Linear(3, 3), 18 | pax.Linear(5, 5), 19 | ) 20 | b = pax.freeze_parameters(a) 21 | # assert isinstance(b.pax.name_to_kind, MappingProxyType), "expecting a proxy map" 22 | 23 | 24 | def test_freeze_twice(): 25 | a = pax.Linear(2, 2) 26 | # with pytest.raises(ValueError): 27 | _ = pax.freeze_parameters(pax.freeze_parameters(a)) 28 | 29 | 30 | # def test_freeze_unfreeze(): 31 | # a = pax.Sequential( 32 | # pax.Linear(2, 2), 33 | # pax.Linear(3, 3), 34 | # pax.Linear(4, 4), 35 | # pax.Linear(5, 5), 36 | # ) 37 | 38 | # b = pax.freeze_parameters(a) 39 | # c = pax.unfreeze_parameters(b, origin=a) 40 | # # pylint: disable=-access 41 | # # assert a[0].pax.name_to_kind is c[0].pax.name_to_kind 42 | 43 | 44 | def test_copy(): 45 | a = pax.Linear(1, 1, with_bias=False) 46 | b = pax.enable_eval_mode(a) 47 | assert jax.tree_util.tree_structure(a) != jax.tree_util.tree_structure(b) 48 | c = pax.enable_train_mode(b) 49 | assert jax.tree_util.tree_structure(a) == jax.tree_util.tree_structure(c) 50 | -------------------------------------------------------------------------------- /tests/test_graph_module.py: -------------------------------------------------------------------------------- 1 | """Test graph module""" 2 | 3 | import copy 4 | from functools import partial 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import pax 9 | import pytest 10 | from pax.experimental.graph import GraphModule, InputNode, build_graph_module 11 | 12 | 13 | def test_simple_graph(): 14 | x = InputNode(jnp.zeros((3, 3))) 15 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 16 | assert y.value.shape == (3, 4) 17 | 18 | 19 | def test_cat_graph(): 20 | x = InputNode(jnp.zeros((3, 3))) 21 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 22 | z = x & y 23 | t = z >> partial(jnp.concatenate, axis=-1) 24 | assert t.value.shape == (3, 7) 25 | 26 | 27 | def test_cat_merge_left(): 28 | x = InputNode(jnp.zeros((3, 3))) 29 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 30 | q = y & y 31 | z = q & x 32 | assert z.parents == (y, y, x) 33 | 34 | 35 | def test_cat_merge_right(): 36 | x = InputNode(jnp.zeros((3, 3))) 37 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 38 | q = y & y 39 | z = x & q 40 | assert z.parents == (x, y, y) 41 | 42 | 43 | def test_merge_2_cat(): 44 | x = InputNode(jnp.zeros((3, 3))) 45 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 46 | q = y & y 47 | t = x & x 48 | k = q & t 49 | assert k.parents == (y, y, x, x) 50 | 51 | 52 | def test_3_cat_graph(): 53 | x = InputNode(jnp.zeros((3, 3))) 54 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 55 | z = x & y & x 56 | t = z >> partial(jnp.concatenate, axis=-1) 57 | assert t.value.shape == (3, 10) 58 | 59 | 60 | def test_3_cat_graph_module(): 61 | x = InputNode(jnp.zeros((3, 3))) 62 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 63 | z = x & y & y 64 | t = z >> partial(jnp.concatenate, axis=-1) 65 | _ = GraphModule((x,), t) 66 | 67 | 68 | def test_or_graph(): 69 | x = InputNode(jnp.zeros((3, 3))) 70 | y = x >> pax.Linear(3, 3) >> jax.nn.relu 71 | z = (x | y) >> jax.lax.add 72 | assert z.value.shape == (3, 3) 73 | 74 | 75 | def test_merge_2_or(): 76 | x = InputNode(jnp.zeros((3, 3))) 77 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 78 | q = y | y 79 | t = x | x 80 | k = t | q 81 | assert k.parents == (x, x, y, y) 82 | 83 | 84 | def test_or_merge_left(): 85 | x = InputNode(jnp.zeros((3, 3))) 86 | y = x >> pax.Linear(3, 3) >> jax.nn.relu 87 | z = x | y 88 | t = z | x 89 | assert t.parents == (x, y, x) 90 | 91 | 92 | def test_or_merge_right(): 93 | x = InputNode(jnp.zeros((3, 3))) 94 | y = x >> pax.Linear(3, 3) >> jax.nn.relu 95 | z = x | y 96 | t = x | z 97 | assert t.parents == (x, x, y) 98 | 99 | 100 | def test_cat_graph_merge(): 101 | x = InputNode(jnp.zeros((3, 3))) 102 | y = x >> pax.Linear(3, 4) >> jax.nn.relu 103 | q = y | y 104 | z = x | q 105 | assert z.parents == (x, y, y) 106 | 107 | 108 | def test_binops(): 109 | x = InputNode(jnp.ones((3, 3))) 110 | y = x.binary_ops(jax.lax.add, x) 111 | assert y.parents == (x, x) 112 | assert jnp.array_equal(y.fx((x.value, x.value)), jnp.ones((3, 3)) * 2) 113 | assert jnp.array_equal(y.value, jnp.ones((3, 3)) * 2) 114 | 115 | 116 | def test_type_shape(): 117 | x = InputNode(jnp.ones((3, 3), dtype=jnp.int32)) 118 | assert x.shape == (3, 3) 119 | assert x.dtype == jnp.int32 120 | 121 | 122 | def test_build_residual_net(): 123 | def residual(x): 124 | y = x >> pax.Linear(3, 3) >> jax.nn.relu 125 | t = x >> pax.Linear(3, 3) >> jax.nn.tanh 126 | z = (y | t) >> jax.lax.add 127 | return z 128 | 129 | x = jnp.empty((1, 3)) 130 | net = build_graph_module(residual)(x) 131 | y = net(x) 132 | assert y.shape == (1, 3) 133 | 134 | 135 | def test_reuse_module_error(): 136 | def reuse(x): 137 | mod = pax.Linear(3, 3) 138 | y = x >> mod >> jax.nn.relu 139 | t = x >> mod 140 | z = (y | t) >> jax.lax.add 141 | return z 142 | 143 | x = jnp.empty((1, 3)) 144 | with pytest.raises(ValueError): 145 | _ = build_graph_module(reuse)(x) 146 | 147 | 148 | def test_copy_error(): 149 | x = InputNode(jnp.empty((3, 3))) 150 | with pytest.raises(TypeError): 151 | _ = copy.copy(x) 152 | 153 | 154 | def test_deepcopy_error(): 155 | x = InputNode(jnp.empty((3, 3))) 156 | with pytest.raises(TypeError): 157 | _ = copy.deepcopy(x) 158 | -------------------------------------------------------------------------------- /tests/test_immutability.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pax 3 | import pytest 4 | 5 | 6 | def test_immutability(): 7 | f = pax.Linear(3, 3) 8 | with pytest.raises(ValueError): 9 | f.c = 123 10 | g = pax.freeze_parameters(f) 11 | # k = pax.unfreeze_parameters(g, origin=f) 12 | 13 | 14 | def test_new_empty_attribute(): 15 | class M(pax.Module): 16 | a = [] 17 | 18 | m = M() 19 | 20 | 21 | def test_new_unregistered_array(): 22 | class M(pax.Module): 23 | a = [jnp.zeros((3, 3))] 24 | 25 | with pytest.raises(ValueError): 26 | m = M() 27 | 28 | 29 | def test_new_unregistered_module(): 30 | class M(pax.Module): 31 | a = pax.Linear(3, 3) 32 | 33 | with pytest.raises(ValueError): 34 | m = M() 35 | -------------------------------------------------------------------------------- /tests/test_jax_transform.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pax 4 | import pytest 5 | 6 | 7 | def test_jit_immutability(): 8 | class M(pax.Module): 9 | def __init__(self): 10 | self.x = pax.Linear(2, 2) 11 | self.counter = 2 12 | 13 | def __call__(self, x): 14 | self.counter = self.counter + 1 15 | return x 16 | 17 | m = M() 18 | x = jnp.zeros((1, 1)) 19 | with pytest.raises(ValueError): 20 | y = jax.jit(lambda y: m(y))(x) 21 | 22 | 23 | def test_grad_deepscan(): 24 | class M(pax.Module): 25 | def __init__(self): 26 | self.fc = pax.Linear(2, 2) 27 | 28 | def __call__(self, x): 29 | return self.fc(x) 30 | 31 | def loss_fn(params, model, inputs): 32 | model = pax.update_parameters(model, params=params) 33 | loss = jnp.mean(model(inputs)) 34 | return loss, (loss, model) 35 | 36 | m = M() 37 | x = jnp.zeros((1, 2)) 38 | m.set_attribute("fc1", pax.Linear(2, 2)) 39 | y = jax.grad(loss_fn, has_aux=True)(pax.select_parameters(m), m, x) 40 | 41 | 42 | def test_loss_fn_no_return_model(): 43 | def loss_fn(params, model, inputs): 44 | model = pax.update_parameters(model, params=params) 45 | y = model(inputs) 46 | return jnp.sum(y) 47 | 48 | grad_fn = jax.grad(loss_fn) 49 | x = jnp.zeros((3, 3)) 50 | net = pax.Linear(3, 3) 51 | y = grad_fn(net.parameters(), net, x) 52 | 53 | 54 | def test_jit__call__(): 55 | class M(pax.Module): 56 | @jax.jit 57 | def __call__(self, x): 58 | return x, self 59 | 60 | x = jnp.zeros((3, 3)) 61 | net = M() 62 | y = net(x) 63 | 64 | class M(pax.Module): 65 | @jax.jit 66 | def __call__(self, x): 67 | return x 68 | 69 | net = M() 70 | y = net(x) 71 | -------------------------------------------------------------------------------- /tests/test_mixed_precision.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jmp 4 | import pax 5 | import pytest 6 | from pax import apply_mp_policy 7 | 8 | half = jmp.half_dtype() 9 | full = jnp.float32 10 | 11 | 12 | def test_wrap_unwrap_mp_policy(): 13 | f = pax.Linear(3, 3) 14 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) 15 | 16 | ff = pax.apply_mp_policy(f, mp_policy=my_policy) 17 | fff = pax.unwrap_mp_policy(ff) 18 | assert hasattr(ff, "_pax_mp_policy") 19 | assert not hasattr(fff, "_pax_mp_policy") 20 | 21 | x = jax.numpy.ones((3, 3)) 22 | assert f(x).dtype == full 23 | assert ff(x).dtype == half 24 | assert fff(x).dtype == full # type: ignore 25 | 26 | 27 | def test_sequential_mixed_precision(): 28 | f = pax.Sequential( 29 | pax.Linear(3, 3), 30 | pax.BatchNorm2D(3, True, True, 0.9), 31 | pax.Linear(3, 3), 32 | pax.BatchNorm2D(3, True, True, 0.9), 33 | ) 34 | linear_policy = jmp.Policy(compute_dtype=half, param_dtype=half, output_dtype=half) 35 | batchnorm_policy = jmp.Policy( 36 | compute_dtype=full, param_dtype=full, output_dtype=half 37 | ) 38 | 39 | def policy_fn(mod): 40 | if isinstance(mod, pax.Linear): 41 | return pax.apply_mp_policy(mod, mp_policy=linear_policy) 42 | elif isinstance(mod, pax.BatchNorm2D): 43 | return pax.apply_mp_policy(mod, mp_policy=batchnorm_policy) 44 | else: 45 | # unchanged 46 | return mod 47 | 48 | f_mp = f.apply(policy_fn) 49 | x = jnp.zeros((32, 5, 5, 3)) 50 | 51 | @pax.pure 52 | def run(f_mp): 53 | return f_mp(x) 54 | 55 | y = run(f_mp) 56 | assert y.dtype == half 57 | 58 | 59 | def test_change_internal_state(): 60 | class M(pax.Module): 61 | counter: jnp.ndarray 62 | 63 | def __init__(self): 64 | super().__init__() 65 | self.counter = jnp.array(0) 66 | 67 | def __call__(self, x): 68 | self.counter = self.counter + 1 69 | return x * self.counter 70 | 71 | m = M() 72 | mp = jmp.Policy( 73 | compute_dtype=jnp.float16, param_dtype=jnp.float32, output_dtype=jnp.float16 74 | ) 75 | mm = m.apply( 76 | lambda x: (pax.apply_mp_policy(x, mp_policy=mp) if isinstance(x, M) else x) 77 | ) 78 | x = jnp.array(0.0) 79 | assert mm.counter.item() == 0 80 | mm, y = pax.module_and_value(mm)(x) 81 | assert mm.counter.item() == 1 82 | assert m.counter.item() == 0 83 | 84 | 85 | def test_change_tree_def(): 86 | class M(pax.Module): 87 | counter: jnp.ndarray 88 | count: int 89 | 90 | def __init__(self): 91 | super().__init__() 92 | self.counter = jnp.array(0) 93 | self.count = 0 94 | 95 | def __call__(self, x): 96 | self.counter = self.counter + 1 97 | self.count = self.count + 1 98 | return x * self.counter 99 | 100 | m = M() 101 | mp = jmp.Policy( 102 | compute_dtype=jnp.float16, param_dtype=jnp.float32, output_dtype=jnp.float16 103 | ) 104 | mm = m.apply( 105 | lambda x: (pax.apply_mp_policy(x, mp_policy=mp) if isinstance(x, M) else x) 106 | ) 107 | x = jnp.array(0.0) 108 | assert mm.counter.item() == 0 109 | with pytest.raises(ValueError): 110 | y = mm(x) 111 | assert mm.counter.item() == 0 112 | assert m.counter.item() == 0 113 | 114 | 115 | def test_wrap_wrap_mixed_precision(): 116 | f = pax.Linear(3, 3) 117 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) 118 | 119 | f = pax.apply_mp_policy(f, mp_policy=my_policy) 120 | with pytest.raises(ValueError): 121 | f = pax.apply_mp_policy(f, mp_policy=my_policy) 122 | 123 | f = pax.unwrap_mp_policy(f) 124 | f = pax.apply_mp_policy(f, mp_policy=my_policy) 125 | 126 | with pytest.raises(ValueError): 127 | f = pax.apply_mp_policy(f, mp_policy=my_policy) 128 | 129 | 130 | def test_mixed_precision_clone(): 131 | f = pax.BatchNorm1D(3) 132 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) 133 | 134 | ff = pax.apply_mp_policy(f, mp_policy=my_policy) 135 | 136 | f = f.set_attribute("new_fc", pax.Linear(1, 1)) 137 | # assert "new_fc" not in ff.pax.name_to_kind 138 | 139 | 140 | def test_mixed_precision_unwrap_clone(): 141 | f = pax.BatchNorm1D(3) 142 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) 143 | 144 | ff = pax.apply_mp_policy(f, mp_policy=my_policy) 145 | f = pax.unwrap_mp_policy(ff) 146 | f = f.set_attribute("new_fc", pax.Linear(1, 1)) 147 | # assert "new_fc" not in ff.pax.name_to_kind 148 | 149 | 150 | def test_mixed_precision_no_method_name(): 151 | f = pax.Linear(3, 3) 152 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) 153 | 154 | # with pytest.raises(TypeError): 155 | _ = pax.apply_mp_policy(f, mp_policy=my_policy) 156 | 157 | 158 | def test_mp_call_classmethod(): 159 | class M(pax.Module): 160 | def __init__(self): 161 | super().__init__() 162 | self.fc = pax.Linear(3, 3) 163 | 164 | @classmethod 165 | def t(cls, y): 166 | return y 167 | 168 | m = M() 169 | x = jnp.zeros((3, 3)) 170 | y = m.t(x) 171 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) 172 | m = apply_mp_policy(m, mp_policy=my_policy) 173 | # with pytest.raises(ValueError): 174 | y = m.t(x) 175 | 176 | 177 | def test_mp_call_staticmethod(): 178 | class M(pax.Module): 179 | def __init__(self): 180 | super().__init__() 181 | self.fc = pax.Linear(3, 3) 182 | 183 | @staticmethod 184 | def t(_, y): 185 | return y 186 | 187 | m = M() 188 | x = jnp.zeros((3, 3)) 189 | y = m.t(x, x) 190 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) 191 | m = apply_mp_policy(m, mp_policy=my_policy) 192 | # with pytest.raises(ValueError): 193 | y = m.t(x, x) 194 | 195 | 196 | @pax.pure 197 | def test_mp_call_function(): 198 | class M(pax.Module): 199 | def __init__(self): 200 | super().__init__() 201 | self.fc = pax.Linear(3, 3) 202 | 203 | m = M() 204 | x = jnp.zeros((3, 3)) 205 | 206 | def mutate(m): 207 | m.q = lambda x: x 208 | return m 209 | 210 | m = pax.pure(mutate)(m) 211 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half) 212 | m = apply_mp_policy(m, mp_policy=my_policy) 213 | # with pytest.raises(ValueError): 214 | m.q(x) 215 | -------------------------------------------------------------------------------- /tests/test_multithread.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test PAX in multithread environment. 3 | """ 4 | import queue 5 | import threading 6 | import time 7 | 8 | import jax.numpy as jnp 9 | import pax 10 | 11 | 12 | class DelayedCounter(pax.Module): 13 | def __init__(self): 14 | super().__init__() 15 | self.counter = jnp.array(0) 16 | 17 | def __call__(self): 18 | time.sleep(1) 19 | self.counter += 1 20 | time.sleep(1) 21 | return self.counter 22 | 23 | 24 | def test_multithread(): 25 | @pax.pure 26 | def update(c: DelayedCounter, q): 27 | o = c() 28 | q.put(o) 29 | 30 | c1 = DelayedCounter() 31 | c2 = DelayedCounter() 32 | q = queue.Queue() 33 | x = threading.Thread(target=update, args=(c1, q)) 34 | y = threading.Thread(target=update, args=(c2, q)) 35 | x.start() 36 | y.start() 37 | x.join() 38 | y.join() 39 | q.get(timeout=1) 40 | q.get(timeout=1) 41 | -------------------------------------------------------------------------------- /tests/test_nets.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import pax 3 | 4 | 5 | def test_run_resnet(): 6 | resnet = pax.nets.ResNet18(3, 1) 7 | x = jax.numpy.zeros((1, 3, 18, 18)) 8 | y = pax.pure(resnet)(x) 9 | assert y.shape == (1, 1) 10 | 11 | 12 | def test_run_transformer(): 13 | transformer = pax.nets.Transformer(8, 2, 2, 0.1) 14 | x = jax.numpy.zeros((1, 15, 8)) 15 | y = pax.pure(transformer)(x) 16 | assert y.shape == (1, 15, 8) 17 | -------------------------------------------------------------------------------- /tests/test_optim.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import opax 4 | import pax 5 | import pytest 6 | 7 | 8 | def test_optim_model_update_state(): 9 | # a module updates it internal `count` value in the forward pass. 10 | 11 | class MyModule(pax.Module): 12 | count: int = 0 13 | fc: pax.Module 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.fc = pax.Linear(2, 2) 18 | self.count = 0 19 | 20 | def __call__(self, x): 21 | self.count = self.count + 1 22 | x = self.fc(x) 23 | return x 24 | 25 | net = MyModule() 26 | 27 | def loss_fn(model: MyModule, x): 28 | y = model(x) 29 | loss = jnp.mean(jnp.square(x - y)) 30 | return loss, (loss, model) 31 | 32 | update_fn = pax.utils.build_update_fn(loss_fn=loss_fn) 33 | optimizer = opax.adamw()(net.parameters()) 34 | x = jnp.zeros((2, 2), dtype=jnp.float32) 35 | 36 | with pytest.raises(ValueError): 37 | net, optimizer, loss = update_fn(net, optimizer, x) 38 | 39 | 40 | def test_sgd(): 41 | class SGD(pax.Module): 42 | velocity: pax.Module 43 | learning_rate: float 44 | momentum: float 45 | 46 | def __init__(self, params, learning_rate: float = 1e-2, momentum: float = 0.9): 47 | super().__init__() 48 | self.momentum = momentum 49 | self.learning_rate = learning_rate 50 | self.velocity = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), params) 51 | 52 | def step(self, grads: pax.Module, params: pax.Module): 53 | self.velocity = jax.tree_util.tree_map( 54 | lambda v, g: v * self.momentum + g * self.learning_rate, 55 | self.velocity, 56 | grads, 57 | ) 58 | new_params = jax.tree_util.tree_map( 59 | lambda p, v: p - v, params, self.velocity 60 | ) 61 | return new_params 62 | 63 | f = pax.Linear(2, 2) 64 | sgd = SGD(f, 0.9, 1e-4) 65 | pax.pure(sgd.step)(f, f) 66 | -------------------------------------------------------------------------------- /tests/test_performance.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax 4 | import numpy as np 5 | import pax 6 | 7 | 8 | def test_perf_transformer_flatten_unflatten(): 9 | class MyTransformer(pax.Module): 10 | def __init__(self, num_layers: int): 11 | super().__init__() 12 | self.layers = [ 13 | pax.MultiHeadAttention(8, 512 // 8, 1.0) for i in range(num_layers) 14 | ] 15 | 16 | f = MyTransformer(16) 17 | 18 | start = time.perf_counter() 19 | n_iters = 100_000 20 | for _ in range(n_iters): 21 | leaves, treedef = jax.tree_util.tree_flatten(f) 22 | f = jax.tree_util.tree_unflatten(treedef, leaves) 23 | end = time.perf_counter() 24 | iters_per_second = n_iters / (end - start) 25 | print(iters_per_second, "iters/second") 26 | assert iters_per_second > 2500 27 | 28 | 29 | def test_perf_resnet200_flatten_unflatten(): 30 | 31 | f = pax.nets.ResNet200(3, 100) 32 | 33 | start = time.perf_counter() 34 | n_iters = 1000 35 | for _ in range(n_iters): 36 | leaves, treedef = jax.tree_util.tree_flatten(f) 37 | f = jax.tree_util.tree_unflatten(treedef, leaves) 38 | end = time.perf_counter() 39 | iters_per_second = n_iters / (end - start) 40 | print(iters_per_second, "iters/second") 41 | assert iters_per_second > 100 42 | 43 | 44 | def test_perf_flattenmodule_resnet200_flatten_unflatten(): 45 | 46 | x = jax.random.normal(jax.random.PRNGKey(42), (1, 3, 64, 64)) 47 | f = pax.nets.ResNet200(3, 100) 48 | y = f.eval()(x) 49 | f = pax.experimental.Flattener(net=f.eval()) 50 | y1 = pax.pure(f.net)(x) 51 | np.testing.assert_array_equal(y, y1) 52 | 53 | start = time.perf_counter() 54 | n_iters = 10000 55 | for _ in range(n_iters): 56 | leaves, treedef = jax.tree_util.tree_flatten(f) 57 | f = jax.tree_util.tree_unflatten(treedef, leaves) 58 | end = time.perf_counter() 59 | iters_per_second = n_iters / (end - start) 60 | print(iters_per_second, "iters/second") 61 | assert iters_per_second > 4000 62 | -------------------------------------------------------------------------------- /tests/test_pure.py: -------------------------------------------------------------------------------- 1 | import weakref 2 | from functools import partial 3 | from typing import Any 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import pax 8 | import pytest 9 | from numpy.testing import assert_array_equal 10 | 11 | 12 | def test_rng_unchanged(): 13 | pax.seed_rng_key(41) 14 | pax.next_rng_key() 15 | 16 | @jax.jit 17 | @pax.pure 18 | def fn(): 19 | return pax.next_rng_key() 20 | 21 | def f1(): 22 | pax.seed_rng_key(41) 23 | pax.next_rng_key() 24 | return pax.next_rng_key() 25 | 26 | def f2(): 27 | pax.seed_rng_key(41) 28 | pax.next_rng_key() 29 | fn() 30 | return pax.next_rng_key() 31 | 32 | r1 = f1() 33 | r2 = f2() 34 | assert not jnp.array_equal(r1, r2) 35 | 36 | r3 = fn() 37 | _ = pax.next_rng_key() 38 | r4 = fn() 39 | assert_array_equal(r3, r4) 40 | 41 | 42 | def test_deepcopy(): 43 | class C(object): 44 | c: int 45 | 46 | def __init__(self): 47 | self.c = 0 48 | 49 | @pax.pure 50 | def mutate(x): 51 | x.c.c += 1 52 | return x 53 | 54 | class M(pax.Module): 55 | c: C 56 | 57 | def __init__(self): 58 | self.c = C() 59 | 60 | m = M() 61 | assert m.c.c == 0 62 | m1 = mutate(m) 63 | assert m.c.c == 1 64 | assert m1.c.c == 1 65 | 66 | 67 | def test_deep_compare_1(): 68 | class C(object): 69 | c: int 70 | 71 | def __init__(self): 72 | self.c = 0 73 | 74 | @pax.pure 75 | def mutate(x): 76 | return x 77 | 78 | class M(pax.Module): 79 | c: C 80 | 81 | def __init__(self): 82 | self.c = C() 83 | 84 | m = M() 85 | m1 = mutate(m) 86 | # with pytest.raises(AssertionError): 87 | pax.assert_structure_equal(m, m1) 88 | 89 | 90 | def test_deep_compare_2(): 91 | class C(object): 92 | c: int 93 | 94 | def __init__(self): 95 | self.c = 0 96 | 97 | def __eq__(self, o) -> bool: 98 | return self.c == o.c 99 | 100 | @pax.pure 101 | def mutate(x): 102 | return x 103 | 104 | class M(pax.Module): 105 | f: Any 106 | g: Any 107 | j: Any 108 | c: C 109 | 110 | def __init__(self): 111 | self.f = jax.nn.relu 112 | self.g = jax.nn.sigmoid 113 | self.j = partial(jax.nn.leaky_relu, negative_slope=0.2) 114 | self.h = jnp.tanh 115 | self.c = C() 116 | 117 | m = M() 118 | m1 = mutate(m) 119 | pax.assert_structure_equal(m, m1) 120 | 121 | 122 | def test_module_weak_ref(): 123 | mod = pax.Linear(3, 3) 124 | mod_ref = weakref.ref(mod) 125 | assert mod_ref() is mod 126 | del mod 127 | assert mod_ref() is None 128 | 129 | 130 | def test_abstraction_level_checking(): 131 | def mutate(f): 132 | @jax.jit 133 | def g(): 134 | f.a = "hello" 135 | 136 | g() 137 | 138 | fc = pax.Linear(3, 3) 139 | with pytest.raises(ValueError): 140 | pax.pure(mutate)(fc) 141 | 142 | 143 | def test_decorate_method_with_module_and_value(): 144 | class M(pax.StateModule): 145 | def __init__(self): 146 | self.c = jnp.array(0) 147 | 148 | @pax.module_and_value 149 | def step(self): 150 | self.c += 1 151 | 152 | m = M() 153 | assert m.c.item() == 0 154 | m, _ = m.step() 155 | assert m.c.item() == 1 156 | -------------------------------------------------------------------------------- /tests/test_summary.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import pax 3 | 4 | 5 | def test_linear_summary(): 6 | fc = pax.Linear(3, 3) 7 | assert fc.summary() == "Linear(in_dim=3, out_dim=3, with_bias=True)" 8 | 9 | 10 | def test_sequential_summary(): 11 | f = pax.Sequential(pax.Linear(3, 32), jax.nn.sigmoid, pax.Linear(32, 64)) 12 | f1 = pax.BatchNorm1D(3) 13 | f1 = f1.set_attribute("T", f) 14 | print(f1.summary()) 15 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import opax 6 | import pax 7 | 8 | 9 | def test_train_linear_regression_1(): 10 | x = jax.random.normal(jax.random.PRNGKey(42), (32, 1), dtype=jnp.float32) 11 | noise = jax.random.normal(jax.random.PRNGKey(43), (32, 1), dtype=jnp.float32) * 0.2 12 | y = x * 2.5 - 3.1 + noise 13 | 14 | def loss_fn(model: pax.Linear, x, y): 15 | y_hat = model(x) 16 | loss = jnp.mean(jnp.square(y - y_hat)) 17 | return loss, (loss, model) 18 | 19 | update_fn = pax.utils.build_update_fn(loss_fn) 20 | net = pax.Linear(1, 1) 21 | optimizer = opax.adamw(1e-1)(net.parameters()) 22 | for step in range(100): 23 | net, optimizer, loss = update_fn(net, optimizer, x, y) 24 | print(f"[step {step}] loss {loss:.3f}") 25 | 26 | 27 | def test_train_linear_regression_2(): 28 | x = jax.random.normal(jax.random.PRNGKey(42), (32, 1), dtype=jnp.float32) 29 | noise = jax.random.normal(jax.random.PRNGKey(43), (32, 1), dtype=jnp.float32) * 0.2 30 | y = x * 2.5 - 3.1 + noise 31 | 32 | class M(pax.Module): 33 | def __init__(self): 34 | super().__init__() 35 | self.fc1 = pax.Linear(1, 32) 36 | self.fc2 = pax.Linear(32, 1) 37 | 38 | def __call__(self, x): 39 | x = self.fc1(x) 40 | x = jax.nn.relu(x) 41 | x = self.fc2(x) 42 | return x 43 | 44 | def loss_fn(model: M, x, y): 45 | y_hat = model(x) 46 | loss = jnp.mean(jnp.square(y - y_hat)) 47 | return loss, (loss, model) 48 | 49 | update_fn = pax.utils.build_update_fn(loss_fn) 50 | net = M() 51 | optimizer = opax.adamw(1e-1)(net.parameters()) 52 | for step in range(100): 53 | net, optimizer, loss = update_fn(net, optimizer, x, y) 54 | print(f"[step {step}] loss {loss:.3f}") 55 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jmp 3 | import pax 4 | 5 | 6 | def test_mutate_new_module_list(): 7 | a = pax.Linear(3, 3) 8 | b = a.copy() 9 | 10 | def mutate(b): 11 | b.lst = [pax.Linear(4, 4)] 12 | return b 13 | 14 | b = pax.pure(mutate)(b) 15 | # pylint: disable=protected-access 16 | # assert b.pax.name_to_kind["lst"] == pax.PaxKind.MODULE 17 | 18 | 19 | def test_mp_policy_method_name(): 20 | class M(pax.Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.f = pax.Linear(3, 3) 24 | 25 | def __call__(self, x): 26 | return self.f(x) 27 | 28 | def inference(self, x): 29 | return self.f(x) + 1.0 30 | 31 | m = M() 32 | half = jmp.half_dtype() 33 | full = jnp.float32 34 | 35 | p = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=full) 36 | 37 | m = pax.apply_mp_policy(m, mp_policy=p) 38 | x = jnp.zeros((4, 3)) 39 | _ = m(x) # ok 40 | 41 | _ = m.inference(x) 42 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import opax 5 | import pax 6 | from pax import EMA, RngSeq 7 | 8 | 9 | def test_grad(): 10 | def loss_fn(model: pax.Linear, inputs): 11 | x, target = inputs 12 | y = model(x) 13 | loss = jnp.mean(jnp.square(y - target)) 14 | return loss, (loss, model) 15 | 16 | @jax.jit 17 | def update_fn(model, optimizer, inputs): 18 | grads, (loss, model) = pax.grad(loss_fn, has_aux=True)(model, inputs) 19 | model, optimizer = opax.apply_gradients(model, opt, grads=grads) 20 | return model, optimizer, loss 21 | 22 | net = pax.Linear(2, 1) 23 | opt = opax.adamw(learning_rate=1e-2)(net.parameters()) 24 | x = np.random.normal(size=(32, 2)) 25 | y = np.random.normal(size=(32, 1)) 26 | print() 27 | for step in range(5): 28 | net, opt, loss = update_fn(net, opt, (x, y)) 29 | print(f"step {step} loss {loss:.3f}") 30 | 31 | 32 | def test_value_and_grad(): 33 | def loss_fn(model: pax.Linear, inputs): 34 | x, target = inputs 35 | y = model(x) 36 | loss = jnp.mean(jnp.square(y - target)) 37 | return loss, model 38 | 39 | @jax.jit 40 | def update_fn(model, optimizer, inputs): 41 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, inputs) 42 | model, optimizer = opax.apply_gradients(model, opt, grads) 43 | return model, optimizer, loss 44 | 45 | net = pax.Linear(2, 1) 46 | opt = opax.adamw(learning_rate=1e-2)(net.parameters()) 47 | x = np.random.normal(size=(32, 2)) 48 | y = np.random.normal(size=(32, 1)) 49 | print() 50 | for step in range(5): 51 | net, opt, loss = update_fn(net, opt, (x, y)) 52 | print(f"step {step} loss {loss:.3f}") 53 | 54 | 55 | def test_util_update_fn(): 56 | def loss_fn(model: pax.Linear, x, target): 57 | y = model(x) 58 | loss = jnp.mean(jnp.square(y - target)) 59 | return loss, (loss, model) 60 | 61 | net = pax.Linear(2, 1) 62 | opt = opax.adamw(learning_rate=1e-1)(net.parameters()) 63 | update_fn = jax.jit(pax.utils.build_update_fn(loss_fn, scan_mode=True)) 64 | x = np.random.normal(size=(32, 2)) 65 | y = np.random.normal(size=(32, 1)) 66 | print() 67 | for step in range(3): 68 | (net, opt), loss = update_fn((net, opt), x, y) 69 | print(f"step {step} loss {loss:.3f}") 70 | 71 | 72 | def test_Rng_Seq(): 73 | rng_seq = RngSeq(seed=42) 74 | assert rng_seq._rng_key.tolist() == [0, 42] 75 | 76 | rng_seq, r1 = pax.module_and_value(rng_seq.next_rng_key)() 77 | assert r1.shape == (2,) 78 | h1 = rng_seq._rng_key 79 | rng_seq, rs = pax.module_and_value(rng_seq.next_rng_key)(2) 80 | h2 = rng_seq._rng_key 81 | assert len(rs) == 2 82 | assert r1.tolist() != rs[0].tolist() 83 | assert h1.tolist() != h2.tolist(), "update internal state in `train` mode" 84 | 85 | rng_seq = pax.enable_eval_mode(rng_seq) 86 | rng_seq, r3 = pax.module_and_value(rng_seq.next_rng_key)() 87 | rng_seq, r4 = pax.module_and_value(rng_seq.next_rng_key)() 88 | assert r3.tolist() != r4.tolist() 89 | h3 = rng_seq._rng_key 90 | assert h2.tolist() != h3.tolist(), "update internal state even in `eval` mode" 91 | 92 | 93 | def test_ema_debias(): 94 | ema = EMA(jnp.array(1.0), 0.9, True) 95 | assert ema.debias.item() == False 96 | assert ema.averages.item() == 1.0 97 | 98 | ema, _ = pax.purecall(ema, jnp.array(2.0)) 99 | assert ema.averages.item() == 2.0 100 | assert ema.debias.item() == True 101 | 102 | ema, _ = pax.purecall(ema, jnp.array(1.0)) 103 | np.testing.assert_almost_equal(ema.averages.item(), 0.9 * 2.0 + 0.1 * 1.0) 104 | 105 | 106 | def test_ema_bias(): 107 | ema = EMA(jnp.array(1.0), 0.9, False) 108 | assert ema.debias is None 109 | assert ema.averages.item() == 1.0 110 | 111 | ema, _ = pax.purecall(ema, jnp.array(2.0)) 112 | np.testing.assert_almost_equal(ema.averages.item(), 0.1 * 2.0 + 0.9 * 1.0) 113 | 114 | 115 | def test_scan_fn_not_time_major(): 116 | def loop(prev_state, x): 117 | next_state = prev_state + x 118 | return next_state, next_state 119 | 120 | h0 = jnp.zeros((1,)) 121 | xs = jnp.arange(0, 10).reshape((1, -1)) 122 | _, ys = pax.scan(loop, h0, xs, time_major=False) 123 | assert ys[0, -1].item() == 45 124 | 125 | 126 | def test_scan_fn_not_time_major_pytree(): 127 | def loop(prev_state, x): 128 | next_state = prev_state + x[0] + x[1] 129 | return next_state, (next_state, next_state) 130 | 131 | h0 = jnp.zeros((1,)) 132 | xs = jnp.arange(0, 10).reshape((1, -1)) 133 | _, (ys1, ys2) = pax.scan(loop, h0, (xs, xs), time_major=False) 134 | assert ys1[0, -1].item() == 90 135 | 136 | 137 | def test_scan_fn_time_major(): 138 | def loop(prev_state, x): 139 | next_state = prev_state + x 140 | return next_state, next_state 141 | 142 | h0 = jnp.zeros((1,)) 143 | xs = jnp.arange(0, 10).reshape((-1, 1)) 144 | _, ys = pax.scan(loop, h0, xs, time_major=True) 145 | assert ys[-1, 0].item() == 45 146 | --------------------------------------------------------------------------------