├── .github
└── workflows
│ ├── ci.yml
│ ├── docs.yml
│ └── pypi.yml
├── .gitignore
├── .pylintrc
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── docs
├── Makefile
├── api
│ ├── basics.rst
│ ├── experimental.rst
│ ├── modules.rst
│ ├── transformations.rst
│ └── utilities.rst
├── conf.py
├── index.rst
├── notebooks
│ ├── basics.ipynb
│ ├── jax_transformations.ipynb
│ ├── operators.ipynb
│ ├── performance.ipynb
│ ├── training.ipynb
│ └── understanding.ipynb
└── requirements.txt
├── examples
├── basics.py
├── char_rnn.py
├── dataclass_module.py
├── denoising_diffusion
│ ├── README.md
│ ├── data_loader.py
│ ├── model.py
│ ├── requirements.txt
│ └── train.py
├── graph_module.py
├── lazy_module.py
├── mnist.py
├── mnist_mixed_precision.py
├── notebooks
│ ├── DCGAN.ipynb
│ ├── VAE.ipynb
│ ├── adversarial_examples.ipynb
│ ├── fine_tuning_resnet18.ipynb
│ ├── mixed_precision.ipynb
│ ├── pretrained_resnet18.py
│ └── test_pretrained_resnet18.ipynb
├── transformer
│ ├── data.py
│ ├── model.py
│ └── train.py
└── wave_gru
│ ├── README.md
│ ├── data_loader.py
│ ├── model.py
│ ├── prepare_data.sh
│ ├── requirements.txt
│ └── train.py
├── images
└── pax_logo.png
├── pax
├── __init__.py
├── _src
│ ├── __init__.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── graph_module.py
│ │ ├── mixed_precision.py
│ │ ├── module.py
│ │ ├── module_and_value.py
│ │ ├── mutable.py
│ │ ├── pure.py
│ │ ├── rng.py
│ │ ├── safe_module.py
│ │ ├── threading_local.py
│ │ ├── transforms.py
│ │ ├── utility_modules.py
│ │ └── utils.py
│ ├── nets
│ │ ├── __init__.py
│ │ ├── resnet.py
│ │ └── transformer.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── batch_norm.py
│ │ ├── conv.py
│ │ ├── dropout.py
│ │ ├── ema.py
│ │ ├── embed.py
│ │ ├── group_norm.py
│ │ ├── identity.py
│ │ ├── lambda_module.py
│ │ ├── layer_norm.py
│ │ ├── linear.py
│ │ ├── pool.py
│ │ ├── recurrent.py
│ │ ├── rng_seq.py
│ │ └── sequential.py
│ └── utils.py
├── experimental
│ ├── __init__.py
│ └── graph.py
├── nets.py
├── py.typed
└── utils.py
├── setup.py
└── tests
├── test_auto_modules.py
├── test_counter.py
├── test_deepscan.py
├── test_finetune.py
├── test_freeze_unfreeze.py
├── test_graph_module.py
├── test_immutability.py
├── test_jax_transform.py
├── test_mixed_precision.py
├── test_multithread.py
├── test_nets.py
├── test_nn.py
├── test_optim.py
├── test_pax.py
├── test_performance.py
├── test_pure.py
├── test_summary.py
├── test_training.py
├── test_transforms.py
└── test_utils.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | push:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | test-ubuntu:
13 | name: "Test on ${{ matrix.python-version }} on ${{ matrix.os }}"
14 | runs-on: "${{ matrix.os }}"
15 | strategy:
16 | matrix:
17 | python-version: [3.7, 3.8, 3.9]
18 | os: [ubuntu-latest]
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v1
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install .[test]
29 | - name: Test with pytest
30 | run: |
31 | pip install pytest pytest-xdist
32 | pytest -n auto -k "not perf" tests
33 | # pytest -n 1 -k "perf" tests
34 | - name: Test with pytype
35 | run: |
36 | pip install pytype
37 | pytype pax tests
38 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | # Source: https://raw.githubusercontent.com/deepmind/dm-haiku/0a28e731938ef932ed6c33555fb1051bea0b29bd/.github/workflows/docs.yml
2 | # Apache-2.0 License
3 |
4 | name: docs
5 |
6 | on:
7 | pull_request:
8 | branches:
9 | - main
10 | push:
11 | branches:
12 | - main
13 |
14 | jobs:
15 | test-ubuntu:
16 | name: "docs on ${{ matrix.python-version }} on ${{ matrix.os }}"
17 | runs-on: "${{ matrix.os }}"
18 | strategy:
19 | matrix:
20 | python-version: [3.7, 3.8, 3.9]
21 | os: [ubuntu-latest]
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python ${{ matrix.python-version }}
25 | uses: actions/setup-python@v1
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 | - name: Install dependencies
29 | run: |
30 | sudo apt install -y pandoc
31 | python -m pip install --upgrade pip
32 | pip install .[test]
33 | pip install -r docs/requirements.txt
34 | - name: Test doctests
35 | run: |
36 | cd docs
37 | make doctest
38 | - name: Test docs to HTML
39 | run: |
40 | cd docs
41 | make html
--------------------------------------------------------------------------------
/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | name: pypi
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - name: Set up Python
13 | uses: actions/setup-python@v1
14 | with:
15 | python-version: '3.x'
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install setuptools wheel twine
20 | - name: Build and publish
21 | env:
22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
24 | run: |
25 | python setup.py sdist bdist_wheel
26 | twine upload dist/*
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | version: 2
5 |
6 | sphinx:
7 | builder: html
8 | configuration: docs/conf.py
9 | fail_on_warning: false
10 |
11 | python:
12 | version: 3.7
13 | install:
14 | - requirements: docs/requirements.txt
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Thông Nguyễn
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | [**Introduction**](#introduction)
6 | | [**Getting started**](#gettingstarted)
7 | | [**Functional programming**](#functional)
8 | | [**Examples**](https://github.com/ntt123/pax/tree/main/examples/)
9 | | [**Modules**](#modules)
10 | | [**Fine-tuning**](#finetune)
11 |
12 | 
13 | 
14 | 
15 |
16 |
17 | ## Introduction
18 |
19 | PAX is a [JAX]-based library for training neural networks.
20 |
21 | PAX modules are registered as JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), therefore, they can be input or output of JAX transformations such as `jax.jit`, `jax.grad`, etc. This makes programming with modules very convenient and easy to understand.
22 |
23 | ## Installation
24 |
25 | Install from PyPI:
26 |
27 | ```bash
28 | pip install pax3
29 | ```
30 |
31 | Or install the latest version from Github:
32 |
33 | ```bash
34 | pip install git+https://github.com/ntt123/pax.git
35 |
36 | ## or test mode to run tests and examples
37 | pip install git+https://github.com/ntt123/pax.git#egg=pax3[test]
38 | ```
39 |
40 |
41 | ## Getting started
42 |
43 |
44 | Below is a simple example of a `Linear` module.
45 |
46 | ```python
47 | import jax.numpy as jnp
48 | import pax
49 |
50 | class Linear(pax.Module):
51 | weight: jnp.ndarray
52 | bias: jnp.ndarray
53 | parameters = pax.parameters_method("weight", "bias")
54 |
55 | def __init__(self):
56 | super().__init__()
57 | self.weight = jnp.array(0.0)
58 | self.bias = jnp.array(0.0)
59 |
60 | def __call__(self, x):
61 | return self.weight * x + self.bias
62 | ```
63 |
64 | The implementation is very similar to a normal python class. However, we need an additional line
65 |
66 | ```python
67 | parameters = pax.parameters_method("weight", "bias")
68 | ```
69 |
70 | to declare that `weight` and `bias` are *trainable parameters* of the Linear module.
71 |
72 | ## PAX functional programming
73 |
74 | ### `pax.pure`
75 |
76 | A PAX module can have internal states. For example, below is a simple `Counter` module with an internal counter.
77 |
78 | ```python
79 | class Counter(pax.Module):
80 | count : jnp.ndarray
81 |
82 | def __init__(self):
83 | super().__init__()
84 | self.count = jnp.array(0)
85 |
86 | def __call__(self):
87 | self.count = self.count + 1
88 | return self.count
89 | ```
90 |
91 | However, PAX *aims* to guarantee that modules will have no side effects from the outside point of view.
92 | Therefore, the modifications of these internal states are restricted. For example, we get an error when trying to call `Counter` directly.
93 |
94 | ```python
95 | counter = Counter()
96 | count = counter()
97 | # ...
98 | # ----> 9 self.count = self.count + 1
99 | # ...
100 | # ValueError: Cannot modify a module in immutable mode.
101 | # Please do this computation inside a function decorated by `pax.pure`.
102 | ```
103 |
104 | Only functions decorated by `pax.pure` are allowed to modify input module's internal states.
105 |
106 | ```python
107 | @pax.pure
108 | def update_counter(counter: Counter):
109 | count = counter()
110 | return counter, count
111 |
112 | counter, count = update_counter(counter)
113 | print(counter.count, count)
114 | # 1 1
115 | ```
116 |
117 | Note that we have to return `counter` in the output of `update_counter`, otherwise, the `counter` object will not be updated. This is because `pax.pure` only provides `update_counter` a copy of the `counter` object.
118 |
119 |
120 | ### `pax.purecall`
121 |
122 | For convenience, PAX provides the `pax.purecall` function.
123 | It is a shortcut for `pax.pure(lambda f, x: [f, f(x)])`.
124 |
125 | Instead of implementing an `update_counter` function, we can do the same thing with:
126 |
127 | ```python
128 | counter, count = pax.purecall(counter)
129 | print(counter.count, count)
130 | # 2, 2
131 | ```
132 |
133 | ### Replacing parts
134 |
135 | PAX provides utility methods to modify a module in a functional way.
136 |
137 | The `replace` method creates a new module with attributes replaced.
138 | For example, to replace `weight` and `bias` of a `pax.Linear` module:
139 |
140 | ```python
141 | fc = pax.Linear(2, 2)
142 | fc = fc.replace(weight=jnp.ones((2,2)), bias=jnp.zeros((2,)))
143 | ```
144 |
145 | The `replace_node` method replaces a pytree node of a module:
146 |
147 | ```python
148 | f = pax.Sequential(
149 | pax.Linear(2, 3),
150 | pax.Linear(3, 4),
151 | )
152 |
153 | f = f.replace_node(f[-1], pax.Linear(3, 5))
154 | print(f.summary())
155 | # Sequential
156 | # ├── Linear(in_dim=2, out_dim=3, with_bias=True)
157 | # └── Linear(in_dim=3, out_dim=5, with_bias=True)
158 | ```
159 |
160 | ## PAX and other libraries
161 |
162 | PAX learns a lot from other libraries:
163 | - PAX borrows the idea that _a module is also a pytree_ from [treex] and [equinox].
164 | - PAX uses the concept of _trainable parameters_ and _non-trainable states_ from [dm-haiku].
165 | - PAX has similar methods to PyTorch such as `model.apply()`, `model.parameters()`, `model.eval()`, etc.
166 | - PAX uses [objax]'s approach to implement optimizers as modules.
167 | - PAX uses [jmp] library for supporting mixed precision.
168 | - And of course, PAX is heavily influenced by [jax] functional programming approach.
169 |
170 |
171 | ## Examples
172 |
173 | A good way to learn about ``PAX`` is to see examples in the [examples/](./examples) directory.
174 |
175 |
176 |
177 | Click to expand
178 |
179 | | Path | Description |
180 | |----------|-----------------------|
181 | | ``char_rnn.py`` | train a RNN language model on TPU. |
182 | | ``transformer/`` | train a Transformer language model on TPU. |
183 | | ``mnist.py`` | train an image classifier on `MNIST` dataset. |
184 | | ``notebooks/VAE.ipynb`` | train a variational autoencoder. |
185 | | ``notebooks/DCGAN.ipynb`` | train a DCGAN model on `Celeb-A` dataset. |
186 | | ``notebooks/fine_tuning_resnet18.ipynb`` | finetune a pretrained ResNet18 model on `cats vs dogs` dataset. |
187 | | ``notebooks/mixed_precision.ipynb`` | train a U-Net image segmentation with mixed precision. |
188 | | ``mnist_mixed_precision.py`` | train an image classifier with mixed precision. |
189 | | ``wave_gru/`` | train a WaveGRU vocoder: convert mel-spectrogram to waveform. |
190 | | ``denoising_diffusion/`` | train a denoising diffusion model on `Celeb-A` dataset. |
191 |
192 |
193 |
194 |
195 |
196 |
197 | ## Modules
198 |
199 | At the moment, PAX includes:
200 |
201 | * ``pax.Embed``,
202 | * ``pax.Linear``,
203 | * ``pax.{GRU, LSTM}``,
204 | * ``pax.{BatchNorm1D, BatchNorm2D, LayerNorm, GroupNorm}``,
205 | * ``pax.{Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose}``,
206 | * ``pax.{Dropout, Sequential, Identity, Lambda, RngSeq, EMA}``.
207 |
208 | ## Optimizers
209 |
210 | PAX has its optimizers implemented in a separate library [opax](https://github.com/ntt123/opax). The `opax` library supports many common optimizers such as `adam`, `adamw`, `sgd`, `rmsprop`. Visit opax's GitHub repository for more information.
211 |
212 |
213 | ## Fine-tunning models
214 |
215 | PAX's Module provides the ``pax.freeze_parameters`` transformation to convert all trainable parameters to non-trainable states.
216 |
217 | ```python
218 | net = pax.Sequential(
219 | pax.Linear(28*28, 64),
220 | jax.nn.relu,
221 | pax.Linear(64, 10),
222 | )
223 |
224 | net = pax.freeze_parameters(net)
225 | net = net.set(-1, pax.Linear(64, 2))
226 | ```
227 |
228 | After this, ``net.parameters()`` will only return trainable parameters of the last layer.
229 |
230 |
231 | [jax]: https://github.com/google/jax
232 | [objax]: https://github.com/google/objax
233 | [dm-haiku]: https://github.com/deepmind/dm-haiku
234 | [optax]: https://github.com/deepmind/optax
235 | [jmp]: https://github.com/deepmind/jmp
236 | [pytorch]: https://github.com/pytorch/pytorch
237 | [treex]: https://github.com/cgarciae/treex
238 | [equinox]: https://github.com/patrick-kidger/equinox
239 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/api/basics.rst:
--------------------------------------------------------------------------------
1 | PAX Basics
2 | ==========
3 |
4 | .. currentmodule:: pax
5 |
6 | .. autosummary::
7 | Module
8 | EmptyNode
9 | pure
10 | purecall
11 | seed_rng_key
12 | next_rng_key
13 |
14 |
15 |
16 | PAX's Module
17 | ------------
18 |
19 | .. currentmodule:: pax
20 |
21 | .. autoclass:: Module
22 | :members:
23 | __init__,
24 | parameters,
25 | training,
26 | train,
27 | eval,
28 | update_parameters,
29 | replace,
30 | replace_node,
31 | summary,
32 | apply,
33 | state_dict,
34 | load_state_dict,
35 | __or__,
36 | __mod__
37 |
38 |
39 |
40 | .. autoclass:: ParameterModule
41 | :members:
42 |
43 |
44 | .. autoclass:: StateModule
45 | :members:
46 |
47 |
48 | .. autoclass:: EmptyNode
49 | :members:
50 |
51 |
52 | Purify functions and methods
53 | ----------------------------
54 |
55 | .. currentmodule:: pax
56 |
57 | .. autofunction:: pure
58 |
59 | .. autofunction:: purecall
60 |
61 |
62 | Random Number Generator
63 | -----------------------
64 |
65 | .. autosummary::
66 |
67 | seed_rng_key
68 | next_rng_key
69 |
70 |
71 | seed_rng_key
72 | ~~~~~~~~~~~~
73 |
74 | .. autofunction:: seed_rng_key
75 |
76 |
77 | next_rng_key
78 | ~~~~~~~~~~~~
79 |
80 | .. autofunction:: next_rng_key
81 |
--------------------------------------------------------------------------------
/docs/api/experimental.rst:
--------------------------------------------------------------------------------
1 | Experimental
2 | ============
3 |
4 | .. currentmodule:: pax.experimental
5 |
6 |
7 | .. autosummary::
8 | mutable
9 | Flattener
10 | LazyModule
11 | graph.build_graph_module
12 | default_mp_policy
13 | apply_scaled_gradients
14 | save_weights_to_dict
15 | load_weights_from_dict
16 |
17 |
18 | Mutable
19 | -------
20 |
21 | .. autofunction:: mutable
22 |
23 |
24 | Flattener
25 | ---------
26 |
27 | .. autoclass:: Flattener
28 | :members:
29 |
30 |
31 | Graph API
32 | ---------
33 |
34 | .. currentmodule:: pax.experimental.graph
35 |
36 | .. autoclass:: Node
37 | :members:
38 |
39 | .. autoclass:: InputNode
40 | :members:
41 |
42 | .. autoclass:: GraphModule
43 | :members:
44 |
45 | .. autofunction:: build_graph_module
46 |
47 |
48 | Lazy Module
49 | -----------
50 |
51 | .. currentmodule:: pax.experimental
52 |
53 | .. autoclass:: LazyModule
54 | :members:
55 |
56 |
57 | Mixed Precision
58 | ---------------
59 |
60 | .. currentmodule:: pax.experimental
61 |
62 | .. autofunction:: default_mp_policy
63 | .. autofunction:: apply_scaled_gradients
64 |
65 |
66 | Save and load weights
67 | ---------------------
68 |
69 | .. currentmodule:: pax.experimental
70 |
71 | .. autofunction:: save_weights_to_dict
72 | .. autofunction:: load_weights_from_dict
73 |
--------------------------------------------------------------------------------
/docs/api/modules.rst:
--------------------------------------------------------------------------------
1 | Common Modules
2 | ==============
3 |
4 | .. currentmodule:: pax
5 |
6 | .. autosummary::
7 | Linear
8 | Conv1D
9 | Conv2D
10 | Conv1DTranspose
11 | Conv2DTranspose
12 | BatchNorm1D
13 | BatchNorm2D
14 | LayerNorm
15 | GroupNorm
16 | Sequential
17 | VanillaRNN
18 | LSTM
19 | GRU
20 | MultiHeadAttention
21 | Identity
22 | avg_pool
23 | max_pool
24 |
25 |
26 |
27 |
28 | Linear
29 | ------
30 |
31 |
32 | .. autoclass:: Linear
33 | :members:
34 |
35 |
36 | Dropout
37 | -------
38 |
39 | .. autoclass:: Dropout
40 | :members:
41 |
42 |
43 | Embed
44 | -----
45 |
46 | .. autoclass:: Embed
47 | :members:
48 |
49 |
50 | Convolution
51 | -----------
52 |
53 | Conv1D
54 | ~~~~~~
55 |
56 | .. autoclass:: Conv1D
57 | :members:
58 |
59 | Conv2D
60 | ~~~~~~
61 |
62 | .. autoclass:: Conv2D
63 | :members:
64 |
65 | Conv1DTranspose
66 | ~~~~~~~~~~~~~~~
67 |
68 | .. autoclass:: Conv1DTranspose
69 | :members:
70 |
71 | Conv2DTranspose
72 | ~~~~~~~~~~~~~~~
73 |
74 | .. autoclass:: Conv2DTranspose
75 | :members:
76 |
77 |
78 | Normalization
79 | -------------
80 |
81 |
82 | BatchNorm1D
83 | ~~~~~~~~~~~
84 |
85 | .. autoclass:: BatchNorm1D
86 | :members:
87 |
88 | BatchNorm2D
89 | ~~~~~~~~~~~
90 |
91 | .. autoclass:: BatchNorm2D
92 | :members:
93 |
94 |
95 |
96 | LayerNorm
97 | ~~~~~~~~~
98 |
99 |
100 | .. autoclass:: LayerNorm
101 | :members:
102 |
103 |
104 | GroupNorm
105 | ~~~~~~~~~
106 |
107 |
108 | .. autoclass:: GroupNorm
109 | :members:
110 |
111 |
112 |
113 | Recurrent
114 | ---------
115 |
116 |
117 | VanillaRNN
118 | ~~~~~~~~~~
119 |
120 | .. autoclass:: VanillaRNN
121 | :members:
122 |
123 |
124 | LSTM
125 | ~~~~
126 |
127 | .. autoclass:: LSTM
128 | :members:
129 |
130 |
131 | GRU
132 | ~~~
133 |
134 | .. autoclass:: GRU
135 | :members:
136 |
137 |
138 | Pool
139 | ----
140 |
141 | avg_pool
142 | ~~~~~~~~
143 |
144 | .. autofunction:: avg_pool
145 |
146 |
147 | max_pool
148 | ~~~~~~~~
149 |
150 | .. autofunction:: max_pool
151 |
152 |
153 |
154 |
155 | MultiHeadAttention
156 | ------------------
157 |
158 | .. autoclass:: MultiHeadAttention
159 | :members:
160 |
161 |
162 | Utilities
163 | ---------
164 |
165 | Sequential
166 | ~~~~~~~~~~
167 |
168 | .. autoclass:: Sequential
169 | :members:
170 |
171 |
172 | RngSeq
173 | ~~~~~~
174 |
175 | .. autoclass:: RngSeq
176 | :members:
177 |
178 |
179 | Lambda
180 | ~~~~~~
181 |
182 | .. autoclass:: Lambda
183 |
184 |
185 | Identity
186 | ~~~~~~~~
187 |
188 | .. autoclass:: Identity
189 | :members:
190 |
191 | EMA
192 | ~~~
193 |
194 | .. autoclass:: EMA
195 | :members:
196 |
--------------------------------------------------------------------------------
/docs/api/transformations.rst:
--------------------------------------------------------------------------------
1 | Module Transformations
2 | ======================
3 |
4 | .. currentmodule:: pax
5 |
6 | A module transformation is a pure function that inputs PAX's modules and outputs PAX's modules.
7 |
8 | .. autosummary::
9 |
10 | update_parameters
11 | enable_train_mode
12 | enable_eval_mode
13 | select_parameters
14 | freeze_parameters
15 | unfreeze_parameters
16 | apply_mp_policy
17 | unwrap_mp_policy
18 |
19 | update_parameters
20 | -----------------
21 |
22 | .. autofunction:: update_parameters
23 |
24 |
25 | enable_train_mode
26 | -----------------
27 |
28 | .. autofunction:: enable_train_mode
29 |
30 |
31 | enable_eval_mode
32 | ----------------
33 |
34 | .. autofunction:: enable_eval_mode
35 |
36 |
37 | select_parameters
38 | -----------------
39 |
40 | .. autofunction:: select_parameters
41 |
42 |
43 | freeze_parameters
44 | -----------------
45 |
46 | .. autofunction:: freeze_parameters
47 |
48 |
49 | unfreeze_parameters
50 | -------------------
51 |
52 | .. autofunction:: unfreeze_parameters
53 |
54 |
55 | apply_mp_policy
56 | ---------------
57 |
58 | .. autofunction:: apply_mp_policy
59 |
60 |
61 | unwrap_mp_policy
62 | ----------------
63 |
64 | .. autofunction:: unwrap_mp_policy
65 |
--------------------------------------------------------------------------------
/docs/api/utilities.rst:
--------------------------------------------------------------------------------
1 | Utilities
2 | =========
3 |
4 | .. currentmodule:: pax
5 |
6 |
7 | .. autosummary::
8 | parameters_method
9 | grad
10 | value_and_grad
11 | scan
12 | build_update_fn
13 |
14 |
15 | parameters_method
16 | -----------------
17 |
18 | .. autofunction:: parameters_method
19 |
20 |
21 | grad
22 | ----
23 |
24 | .. autofunction:: grad
25 |
26 |
27 | value_and_grad
28 | --------------
29 |
30 | .. autofunction:: value_and_grad
31 |
32 |
33 | scan
34 | ----
35 |
36 | .. autofunction:: scan
37 |
38 |
39 | build_update_fn
40 | ---------------
41 |
42 | .. autofunction:: build_update_fn
43 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # This file is an adaptation from
2 | # https://raw.githubusercontent.com/deepmind/dm-haiku/main/docs/conf.py
3 | # which is under Apache License, Version 2.0.
4 |
5 |
6 | # Configuration file for the Sphinx documentation builder.
7 | #
8 | # This file only contains a selection of the most common options. For a full
9 | # list see the documentation:
10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
11 |
12 | # -- Path setup --------------------------------------------------------------
13 |
14 | # If extensions (or modules to document with autodoc) are in another directory,
15 | # add these directories to sys.path here. If the directory is relative to the
16 | # documentation root, use os.path.abspath to make it absolute, like shown here.
17 | #
18 | import doctest
19 | import inspect
20 | import os
21 | import sys
22 |
23 | sys.path.insert(0, os.path.abspath(".."))
24 |
25 | import pax
26 | import sphinxcontrib.katex as katex
27 |
28 | # -- Project information -----------------------------------------------------
29 |
30 | project = "PAX"
31 | copyright = "2021, Thông Nguyễn"
32 | author = "Thông Nguyễn"
33 |
34 |
35 | # -- General configuration ---------------------------------------------------
36 | master_doc = "index"
37 |
38 |
39 | # Add any Sphinx extension module names here, as strings. They can be
40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
41 | # ones.
42 | extensions = [
43 | "sphinx.ext.autodoc",
44 | "sphinx.ext.autosummary",
45 | "sphinx.ext.doctest",
46 | "sphinx.ext.inheritance_diagram",
47 | "sphinx.ext.intersphinx",
48 | "sphinx.ext.linkcode",
49 | "sphinx.ext.napoleon",
50 | "sphinxcontrib.bibtex",
51 | "sphinxcontrib.katex",
52 | "sphinx_autodoc_typehints",
53 | "nbsphinx",
54 | "IPython.sphinxext.ipython_console_highlighting",
55 | ]
56 |
57 |
58 | # Add any paths that contain templates here, relative to this directory.
59 | templates_path = ["_templates"]
60 |
61 | # List of patterns, relative to source directory, that match files and
62 | # directories to ignore when looking for source files.
63 | # This pattern also affects html_static_path and html_extra_path.
64 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
65 |
66 |
67 | # -- Options for autodoc -----------------------------------------------------
68 |
69 | autodoc_default_options = {
70 | "member-order": "bysource",
71 | "special-members": True,
72 | "exclude-members": "__repr__, __str__, __weakref__",
73 | }
74 |
75 |
76 | # -- Options for HTML output -------------------------------------------------
77 |
78 |
79 | # The theme to use for HTML and HTML Help pages. See the documentation for
80 | # a list of builtin themes.
81 | #
82 | html_theme = "sphinx_rtd_theme"
83 |
84 | # Add any paths that contain custom static files (such as style sheets) here,
85 | # relative to this directory. They are copied after the builtin static files,
86 | # so a file named "default.css" will overwrite the builtin "default.css".
87 | # html_static_path = ["_static"]
88 |
89 |
90 | # -- Options for doctest -----------------------------------------------------
91 |
92 | doctest_test_doctest_blocks = "true"
93 | doctest_global_setup = """
94 | import jax
95 | import jax.numpy as jnp
96 | import pax
97 | import opax
98 | pax.seed_rng_key(42)
99 | """
100 | doctest_default_flags = (
101 | doctest.ELLIPSIS
102 | | doctest.IGNORE_EXCEPTION_DETAIL
103 | | doctest.DONT_ACCEPT_TRUE_FOR_1
104 | | doctest.NORMALIZE_WHITESPACE
105 | )
106 |
107 |
108 | # -- Options for katex ------------------------------------------------------
109 |
110 | # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html
111 | latex_macros = r"""
112 | \def \d #1{\operatorname{#1}}
113 | """
114 |
115 | # Translate LaTeX macros to KaTeX and add to options for HTML builder
116 | katex_macros = katex.latex_defs_to_katex_macros(latex_macros)
117 | katex_options = "macros: {" + katex_macros + "}"
118 |
119 | # Add LaTeX macros for LATEX builder
120 | latex_elements = {"preamble": latex_macros}
121 |
122 |
123 | # -- Source code links -------------------------------------------------------
124 |
125 |
126 | def linkcode_resolve(domain, info):
127 | """Resolve a GitHub URL corresponding to Python object."""
128 | if domain != "py":
129 | return None
130 |
131 | try:
132 | mod = sys.modules[info["module"]]
133 | except ImportError:
134 | return None
135 |
136 | obj = mod
137 | try:
138 | for attr in info["fullname"].split("."):
139 | obj = getattr(obj, attr)
140 | except AttributeError:
141 | return None
142 | else:
143 | obj = inspect.unwrap(obj)
144 |
145 | try:
146 | filename = inspect.getsourcefile(obj)
147 | except TypeError:
148 | return None
149 |
150 | try:
151 | source, lineno = inspect.getsourcelines(obj)
152 | except OSError:
153 | return None
154 |
155 | return "https://github.com/ntt123/pax/blob/main/pax/%s#L%d#L%d" % (
156 | os.path.relpath(filename, start=os.path.dirname(pax.__file__)),
157 | lineno,
158 | lineno + len(source) - 1,
159 | )
160 |
161 |
162 | # -- nbsphinx configuration --------------------------------------------------
163 |
164 | nbsphinx_execute = "never"
165 | nbsphinx_codecell_lexer = "ipython"
166 | nbsphinx_kernel_name = "python"
167 | nbsphinx_timeout = 180
168 | nbsphinx_prolog = r"""
169 | {% set docname = 'docs/' + env.doc2path(env.docname, base=None) %}
170 |
171 | .. only:: html
172 |
173 | .. role:: raw-html(raw)
174 | :format: html
175 |
176 | .. nbinfo::
177 |
178 | Interactive online version:
179 | :raw-html:`
`
180 | """
181 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/ntt123/pax/tree/main/docs
2 |
3 |
4 | .. PAX documentation master file, created by
5 | sphinx-quickstart on Fri Sep 3 01:09:13 2021.
6 | You can adapt this file completely to your liking, but it should at least
7 | contain the root `toctree` directive.
8 |
9 | PAX documentation
10 | =================
11 |
12 | PAX is a stateful pytree library for training neural networks using JAX. It is designed to be simple
13 | and easy to use while preserving benefits of JAX.
14 |
15 |
16 | Installation
17 | ------------
18 |
19 | To install the latest version::
20 |
21 | pip install git+https://github.com/ntt123/pax.git
22 |
23 |
24 | .. toctree::
25 | :caption: Guides
26 | :maxdepth: 1
27 |
28 | notebooks/basics
29 | notebooks/training
30 | notebooks/operators
31 | notebooks/understanding
32 | notebooks/jax_transformations
33 | notebooks/performance
34 |
35 |
36 | .. toctree::
37 | :caption: API Documentation
38 | :maxdepth: 1
39 |
40 | api/basics
41 | api/modules
42 | api/transformations
43 | api/utilities
44 | api/experimental
45 |
46 |
47 |
48 |
49 | PAX is licensed under the MIT License.
50 |
51 | Indices
52 | =======
53 |
54 | * :ref:`genindex`
55 |
--------------------------------------------------------------------------------
/docs/notebooks/operators.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Operators\n",
8 | "\n",
9 | "There are a few operators that help to clean up the implementation.\n",
10 | "\n",
11 | "\n",
12 | "\n",
13 | "| Text | Operator |\n",
14 | "| ----------- | ----------- |\n",
15 | "| `mod, z = pax.purecall(mod, x, y)` | `mod, z = mod % (x, y)` |\n",
16 | "| `mod.parameters()` | `~mod` |\n",
17 | "| `pax.update_pytree(mod1, mod2)` | `mod1 | mod2` |\n",
18 | "| `mod1.update_parameters(mod2)` | `mod1 | ~mod2` |\n",
19 | "| `f = pax.Sequential(mod1, mod2)` | `f = pax.Sequential() >> mod1 >> mod2` |\n",
20 | "\n",
21 | "\n",
22 | "\n"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": []
29 | }
30 | ],
31 | "metadata": {
32 | "interpreter": {
33 | "hash": "4f946df053fbf2b937619d3c5458e7af74262f9a954d8797ba0b27400bcafe06"
34 | },
35 | "kernelspec": {
36 | "display_name": "Python 3.8.6 64-bit",
37 | "name": "python3"
38 | },
39 | "language_info": {
40 | "codemirror_mode": {
41 | "name": "ipython",
42 | "version": 3
43 | },
44 | "file_extension": ".py",
45 | "mimetype": "text/x-python",
46 | "name": "python",
47 | "nbconvert_exporter": "python",
48 | "pygments_lexer": "ipython3",
49 | "version": "3.8.6"
50 | },
51 | "orig_nbformat": 4
52 | },
53 | "nbformat": 4,
54 | "nbformat_minor": 2
55 | }
56 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | docutils==0.16
2 | ipykernel==5.3.4
3 | ipython==7.16.3
4 | Jinja2==2.11.3
5 | jq==1.1.1
6 | markupsafe==2.0.1
7 | matplotlib==3.3.3
8 | nbsphinx==0.8.0
9 | pandoc==1.0.2
10 | pygments==2.7.4
11 | seaborn==0.11.1
12 | sphinx_rtd_theme==0.5.0
13 | sphinx-autodoc-typehints==1.11.1
14 | sphinx==3.3.0
15 | sphinxcontrib-bibtex==1.0.0
16 | sphinxcontrib-katex==0.7.1
17 |
18 |
19 | # pax requirements
20 | jax
21 | jaxlib
22 | jmp
23 | numpy
--------------------------------------------------------------------------------
/examples/basics.py:
--------------------------------------------------------------------------------
1 | """PAX basic stuffs."""
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import opax
6 | import pax
7 | from opax import GradientTransformation
8 |
9 |
10 | class Linear(pax.Module):
11 | """A linear module with counter."""
12 |
13 | weight: jnp.ndarray
14 | bias: jnp.ndarray
15 | counter: jnp.ndarray
16 | parameters = pax.parameters_method("weight", "bias")
17 |
18 | def __init__(self):
19 | super().__init__()
20 | self.weight = jax.random.normal(pax.next_rng_key(), (1,))
21 | self.bias = jax.random.normal(pax.next_rng_key(), (1,))
22 | self.counter = jnp.array(0)
23 |
24 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
25 | self.counter = self.counter + 1
26 | x = self.weight * x + self.bias
27 | return x
28 |
29 |
30 | def loss_fn(model: Linear, x: jnp.ndarray, y: jnp.ndarray):
31 | model, y_hat = pax.purecall(model, x)
32 | loss = jnp.mean(jnp.square(y_hat - y))
33 | return loss, model
34 |
35 |
36 | @jax.jit
37 | def train_step(model: Linear, optimizer: GradientTransformation, x, y):
38 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, x, y)
39 | model, optimizer = opax.apply_gradients(model, optimizer, grads)
40 | return model, optimizer, loss
41 |
42 |
43 | def main():
44 | # random seed
45 | pax.seed_rng_key(42)
46 |
47 | # model & optimizer
48 | net = Linear()
49 | print(net.summary())
50 | opt = opax.adam(1e-1).init(net.parameters())
51 |
52 | # data
53 | x = jax.random.normal(pax.next_rng_key(), (32, 1))
54 | y = jax.random.normal(pax.next_rng_key(), (32, 1))
55 |
56 | # training loop
57 | for _ in range(10):
58 | net, opt, loss = train_step(net, opt, x, y)
59 | print(f"step {net.counter:>2} loss {loss:.3f}")
60 |
61 |
62 | if __name__ == "__main__":
63 | main()
64 |
--------------------------------------------------------------------------------
/examples/char_rnn.py:
--------------------------------------------------------------------------------
1 | """Train a rnn language model on TPU (if available)."""
2 |
3 | import inspect
4 | import os
5 | from functools import partial
6 | from typing import List, Tuple
7 |
8 | import jax
9 | import jax.numpy as jnp
10 | import jax.tools.colab_tpu
11 | import opax
12 | import pax
13 | import tensorflow as tf
14 | from tqdm.auto import tqdm
15 |
16 | pax.seed_rng_key(42)
17 |
18 |
19 | def setup_tpu_device():
20 | print("Setting up TPU cores")
21 | jax.tools.colab_tpu.setup_tpu()
22 | print(jax.devices())
23 |
24 |
25 | if "COLAB_TPU_ADDR" in os.environ:
26 | # TPU config
27 | setup_tpu_device()
28 | steps_per_update = 50
29 | num_devices = jax.device_count()
30 | batch_size = 32 * num_devices * steps_per_update
31 | seq_len = 128
32 | vocab_size = 256
33 | hidden_dim = 512
34 | num_steps = 50_000
35 | else:
36 | # CPU/GPU config
37 | steps_per_update = 1
38 | num_devices = jax.device_count()
39 | batch_size = 1 * num_devices * steps_per_update
40 | seq_len = 64
41 | vocab_size = 256
42 | hidden_dim = 256
43 | num_steps = 20_000
44 |
45 |
46 | class LM(pax.Module):
47 | """A RNN language model."""
48 |
49 | lstm: pax.Module
50 | embed: pax.Module
51 | output: pax.Module
52 |
53 | vocab_size: int
54 | hidden_dim: int
55 |
56 | def __init__(self, vocab_size: int, hidden_dim: int):
57 | """
58 | Arguments:
59 | vocab_size: int, size of the alphabet.
60 | hidden_dim: int, number of LSTM cells.
61 | """
62 | super().__init__()
63 | self.vocab_size = vocab_size
64 | self.hidden_dim = hidden_dim
65 | self.embed = pax.Embed(vocab_size, hidden_dim)
66 | self.lstm = pax.LSTM(hidden_dim, hidden_dim)
67 | self.output = pax.Linear(hidden_dim, vocab_size)
68 |
69 | def __call__(self, x):
70 | x = self.embed(x)
71 | hx, x = pax.scan(
72 | self.lstm,
73 | self.lstm.initial_state(x.shape[0]),
74 | x,
75 | time_major=False,
76 | )
77 | del hx
78 | logits = self.output(x)
79 | return logits
80 |
81 | def inference(self, prompt: List[int] = [], length=32):
82 | hx = self.lstm.initial_state(1)
83 | if len(prompt) == 0:
84 | prompt = [0]
85 |
86 | x = jnp.array([prompt[0]], dtype=jnp.int32)
87 |
88 | total_len = len(prompt) + length
89 |
90 | out = [x]
91 |
92 | @jax.jit
93 | def step(x, hx):
94 | x = self.embed(x)
95 | hx, x = self.lstm(hx, x)
96 | logits = self.output(x)
97 | return logits, hx
98 |
99 | for i in range(1, total_len):
100 | logits, hx = step(x, hx)
101 | if i >= len(prompt):
102 | x = jnp.argmax(logits, axis=-1)
103 | else:
104 | x = jnp.array([prompt[i]], dtype=jnp.int32)
105 | out.append(x)
106 | return jnp.concatenate(out)
107 |
108 |
109 | def loss_fn(model: LM, batch: jnp.ndarray):
110 | inputs = batch[:, :-1]
111 | targets = batch[:, 1:]
112 |
113 | logits = model(inputs)
114 | log_pr = jax.nn.log_softmax(logits, axis=-1)
115 | targets = jax.nn.one_hot(targets, num_classes=model.vocab_size)
116 | loss = -jnp.mean(jnp.sum(targets * log_pr, axis=-1))
117 | return loss
118 |
119 |
120 | def update_step(model_and_optimizer: Tuple[LM, pax.Module], batch: jnp.ndarray):
121 | model, optimizer = model_and_optimizer
122 | loss, grads = jax.value_and_grad(loss_fn)(model, batch)
123 | grads = jax.lax.pmean(grads, axis_name="i")
124 | model, optimizer = opax.apply_gradients(model, optimizer, grads=grads)
125 | return (model, optimizer), loss
126 |
127 |
128 | @partial(jax.pmap, axis_name="i")
129 | def update_fn(model, optimizer, multi_batch: jnp.ndarray):
130 | (model, optimizer), losses = pax.scan(update_step, (model, optimizer), multi_batch)
131 | return model, optimizer, jnp.mean(losses)
132 |
133 |
134 | net = LM(vocab_size=vocab_size, hidden_dim=hidden_dim)
135 |
136 | optimizer = opax.chain(
137 | opax.clip_by_global_norm(1.0),
138 | opax.adam(1e-4),
139 | ).init(net.parameters())
140 |
141 | # replicate on multiple devices
142 | net = jax.device_put_replicated(net, jax.devices())
143 | print(net.summary())
144 | optimizer = jax.device_put_replicated(optimizer, jax.devices())
145 |
146 |
147 | def tokenize(text):
148 | t = [0] + [ord(c) for c in text] # ASCII, 0 is the [START] token
149 | return t
150 |
151 |
152 | def detokenize(tokens):
153 | text = [chr(t) if t != 0 else "[START]" for t in tokens]
154 | return "".join(text)
155 |
156 |
157 | data = inspect.getsource(LM) # a _true_ AGI learns about itself.
158 | data_token = tokenize(data)
159 | test_prompt = "class LM(pax.Module):"
160 |
161 | tfdata = (
162 | tf.data.Dataset.from_tensors(data_token)
163 | .repeat()
164 | .map(
165 | lambda x: tf.image.random_crop(x, [seq_len + 1]),
166 | num_parallel_calls=tf.data.AUTOTUNE,
167 | )
168 | .batch(batch_size)
169 | .prefetch(tf.data.AUTOTUNE)
170 | .as_numpy_iterator()
171 | )
172 |
173 | loss_accum = 0.0, 0
174 | tr = tqdm(range(0, 1 + num_steps, steps_per_update), desc="training")
175 | for step in tr:
176 | batch = next(tfdata)
177 | # (num_devices,) is for jax.pmap, (steps_per_update,) is for pax.scan
178 | batch = jnp.reshape(batch, (num_devices, steps_per_update, -1) + batch.shape[1:])
179 | net, optimizer, losses = update_fn(net, optimizer, batch)
180 | loss_accum = (loss_accum[0] + jnp.mean(losses), loss_accum[1] + 1)
181 | if step % 1000 == 0:
182 | loss = loss_accum[0] / loss_accum[1]
183 | loss_accum = 0.0, 0
184 | # eval on a single device
185 | eval_net = jax.tree_util.tree_map(lambda x: x[0], net.eval())
186 | out = eval_net.inference(
187 | prompt=tokenize(test_prompt),
188 | length=(100 if step < num_steps else 1000),
189 | )
190 | text = detokenize(out.tolist())
191 | tr.write(
192 | f"[step {step}] loss {loss:.3f}\n"
193 | f"Prompt: {test_prompt}\n"
194 | f"========\n"
195 | f"{text}\n"
196 | f"========"
197 | )
198 |
199 | del tfdata # needed to avoid exception
200 |
--------------------------------------------------------------------------------
/examples/dataclass_module.py:
--------------------------------------------------------------------------------
1 | """How to implement a PAX module using python dataclass."""
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Callable, Optional
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import pax
9 |
10 |
11 | @dataclass
12 | class Linear(pax.Module):
13 | """A linear module"""
14 |
15 | in_dim: int
16 | out_dim: int
17 | with_bias: bool = True
18 | name: Optional[str] = None
19 | weight: jnp.ndarray = field(init=False, repr=False)
20 | bias: Optional[jnp.ndarray] = field(init=False, repr=False)
21 | counter: jnp.ndarray = field(init=False)
22 | w_init: Callable = field(default=jax.nn.initializers.normal(), repr=False)
23 | b_init: Callable = field(default=jax.nn.initializers.zeros, repr=False)
24 | parameters = pax.parameters_method("weight", "bias")
25 |
26 | def __post_init__(self):
27 | self.weight = self.w_init(pax.next_rng_key(), (self.in_dim, self.out_dim))
28 | self.bias = None
29 | if self.with_bias:
30 | self.bias = self.b_init(pax.next_rng_key(), (self.out_dim,))
31 | self.counter = jnp.array(0)
32 |
33 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
34 | self.counter += 1
35 | x = jnp.dot(x, self.weight)
36 | if self.with_bias:
37 | x = x + self.bias
38 | return x
39 |
40 |
41 | pax.seed_rng_key(42)
42 |
43 | fc = Linear(3, 4, name="fc1")
44 |
45 | print("Before:", fc)
46 | dummy_x = jnp.empty((32, 3))
47 | fc, y = pax.purecall(fc, dummy_x)
48 | assert y.shape == (32, 4)
49 | print("After :", fc)
50 |
--------------------------------------------------------------------------------
/examples/denoising_diffusion/README.md:
--------------------------------------------------------------------------------
1 | ## Denoising Diffusion Model
2 |
3 | We transcribe the PyTorch model at https://github.com/lucidrains/denoising-diffusion-pytorch.
4 |
5 | The implementation is almost identical to the PyTorch version.
6 | The difference is at how PAX manages random keys. PAX's version uses a `RngSeq` submodule to generates new random keys when needed.
7 |
8 | To train model:
9 |
10 | ```sh
11 | pip install -r requirements.txt
12 | python3 train.py
13 | ```
--------------------------------------------------------------------------------
/examples/denoising_diffusion/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 | import tensorflow_datasets as tfds
5 |
6 | ### load celeb_a dataset
7 |
8 | # This is a hack to use a custom link to celeb-a dataset in tensorflow-datasets.
9 | # replace the ``tfds.image.CelebA._split_generators`` method by the following method
10 | # which uses our custom links.
11 |
12 | IMG_ALIGNED_DATA = (
13 | "https://drive.google.com/uc?export=download&"
14 | "id=1iQRFaGXRiPBd-flIm0u-u8Jy6CfJ_q6j"
15 | )
16 |
17 | EVAL_LIST = (
18 | "https://drive.google.com/uc?export=download&"
19 | "id=1ab9MDLOblszbKKXoDe8jumFsSkn6lIX1"
20 | )
21 | # Landmark coordinates: left_eye, right_eye etc.
22 | LANDMARKS_DATA = (
23 | "https://drive.google.com/uc?export=download&"
24 | "id=1y8qfK-jaq1QWl9v_n_mBNIMu5-h3UXK4"
25 | )
26 |
27 | # Attributes in the image (Eyeglasses, Mustache etc).
28 | ATTR_DATA = (
29 | "https://drive.google.com/uc?export=download&"
30 | "id=1BPfcVuIqrAsJAgG40-XGWU7g2wmmQU30"
31 | )
32 |
33 |
34 | def _split_generators(self, dl_manager):
35 | downloaded_dirs = dl_manager.download(
36 | {
37 | "img_align_celeba": IMG_ALIGNED_DATA,
38 | "list_eval_partition": EVAL_LIST,
39 | "list_attr_celeba": ATTR_DATA,
40 | "landmarks_celeba": LANDMARKS_DATA,
41 | }
42 | )
43 |
44 | # Load all images in memory (~1 GiB)
45 | # Use split to convert: `img_align_celeba/000005.jpg` -> `000005.jpg`
46 | all_images = {
47 | os.path.split(k)[-1]: img
48 | for k, img in dl_manager.iter_archive(downloaded_dirs["img_align_celeba"])
49 | }
50 |
51 | return [
52 | tfds.core.SplitGenerator(
53 | name=tfds.Split.TRAIN,
54 | gen_kwargs={
55 | "file_id": 0,
56 | "downloaded_dirs": downloaded_dirs,
57 | "downloaded_images": all_images,
58 | },
59 | ),
60 | tfds.core.SplitGenerator(
61 | name=tfds.Split.VALIDATION,
62 | gen_kwargs={
63 | "file_id": 1,
64 | "downloaded_dirs": downloaded_dirs,
65 | "downloaded_images": all_images,
66 | },
67 | ),
68 | tfds.core.SplitGenerator(
69 | name=tfds.Split.TEST,
70 | gen_kwargs={
71 | "file_id": 2,
72 | "downloaded_dirs": downloaded_dirs,
73 | "downloaded_images": all_images,
74 | },
75 | ),
76 | ]
77 |
78 |
79 | img_mean = 0.5
80 | img_scale = 0.5
81 | image_size = 64 # size of input image: 64x64
82 |
83 | tfds.image.CelebA._split_generators = _split_generators
84 |
85 |
86 | def load_celeb_a():
87 | ds = tfds.load("celeb_a")
88 |
89 | def img_ops(x):
90 | img = tf.cast(x["image"], tf.float32) / 255.0
91 | img = tf.image.resize(
92 | img, (image_size * 2, image_size), preserve_aspect_ratio=True
93 | )
94 | img = tf.image.crop_to_bounding_box(img, 7, 0, 64, 64)
95 | img = (img - img_mean) / img_scale
96 | return img
97 |
98 | dataset = (
99 | ds["train"].concatenate(ds["validation"]).concatenate(ds["test"]).map(img_ops)
100 | )
101 | return dataset
102 |
--------------------------------------------------------------------------------
/examples/denoising_diffusion/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | fire
3 | jax
4 | opax
5 | pax3
6 | pillow
7 | tensorflow
8 | tqdm
--------------------------------------------------------------------------------
/examples/denoising_diffusion/train.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import fire
4 | import jax
5 | import jax.numpy as jnp
6 | import numpy as np
7 | import opax
8 | import pax
9 | import tensorflow as tf
10 | from PIL import Image
11 | from tqdm.auto import tqdm
12 |
13 | from data_loader import load_celeb_a
14 | from model import GaussianDiffusion, UNet
15 |
16 |
17 | def make_image_grid(images, padding=2):
18 | """Place images in a square grid."""
19 | n = images.shape[0]
20 | size = int(math.sqrt(n))
21 | assert size * size == n, "expecting a square grid"
22 | img = images[0]
23 |
24 | H = img.shape[0] * size + padding * (size + 1)
25 | W = img.shape[1] * size + padding * (size + 1)
26 | out = np.zeros((H, W, img.shape[-1]), dtype=img.dtype)
27 | for i in range(n):
28 | x = i % size
29 | y = i // size
30 | xstart = x * (img.shape[0] + padding) + padding
31 | xend = xstart + img.shape[0]
32 | ystart = y * (img.shape[1] + padding) + padding
33 | yend = ystart + img.shape[1]
34 | out[xstart:xend, ystart:yend, :] = images[i]
35 | return out
36 |
37 |
38 | def train(
39 | batch_size: int = 32,
40 | learning_rate: float = 1e-4,
41 | num_training_steps: int = 10_000,
42 | log_freq: int = 1000,
43 | image_size: int = 64,
44 | random_seed: int = 42,
45 | ):
46 |
47 | pax.seed_rng_key(random_seed)
48 |
49 | model = UNet(dim=64, dim_mults=(1, 2, 4, 8))
50 |
51 | diffusion = GaussianDiffusion(
52 | model,
53 | image_size=image_size,
54 | timesteps=1000,
55 | loss_type="l1", # L1 or L2
56 | )
57 |
58 | dataset = load_celeb_a()
59 |
60 | dataloader = (
61 | dataset.repeat()
62 | .shuffle(batch_size * 100)
63 | .batch(batch_size)
64 | .take(num_training_steps)
65 | .prefetch(tf.data.AUTOTUNE)
66 | )
67 |
68 | def loss_fn(model, inputs):
69 | model, loss = pax.purecall(model, inputs)
70 | return loss, (loss, model)
71 |
72 | update_fn = pax.utils.build_update_fn(loss_fn)
73 | fast_update_fn = jax.jit(update_fn)
74 |
75 | optimizer = opax.adam(learning_rate)(diffusion.parameters())
76 |
77 | total_loss = 0.0
78 | tr = tqdm(dataloader)
79 | for step, batch in enumerate(tr, 1):
80 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
81 | diffusion, optimizer, loss = fast_update_fn(diffusion, optimizer, batch)
82 | total_loss = total_loss + loss
83 |
84 | if step % log_freq == 0:
85 | loss = total_loss / log_freq
86 | total_loss = 0.0
87 | tr.write(f"[step {step:05d}] train loss {loss:.3f}")
88 |
89 | imgs = jax.device_get(diffusion.eval().sample(16))
90 | imgs = ((imgs * 0.5 + 0.5) * 255).astype(jnp.uint8)
91 | imgs = make_image_grid(imgs)
92 | im = Image.fromarray(imgs)
93 | im.save(f"sample_{step:05d}.png")
94 |
95 |
96 | if __name__ == "__main__":
97 | fire.Fire(train)
98 |
--------------------------------------------------------------------------------
/examples/graph_module.py:
--------------------------------------------------------------------------------
1 | """A model as a directed graph."""
2 |
3 | import jax
4 | import pax
5 | import jax.numpy as jnp
6 | from pax.experimental.graph import Node, build_graph_module
7 |
8 | pax.seed_rng_key(42)
9 |
10 |
11 | def residual_net(x: Node):
12 | _, D = x.shape
13 | y = x >> pax.Linear(D, D) >> jax.nn.relu >> pax.Linear(D, D) >> pax.Dropout(0.2)
14 | z = (x | y) >> jax.lax.add
15 | return z
16 |
17 |
18 | inputs = jnp.ones((3, 8))
19 | net = build_graph_module(residual_net)(inputs)
20 | print(net.summary())
21 | net, _ = pax.purecall(net, inputs)
22 |
--------------------------------------------------------------------------------
/examples/lazy_module.py:
--------------------------------------------------------------------------------
1 | """A forward function that builds the model on the fly."""
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import opax
6 | import pax
7 |
8 |
9 | @pax.pure
10 | def forward(net: pax.experimental.LazyModule, x):
11 | fc1 = net.get_or_create("fc1", lambda: pax.Linear(1, 1))
12 | x = jax.nn.relu(fc1(x))
13 | fc2 = net.get_or_create("fc2", lambda: pax.Linear(1, 1))
14 | x = fc2(x)
15 | return net, x
16 |
17 |
18 | def loss_fn(model, x: jnp.ndarray, y: jnp.ndarray):
19 | model, y_hat = forward(model, x)
20 | loss = jnp.mean(jnp.square(y_hat - y))
21 | return loss, model
22 |
23 |
24 | @jax.jit
25 | def train_step(model, optimizer: opax.GradientTransformation, x, y):
26 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, x, y)
27 | model, optimizer = opax.apply_gradients(model, optimizer, grads)
28 | return model, optimizer, loss
29 |
30 |
31 | def train():
32 | "train a lazy model."
33 |
34 | pax.seed_rng_key(42)
35 |
36 | # data
37 | x = jax.random.normal(pax.next_rng_key(), (32, 1))
38 | y = jax.random.normal(pax.next_rng_key(), (32, 1))
39 |
40 | # model & optimizer
41 | net, _ = forward(pax.experimental.LazyModule(), x)
42 | print(net.summary())
43 | opt = opax.adam(1e-1)(net.parameters())
44 |
45 | # training loop
46 | for step in range(10):
47 | net, opt, loss = train_step(net, opt, x, y)
48 | print(f"step {step} loss {loss:.3f}")
49 |
50 | return net
51 |
52 |
53 | if __name__ == "__main__":
54 | train()
55 |
--------------------------------------------------------------------------------
/examples/mnist.py:
--------------------------------------------------------------------------------
1 | """train a handwritten digit classifier."""
2 |
3 | import pickle
4 | from pathlib import Path
5 | from typing import Mapping
6 |
7 | import fire
8 | import jax
9 | import jax.numpy as jnp
10 | import opax
11 | import pax
12 | import tensorflow_datasets as tfds
13 | from opax import GradientTransformation
14 | from tqdm.auto import tqdm
15 |
16 | Batch = Mapping[str, jnp.ndarray]
17 |
18 |
19 | class ConvNet(pax.Module):
20 | """ConvNet module."""
21 |
22 | layers: pax.Sequential
23 |
24 | def __init__(self):
25 | super().__init__()
26 | self.layers = pax.Sequential()
27 | for i in range(5):
28 | self.layers >>= pax.Conv2D((1 if i == 0 else 32), 32, 6, padding="VALID")
29 | self.layers >>= pax.BatchNorm2D(32, True, True, 0.9)
30 | self.layers >>= jax.nn.relu
31 | self.layers >>= pax.Conv2D(32, 10, 3, padding="VALID")
32 |
33 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
34 | x = self.layers(x)
35 | return jnp.squeeze(x, (1, 2))
36 |
37 |
38 | def loss_fn(model: ConvNet, batch: Batch):
39 | x = batch["image"].astype(jnp.float32) / 255
40 | target = batch["label"]
41 | model, logits = pax.purecall(model, x)
42 | log_pr = jax.nn.log_softmax(logits, axis=-1)
43 | log_pr = jnp.sum(jax.nn.one_hot(target, log_pr.shape[-1]) * log_pr, axis=-1)
44 | loss = -jnp.mean(log_pr)
45 | return loss, model
46 |
47 |
48 | @jax.jit
49 | def test_loss_fn(model: ConvNet, batch: Batch):
50 | model = model.eval()
51 | return loss_fn(model, batch)[0]
52 |
53 |
54 | @jax.jit
55 | def update_fn(model: ConvNet, optimizer: GradientTransformation, batch: Batch):
56 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, batch)
57 | params = model.parameters()
58 | optimizer, updates = pax.purecall(optimizer, grads, params)
59 | params = params.map(jax.lax.sub, updates)
60 | model = model.update_parameters(params)
61 | return model, optimizer, loss
62 |
63 |
64 | def load_dataset(split: str):
65 | """Loads the dataset as a tensorflow dataset."""
66 | ds = tfds.load("mnist:3.*.*", split=split)
67 | return ds
68 |
69 |
70 | def save_ckpt(epoch: int, model: ConvNet, path: Path):
71 | model = jax.device_get(model)
72 | with open(path, "wb") as f:
73 | pickle.dump({"epoch": epoch, "state_dict": model.state_dict()}, f)
74 |
75 |
76 | def load_ckpt(model: ConvNet, path: Path):
77 | """Load model from saved tree leaves"""
78 | with open(path, "rb") as f:
79 | dic = pickle.load(f)
80 | return dic["epoch"], model.load_state_dict(dic["state_dict"])
81 |
82 |
83 | def train(
84 | batch_size=32,
85 | num_epochs=10,
86 | learning_rate=1e-4,
87 | weight_decay=1e-4,
88 | ckpt_dir="/tmp",
89 | ):
90 | pax.seed_rng_key(42)
91 |
92 | # model
93 | net = ConvNet()
94 | print(net.summary())
95 |
96 | # optimizer
97 | optimizer = opax.chain(
98 | opax.clip_by_global_norm(1.0),
99 | opax.adamw(learning_rate=learning_rate, weight_decay=weight_decay),
100 | ).init(net.parameters())
101 |
102 | # data
103 | train_data = load_dataset("train").shuffle(10 * batch_size).batch(batch_size)
104 | test_data = load_dataset("test").shuffle(10 * batch_size).batch(batch_size)
105 |
106 | # resume from the latest checkpoint
107 | ckpts = sorted(Path(ckpt_dir).glob("pax_mnist_ckpt_*.pickle"))
108 | if len(ckpts) > 0:
109 | print("loading checkpoint at", ckpts[-1])
110 | last_epoch, net = load_ckpt(net, ckpts[-1])
111 | else:
112 | last_epoch = -1
113 |
114 | # training loop
115 | for epoch in range(last_epoch + 1, num_epochs):
116 | losses = 0.0
117 |
118 | # training
119 | for batch in tqdm(train_data, desc="train", leave=False):
120 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
121 | net, optimizer, loss = update_fn(net, optimizer, batch)
122 | losses = losses + loss
123 | loss = losses / len(train_data)
124 |
125 | # testing
126 | test_losses = 0.0
127 | for batch in tqdm(test_data, desc="test", leave=False):
128 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
129 | test_losses = test_losses + test_loss_fn(net, batch)
130 | test_loss = test_losses / len(test_data)
131 |
132 | save_ckpt(epoch, net, Path(ckpt_dir) / f"pax_mnist_ckpt_{epoch:02d}.pickle")
133 | # logging
134 | print(f"[Epoch {epoch}] train loss {loss:.3f} test loss {test_loss:.3f}")
135 |
136 |
137 | if __name__ == "__main__":
138 | fire.Fire(train)
139 |
--------------------------------------------------------------------------------
/examples/mnist_mixed_precision.py:
--------------------------------------------------------------------------------
1 | """train a handwritten digit classifier with mixed precision."""
2 |
3 | from typing import List, Mapping, Tuple
4 |
5 | import fire
6 | import jax
7 | import jax.numpy as jnp
8 | import jmp
9 | import opax
10 | import pax
11 | import tensorflow_datasets as tfds
12 | from opax.transform import GradientTransformation
13 | from tqdm.auto import tqdm
14 |
15 | Batch = Mapping[str, jnp.ndarray]
16 |
17 |
18 | class ConvNet(pax.Module):
19 | """ConvNet module."""
20 |
21 | layers: List[Tuple[pax.Conv2D, pax.BatchNorm2D]]
22 | output: pax.Conv2D
23 |
24 | def __init__(self):
25 | super().__init__()
26 | self.layers = []
27 | for i in range(5):
28 | conv_in = 1 if i == 0 else 32
29 | conv = pax.Conv2D(conv_in, 32, 6, padding="VALID")
30 | bn = pax.BatchNorm2D(32)
31 | self.layers.append((conv, bn))
32 |
33 | self.output = pax.Conv2D(32, 10, 3, padding="VALID")
34 |
35 | def __call__(self, x: jnp.ndarray):
36 | for conv, bn in self.layers:
37 | x = bn(conv(x))
38 | x = jax.nn.relu(x)
39 | x = self.output(x)
40 | return jnp.squeeze(x, (1, 2))
41 |
42 |
43 | def loss_fn(model: ConvNet, batch: Batch, loss_scale: jmp.LossScale):
44 | x = batch["image"].astype(jnp.float32) / 255
45 | target = batch["label"]
46 | model, logits = pax.purecall(model, x)
47 | log_pr = jax.nn.log_softmax(logits, axis=-1)
48 | log_pr = jnp.sum(jax.nn.one_hot(target, log_pr.shape[-1]) * log_pr, axis=-1)
49 | loss = -jnp.mean(log_pr)
50 | return loss_scale.scale(loss), (loss, model)
51 |
52 |
53 | @jax.jit
54 | def test_loss_fn(model: ConvNet, batch: Batch):
55 | model = model.eval()
56 | return loss_fn(model, batch, jmp.NoOpLossScale())[0]
57 |
58 |
59 | def apply_gradients_w_loss_scale(
60 | model: pax.Module,
61 | optimizer: opax.GradientTransformation,
62 | loss_scale: jmp.LossScale,
63 | grads: pax.Module,
64 | ):
65 | grads = loss_scale.unscale(grads)
66 | skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)
67 | if skip_nonfinite_updates:
68 | grads_finite = jmp.all_finite(grads)
69 | loss_scale = loss_scale.adjust(grads_finite)
70 | model, optimizer = opax.apply_gradients(
71 | model, optimizer, grads=grads, all_finite=grads_finite
72 | )
73 | else:
74 | model, optimizer = opax.apply_gradients(model, optimizer, grads=grads)
75 | return model, optimizer, loss_scale
76 |
77 |
78 | @jax.jit
79 | def update_fn(
80 | model: ConvNet,
81 | optimizer: GradientTransformation,
82 | loss_scale: jmp.LossScale,
83 | batch: Batch,
84 | ):
85 | grad_fn = pax.grad(loss_fn, has_aux=True)
86 | grads, (loss, model) = grad_fn(model, batch, loss_scale=loss_scale)
87 | return apply_gradients_w_loss_scale(model, optimizer, loss_scale, grads) + (loss,)
88 |
89 |
90 | def load_dataset(split: str):
91 | """Loads the dataset as a tensorflow dataset."""
92 | ds = tfds.load("mnist:3.*.*", split=split)
93 | return ds
94 |
95 |
96 | def mp_policy_fn(mod):
97 | half = jmp.half_dtype()
98 | full = jnp.float32
99 | linear_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=full)
100 | bn_policy = jmp.Policy(compute_dtype=full, param_dtype=full, output_dtype=full)
101 |
102 | if isinstance(mod, pax.Conv2D):
103 | return pax.apply_mp_policy(mod, mp_policy=linear_policy)
104 | elif isinstance(mod, pax.BatchNorm2D):
105 | return pax.apply_mp_policy(mod, mp_policy=bn_policy)
106 | else:
107 | return mod # unchanged
108 |
109 |
110 | def train(batch_size=32, num_epochs=5, learning_rate=1e-4, weight_decay=1e-4):
111 | pax.seed_rng_key(42)
112 |
113 | net = ConvNet()
114 | net = net.apply(mp_policy_fn)
115 | print(net.summary())
116 | optimizer = opax.chain(
117 | opax.clip_by_global_norm(1.0),
118 | opax.adamw(learning_rate=learning_rate, weight_decay=weight_decay),
119 | ).init(net.parameters())
120 |
121 | loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15), period=2000)
122 |
123 | train_data = (
124 | load_dataset("train")
125 | .shuffle(10 * batch_size)
126 | .batch(batch_size, drop_remainder=True)
127 | )
128 | test_data = load_dataset("test").batch(batch_size, drop_remainder=True)
129 |
130 | for epoch in range(0, num_epochs):
131 | losses = 0.0
132 | for batch in tqdm(train_data, desc="train", leave=False):
133 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
134 | net, optimizer, loss_scale, loss = update_fn(
135 | net, optimizer, loss_scale, batch
136 | )
137 | losses = losses + loss
138 | loss = losses / len(train_data)
139 |
140 | test_losses = 0.0
141 | for batch in tqdm(test_data, desc="eval", leave=False):
142 | batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
143 | test_losses = test_losses + test_loss_fn(net, batch)
144 | test_loss = test_losses / len(test_data)
145 |
146 | print(
147 | f"[Epoch {epoch}] train loss {loss:.3f} test loss"
148 | f" {test_loss:.3f} loss scale {loss_scale.loss_scale}"
149 | )
150 |
151 |
152 | if __name__ == "__main__":
153 | fire.Fire(train)
154 |
--------------------------------------------------------------------------------
/examples/notebooks/pretrained_resnet18.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pax
3 | import torchvision
4 |
5 | IMAGENET_MEAN = np.array((0.485, 0.456, 0.406))
6 | IMAGENET_STD = np.array((0.229, 0.224, 0.225))
7 |
8 |
9 | def convert_conv(conv, name=None):
10 | """Return a pax.Conv2D module with weights from pretrained ``conv``."""
11 | weight = conv.weight.data.contiguous().permute(2, 3, 1, 0).contiguous().numpy()[:]
12 |
13 | pax_conv = pax.Conv2D(
14 | in_features=conv.in_channels,
15 | out_features=conv.out_channels,
16 | kernel_shape=conv.kernel_size,
17 | stride=conv.stride,
18 | with_bias=False,
19 | padding=[(conv.padding[0],) * 2, (conv.padding[1],) * 2],
20 | data_format="NCHW",
21 | name=name,
22 | )
23 | assert pax_conv.weight.shape == weight.shape
24 | return pax_conv.replace(weight=weight)
25 |
26 |
27 | def convert_bn(bn, name=None):
28 | """Return a pax.BatchNorm2D module from pretrained ``bn``."""
29 | weight = bn.weight.data.numpy()[None, :, None, None]
30 | bias = bn.bias.data.numpy()[None, :, None, None]
31 | running_mean = bn.running_mean.data.numpy()[None, :, None, None]
32 | running_var = bn.running_var.data.numpy()[None, :, None, None]
33 |
34 | pax_bn = pax.BatchNorm2D(
35 | num_channels=bias.shape[1],
36 | create_offset=True,
37 | create_scale=True,
38 | decay_rate=0.9,
39 | eps=1e-5,
40 | data_format="NCHW",
41 | name=name,
42 | )
43 | assert pax_bn.scale.shape == weight.shape
44 | assert pax_bn.offset.shape == bias.shape
45 | assert pax_bn.ema_mean.averages.shape == running_mean.shape
46 | assert pax_bn.ema_var.averages.shape == running_var.shape
47 |
48 | pax_bn = pax_bn.replace(scale=weight, offset=bias)
49 | pax_bn = pax_bn.replace_node(pax_bn.ema_mean.averages, running_mean)
50 | pax_bn = pax_bn.replace_node(pax_bn.ema_var.averages, running_var)
51 | return pax_bn
52 |
53 |
54 | def convert_basic_block(block):
55 | conv1 = convert_conv(block.conv1, name="conv1")
56 | bn1 = convert_bn(block.bn1, name="bn1")
57 | conv2 = convert_conv(block.conv2, name="conv2")
58 | bn2 = convert_bn(block.bn2, name="bn2")
59 |
60 | if block.downsample is not None:
61 | conv0 = convert_conv(block.downsample[0], name="proj_conv")
62 | bn0 = convert_bn(block.downsample[1], name="proj_bn")
63 | return ((conv1, bn1), (conv2, bn2)), (conv0, bn0)
64 | else:
65 | return (((conv1, bn1), (conv2, bn2)),)
66 |
67 |
68 | def convert_block_group(group):
69 | out = []
70 | for i in range(len(group)):
71 | out.append(convert_basic_block(group[i]))
72 | return out
73 |
74 |
75 | def convert_linear(linear):
76 | weight = linear.weight.data.numpy()
77 | bias = linear.bias.data.numpy()
78 | pax_linear = pax.Linear(
79 | in_dim=weight.shape[1], out_dim=weight.shape[0], with_bias=True
80 | )
81 | weight = np.transpose(weight)
82 | assert pax_linear.bias.shape == bias.shape
83 | assert pax_linear.weight.shape == weight.shape
84 |
85 | return pax_linear.replace(weight=weight, bias=bias)
86 |
87 |
88 | def load_pretrained_resnet18():
89 | resnet18 = pax.nets.ResNet18(3, 1000)
90 | resnet18_pt = torchvision.models.resnet18(pretrained=True).eval()
91 | pax_resnet = [
92 | convert_conv(resnet18_pt.conv1),
93 | convert_bn(resnet18_pt.bn1),
94 | convert_block_group(resnet18_pt.layer1),
95 | convert_block_group(resnet18_pt.layer2),
96 | convert_block_group(resnet18_pt.layer3),
97 | convert_block_group(resnet18_pt.layer4),
98 | convert_linear(resnet18_pt.fc),
99 | ]
100 |
101 | def replace_parts(resnet18):
102 | # replace resnet18 part by part
103 | resnet18.initial_conv = pax_resnet[0]
104 | resnet18.initial_batchnorm = pax_resnet[1]
105 | for i in range(len(resnet18.block_groups)):
106 | bg = resnet18.block_groups[i]
107 | for j in range(len(bg.blocks)):
108 | b = bg.blocks[j]
109 | mods = pax_resnet[2 + i][j]
110 | b.layers = mods[0]
111 | if b.use_projection:
112 | b.proj_conv = mods[1][0]
113 | b.proj_batchnorm = mods[1][1]
114 |
115 | resnet18.logits = pax_resnet[-1]
116 | # make sure we are in `eval` mode when doing evaluation.
117 | return resnet18.eval()
118 |
119 | return pax.pure(replace_parts)(resnet18)
120 |
--------------------------------------------------------------------------------
/examples/transformer/data.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 |
6 | def tokenize(text):
7 | t = [0] + [ord(c) for c in text] # ASCII, 0 is the [START] token
8 | return t
9 |
10 |
11 | def detokenize(tokens):
12 | text = [chr(t) if t != 0 else "[START]" for t in tokens]
13 | return "".join(text)
14 |
15 |
16 | def _device_put_sharded(sharded_tree, devices):
17 | leaves, treedef = jax.tree_util.tree_flatten(sharded_tree)
18 | n = leaves[0].shape[0]
19 | return jax.device_put_sharded(
20 | [
21 | jax.tree_util.tree_unflatten(treedef, [l[i] for l in leaves])
22 | for i in range(n)
23 | ],
24 | devices,
25 | )
26 |
27 |
28 | # Source: https://github.com/deepmind/dm-haiku/blob/8fad8c7503c5f56fa9ea9b53f71b7082704e3a3e/examples/imagenet/dataset.py#L163
29 | def double_buffer(ds, num_devices, steps_per_update):
30 | """Keeps at least two batches on the accelerator.
31 | The current GPU allocator design reuses previous allocations. For a training
32 | loop this means batches will (typically) occupy the same region of memory as
33 | the previous batch. An issue with this is that it means we cannot overlap a
34 | host->device copy for the next batch until the previous step has finished and
35 | the previous batch has been freed.
36 | By double buffering we ensure that there are always two batches on the device.
37 | This means that a given batch waits on the N-2'th step to finish and free,
38 | meaning that it can allocate and copy the next batch to the accelerator in
39 | parallel with the N-1'th step being executed.
40 | Args:
41 | ds: Iterable of batches of numpy arrays.
42 | Yields:
43 | Batches of sharded device arrays.
44 | """
45 | batch = None
46 | devices = jax.devices()
47 | for next_batch in ds:
48 | assert next_batch is not None
49 | next_batch = np.reshape(
50 | next_batch, (num_devices, steps_per_update, -1) + next_batch.shape[1:]
51 | )
52 | next_batch = _device_put_sharded(next_batch, devices)
53 | if batch is not None:
54 | yield batch
55 | batch = next_batch
56 | if batch is not None:
57 | yield batch
58 |
59 |
60 | def make_data_loader(data, seq_len, batch_size, num_devices, steps_per_update):
61 | data_token = tokenize(data)
62 | data_token = [0] * seq_len + data_token
63 |
64 | tfdata = (
65 | tf.data.Dataset.from_tensors(data_token)
66 | .repeat()
67 | .map(
68 | lambda x: tf.image.random_crop(x, [seq_len + 1]),
69 | num_parallel_calls=tf.data.AUTOTUNE,
70 | )
71 | .batch(batch_size)
72 | .prefetch(tf.data.AUTOTUNE)
73 | .as_numpy_iterator()
74 | )
75 |
76 | return double_buffer(tfdata, num_devices, steps_per_update)
77 |
--------------------------------------------------------------------------------
/examples/transformer/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Sequence
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | import pax
7 | from pax.nets import Transformer
8 |
9 |
10 | def positional_encoding(x):
11 | _, L, D = x.shape
12 | position = jnp.arange(0, L, dtype=x.dtype)[:, None]
13 | div_term = jnp.exp(jnp.arange(0, D, 2, dtype=x.dtype) * (-math.log(10_000.0) / D))
14 | x1 = jnp.sin(position * div_term[None, :])
15 | x2 = jnp.cos(position * div_term[None, :])
16 | x_pos = jnp.concatenate((x1, x2), axis=-1)
17 | return x + x_pos[None, :, :]
18 |
19 |
20 | class LM(pax.Module):
21 | """A Transformer language model."""
22 |
23 | transformer: Transformer
24 | embed: pax.Module
25 | output: pax.Module
26 |
27 | vocab_size: int
28 | hidden_dim: int
29 |
30 | def __init__(
31 | self, vocab_size: int, hidden_dim: int, num_layers: int, dropout: float = 0.1
32 | ):
33 | """
34 | Arguments:
35 | vocab_size: int, size of the alphabet.
36 | hidden_dim: int, hidden dim.
37 | num_layers: int, num transformer blocks.
38 | """
39 | super().__init__()
40 | self.vocab_size = vocab_size
41 | self.hidden_dim = hidden_dim
42 | self.embed = pax.Embed(
43 | vocab_size,
44 | hidden_dim,
45 | w_init=jax.nn.initializers.variance_scaling(
46 | 1.0, mode="fan_out", distribution="normal"
47 | ),
48 | )
49 | self.transformer = Transformer(
50 | hidden_dim, hidden_dim // 64, num_layers, dropout_rate=dropout
51 | )
52 | self.output = pax.Linear(hidden_dim, vocab_size)
53 |
54 | def __call__(self, x):
55 | x = self.embed(x)
56 | x = positional_encoding(x)
57 | x = self.transformer(x)
58 | logits = self.output(x)
59 | return logits
60 |
61 | @pax.pure
62 | def inference(self, prompt: Sequence[int] = (), length=1024, train_seq_len=256):
63 | def step(inputs, _):
64 | logits = self(inputs)
65 | x = jnp.argmax(logits[:, -1], axis=-1)
66 | next_inputs = jnp.concatenate((inputs[:, 1:], x[:, None]), axis=-1)
67 | return next_inputs, x
68 |
69 | if len(prompt) > train_seq_len:
70 | inputs = prompt[-train_seq_len:]
71 | else:
72 | inputs = prompt
73 | pad_len = train_seq_len - len(inputs)
74 | padded_inputs = [0] * pad_len + inputs
75 | x = jnp.array([padded_inputs], dtype=jnp.int32)
76 | L = length - len(prompt)
77 | _, out = pax.scan(step, x, None, length=L, time_major=False)
78 | return prompt + out[0].tolist()
79 |
--------------------------------------------------------------------------------
/examples/transformer/train.py:
--------------------------------------------------------------------------------
1 | """Train a transformer language model on TPU (if available)."""
2 |
3 | import inspect
4 | import os
5 | from functools import partial
6 | from typing import Tuple
7 |
8 | import jax
9 | import jax.numpy as jnp
10 | import jax.tools.colab_tpu
11 | import opax
12 | import pax
13 | from opax import GradientTransformation
14 | from tqdm.auto import tqdm
15 |
16 | from data import detokenize, make_data_loader, tokenize
17 | from model import LM
18 |
19 |
20 | def setup_tpu_device():
21 | print("Setting up TPU cores")
22 | jax.tools.colab_tpu.setup_tpu()
23 | print(jax.devices())
24 |
25 |
26 | # shared config
27 | dropout = 0.1
28 | learning_rate = 1e-4
29 | vocab_size = 256
30 | pax.seed_rng_key(42)
31 |
32 | if "COLAB_TPU_ADDR" in os.environ:
33 | # TPU config
34 | # need to config TPU cores _before_ calling `jax.device_count`.
35 | setup_tpu_device()
36 | steps_per_update = 50
37 | num_devices = jax.device_count()
38 | batch_size = 32 * num_devices * steps_per_update
39 | seq_len = 256
40 | hidden_dim = 512
41 | num_steps = 1_000
42 | num_layers = 6
43 | else:
44 | # CPU/GPU config
45 | steps_per_update = 1
46 | num_devices = jax.device_count()
47 | batch_size = 8 * num_devices * steps_per_update
48 | seq_len = 64
49 | hidden_dim = 256
50 | num_steps = 20_000
51 | num_layers = 2
52 |
53 |
54 | def loss_fn(model: LM, batch: jnp.ndarray):
55 | inputs = batch[:, :-1]
56 | targets = batch[:, 1:]
57 |
58 | model, logits = pax.purecall(model, inputs)
59 | log_pr = jax.nn.log_softmax(logits, axis=-1)
60 | targets = jax.nn.one_hot(targets, num_classes=model.vocab_size)
61 | loss = -jnp.mean(jnp.sum(targets * log_pr, axis=-1))
62 | return loss, model
63 |
64 |
65 | def update_step(model_and_optim: Tuple[LM, GradientTransformation], batch: jnp.ndarray):
66 | model, optimizer = model_and_optim
67 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, batch)
68 | grads = jax.lax.pmean(grads, axis_name="i")
69 | params = model.parameters()
70 | optimizer, updates = pax.purecall(optimizer, grads, params)
71 | params = params.map(jax.lax.sub, updates)
72 | model = model.update_parameters(params)
73 | return (model, optimizer), loss
74 |
75 |
76 | @partial(jax.pmap, axis_name="i")
77 | def update_fn(model: LM, optimizer: GradientTransformation, multi_batch: jnp.ndarray):
78 | (model, optimizer), losses = pax.scan(update_step, (model, optimizer), multi_batch)
79 | return model, optimizer, jnp.sum(losses)
80 |
81 |
82 | def train():
83 | net = LM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers)
84 | print(net.summary())
85 | optimizer = opax.chain(
86 | opax.clip_by_global_norm(1.0),
87 | opax.adam(learning_rate),
88 | ).init(net.parameters())
89 |
90 | data = inspect.getsource(LM) # a _true_ AGI learns about itself.
91 | test_prompt = data[:20]
92 | data_iter = make_data_loader(
93 | data,
94 | seq_len=seq_len,
95 | batch_size=batch_size,
96 | num_devices=num_devices,
97 | steps_per_update=steps_per_update,
98 | )
99 |
100 | # replicate on multiple devices
101 | net = jax.device_put_replicated(net, jax.devices())
102 | optimizer = jax.device_put_replicated(optimizer, jax.devices())
103 |
104 | total_losses = 0.0
105 | tr = tqdm(range(0, 1 + num_steps, steps_per_update), desc="training")
106 | for step in tr:
107 | batch = next(data_iter)
108 | # (num_devices,) is for jax.pmap, (steps_per_update,) is for pax.scan
109 | net, optimizer, loss = update_fn(net, optimizer, batch)
110 | total_losses = total_losses + loss
111 | if step % 1000 == 0:
112 | loss = jnp.mean(total_losses) / (1000 if step > 0 else steps_per_update)
113 | total_losses = jnp.zeros_like(total_losses)
114 | # eval on a single device
115 | eval_net = jax.tree_util.tree_map(lambda x: x[0], net.eval())
116 | out = eval_net.inference(
117 | prompt=tokenize(test_prompt),
118 | length=(128 if step < num_steps else 1024),
119 | train_seq_len=seq_len,
120 | )
121 | text = detokenize(out)
122 | tr.write(
123 | f"[step {step}] loss {loss:.3f}\n"
124 | f"Prompt: {test_prompt}\n"
125 | f"========\n"
126 | f"{text}\n"
127 | f"========"
128 | )
129 |
130 |
131 | if __name__ == "__main__":
132 | train()
133 |
--------------------------------------------------------------------------------
/examples/wave_gru/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | This example is an implementation of [Lyra](https://github.com/google/lyra) WaveGRU network.
4 | However, we predict the 8-bit mu-compressed waveform instead of the raw 16-bit waveform.
5 |
6 |
7 | ## Data preparation
8 |
9 | We use `ffmpeg` and `sox` to do audio conversion and silence trimming.
10 |
11 |
12 | To prepare audio clip:
13 |
14 | pip install -r requirements.txt
15 | bash prepare_data.sh
16 |
17 | ## Train WaveGRU
18 |
19 | python3 train.py # 1 hour on a Tesla T4
20 |
--------------------------------------------------------------------------------
/examples/wave_gru/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import librosa
5 | import numpy as np
6 |
7 |
8 | def data_loader(
9 | batch_size: int,
10 | n_mels: int,
11 | n_fft: int,
12 | hop_length: int,
13 | win_length: int,
14 | sample_rate: int,
15 | fmin: int,
16 | fmax: int,
17 | mu: int,
18 | n_frames: int,
19 | split="train",
20 | pad: int = 31,
21 | ):
22 | if not os.path.exists("/tmp/wave_gru_clip.wav"):
23 | os.system("bash /tmp/prepare_clip.sh")
24 |
25 | wav, _ = librosa.load("/tmp/wave_gru_clip.wav", sr=sample_rate)
26 |
27 | L = len(wav) * 9 // 10
28 | if split == "train":
29 | wav = wav[:L]
30 | else:
31 | wav = wav[L:]
32 |
33 | mel = librosa.feature.melspectrogram(
34 | n_mels=n_mels,
35 | y=wav,
36 | sr=sample_rate,
37 | n_fft=n_fft,
38 | hop_length=hop_length,
39 | win_length=win_length,
40 | fmin=fmin,
41 | fmax=fmax,
42 | center=False,
43 | )
44 |
45 | mel = mel.T
46 |
47 | logmel = np.log(1e-3 + mel)
48 | mu_wav = librosa.mu_compress(wav, mu=mu, quantize=True) + mu // 2
49 |
50 | if split == "test":
51 | yield (logmel, mu_wav)
52 | return
53 |
54 | batch = []
55 | while True:
56 | left = random.randint(0, logmel.shape[0] - n_frames - pad * 2)
57 | right = left + pad + n_frames + pad
58 | cond = logmel[left:right] # included padding
59 | x = mu_wav[(left + pad) * hop_length : (right - pad) * hop_length + 1]
60 | batch.append((cond, x))
61 | if len(batch) == batch_size:
62 | conds, xs = zip(*batch)
63 | conds = np.array(conds)
64 | xs = np.array(xs, dtype=np.int16)
65 | yield (conds, xs)
66 | batch = []
67 |
--------------------------------------------------------------------------------
/examples/wave_gru/model.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import pax
6 |
7 |
8 | class UpsampleNet(pax.Module):
9 | """Upsampling melspectrogram."""
10 |
11 | def __init__(self, n_mels, num_output_channels):
12 | super().__init__()
13 | self.input_conv = pax.Conv1D(n_mels, 512, 1, padding="VALID")
14 | self.dilated_convs = []
15 | self.bns = []
16 | for i in range(5):
17 | conv = pax.Conv1D(512, 512, 3, rate=2 ** i, padding="VALID")
18 | self.dilated_convs.append(conv)
19 | self.bns.append(pax.BatchNorm1D(512, True, True, 0.99))
20 | self.upsample_conv_1 = pax.Conv1DTranspose(512, 512, 4, stride=4)
21 | self.upsample_bn1 = pax.BatchNorm1D(512, True, True, 0.99)
22 | self.upsample_conv_2 = pax.Conv1DTranspose(512, 512, 4, stride=4)
23 | self.upsample_bn2 = pax.BatchNorm1D(512, True, True, 0.99)
24 | self.output_conv = pax.Conv1D(512, num_output_channels, 1, padding="VALID")
25 |
26 | def __call__(self, mel):
27 | x = self.input_conv(mel)
28 |
29 | # Large receptive fields
30 | for conv, batch_norm in zip(self.dilated_convs, self.bns):
31 | residual = jax.nn.relu(batch_norm(conv(x)))
32 | pad = (x.shape[1] - residual.shape[1]) // 2
33 | x = x[:, pad:-pad] + residual
34 |
35 | # upsample
36 | x = jax.nn.relu(self.upsample_bn1(self.upsample_conv_1(x)))
37 | x = jax.nn.relu(self.upsample_bn2(self.upsample_conv_2(x)))
38 |
39 | x = self.output_conv(x)
40 |
41 | # tile x16
42 | N, L, D = x.shape
43 | x = jnp.tile(x[:, :, None, :], (1, 1, 16, 1))
44 | x = jnp.reshape(x, (N, -1, D))
45 |
46 | return x
47 |
48 |
49 | class WaveGRU(pax.Module):
50 | def __init__(self, n_mels, hidden_dim, n_mu_bits=8):
51 | super().__init__()
52 | self.n_mu_bits = n_mu_bits
53 | self.hidden_dim = hidden_dim
54 |
55 | self.upsampling = UpsampleNet(n_mels, hidden_dim)
56 | self.gru = pax.GRU(hidden_dim, hidden_dim)
57 | self.logits = pax.Linear(hidden_dim, 2 ** n_mu_bits)
58 | self.embed = pax.Embed(2 ** n_mu_bits, hidden_dim)
59 |
60 | def __call__(self, inputs):
61 | logmel, wav = inputs
62 | x = self.upsampling(logmel)
63 | hx = self.gru.initial_state(x.shape[0])
64 | wav = self.embed(wav)
65 | assert x.shape == wav.shape
66 | x = x + wav
67 | _, x = pax.scan(self.gru, hx, x, time_major=False)
68 | x = self.logits(x)
69 | return x
70 |
71 | def inference(self, logmel, rng_key=None):
72 | if rng_key is None:
73 | rng_key = pax.next_rng_key()
74 |
75 | x = jnp.array([2 ** (self.n_mu_bits - 1)], dtype=jnp.int32)
76 | hx = self.gru.initial_state(1)
77 |
78 | conds = self.upsampling(logmel)
79 |
80 | def loop(prev_state, inputs):
81 | x, hx, rng_key = prev_state
82 | rng_key, next_rng_key = jax.random.split(rng_key)
83 |
84 | x = self.embed(x) + inputs
85 | hx, x = self.gru(hx, x)
86 | x = self.logits(x)
87 | x = jax.random.categorical(rng_key, x)
88 | return (x, hx, next_rng_key), x
89 |
90 | _, x = pax.scan(loop, (x, hx, rng_key), conds, time_major=False)
91 | return x
92 |
--------------------------------------------------------------------------------
/examples/wave_gru/prepare_data.sh:
--------------------------------------------------------------------------------
1 | # "Yoshua Bengio: Deep Learning Cognition | Full Keynote - AI in 2020 & Beyond"
2 | youtube-dl -f 139 https://www.youtube.com/watch?v=GibjI5FoZsE --output /tmp/wave_gru_clip.m4a
3 | # convert m4a to wav
4 | ffmpeg -i /tmp/wave_gru_clip.m4a -ac 1 -ar 16000 -acodec pcm_s16le /tmp/wave_gru_clip_.wav
5 | # trim silences
6 | sox /tmp/wave_gru_clip_.wav /tmp/wave_gru_clip.wav silence -l 1 0.1 1% -1 1.0 1%
7 |
--------------------------------------------------------------------------------
/examples/wave_gru/requirements.txt:
--------------------------------------------------------------------------------
1 | fire
2 | librosa
3 | opax
4 | soundfile
5 | tqdm
6 | youtube-dl
--------------------------------------------------------------------------------
/examples/wave_gru/train.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import librosa
6 | import opax
7 | import pax
8 | import soundfile
9 | from tqdm.auto import tqdm
10 |
11 | from data_loader import data_loader
12 | from model import WaveGRU
13 |
14 |
15 | def loss_fn(model: WaveGRU, inputs):
16 | logmel, wav = inputs
17 | input_wav = wav[:, :-1]
18 | target_wav = wav[:, 1:]
19 | model, logits = pax.purecall(model, (logmel, input_wav))
20 | log_pr = jax.nn.log_softmax(logits, axis=-1)
21 | target_wave = jax.nn.one_hot(target_wav, num_classes=logits.shape[-1])
22 | log_pr = jnp.sum(log_pr * target_wave, axis=-1)
23 | loss = -jnp.mean(log_pr)
24 | return loss, (loss, model)
25 |
26 |
27 | def generate_test_sample(step, test_logmel, wave_gru, length, sample_rate, mu):
28 | generated_mu = wave_gru.eval().inference(test_logmel[None, :length, :])
29 | generated_mu = jax.device_get(generated_mu)
30 | synthesized_clip = librosa.mu_expand(
31 | generated_mu[0] - mu // 2, mu=mu, quantize=True
32 | )
33 | file_name = f"/tmp/wave_gru_sample_{step:05d}.wav"
34 | soundfile.write(
35 | file_name,
36 | synthesized_clip,
37 | samplerate=sample_rate,
38 | )
39 | return file_name
40 |
41 |
42 | def train(
43 | hidden_dim: int = 512,
44 | num_training_steps: int = 5_000,
45 | batch_size: int = 128,
46 | learning_rate: float = 5e-4,
47 | sample_rate: int = 16_000,
48 | max_global_norm: float = 1.0,
49 | n_fft=1024,
50 | hop_length=256,
51 | win_length=1024,
52 | n_mels=80,
53 | fmin=0,
54 | fmax=8000,
55 | seq_len=2 ** 10,
56 | n_mu_bits=8,
57 | log_freq: int = 1000,
58 | random_seed=42,
59 | ):
60 | pax.seed_rng_key(random_seed)
61 | mu = 2 ** n_mu_bits - 1
62 | n_frames = seq_len // hop_length
63 | wave_gru = WaveGRU(n_mels, hidden_dim)
64 | print(wave_gru.summary())
65 |
66 | optimizer = opax.chain(
67 | opax.clip_by_global_norm(max_global_norm),
68 | opax.adam(learning_rate),
69 | ).init(wave_gru.parameters())
70 |
71 | split_loader = partial(
72 | data_loader,
73 | batch_size=batch_size,
74 | n_mels=n_mels,
75 | n_fft=n_fft,
76 | hop_length=hop_length,
77 | win_length=win_length,
78 | sample_rate=sample_rate,
79 | mu=mu,
80 | n_frames=n_frames,
81 | fmin=fmin,
82 | fmax=fmax,
83 | )
84 | data_iter = split_loader(split="train")
85 | test_iter = split_loader(split="test")
86 | test_logmel, _ = next(test_iter)
87 |
88 | update_fn = jax.jit(pax.utils.build_update_fn(loss_fn))
89 | total_loss = 0.0
90 | tr = tqdm(range(1, 1 + num_training_steps))
91 | for step in tr:
92 | batch = next(data_iter)
93 | wave_gru, optimizer, loss = update_fn(wave_gru, optimizer, batch)
94 | total_loss = total_loss + loss
95 |
96 | if step % log_freq == 0:
97 | loss = total_loss / log_freq
98 | total_loss = 0.0
99 | file_name = generate_test_sample(
100 | step, test_logmel, wave_gru, 1000, sample_rate, mu
101 | )
102 | tr.write(
103 | f"[step {step}] train loss {loss:.3f} synthesized clip {file_name}"
104 | )
105 |
106 |
107 | if __name__ == "__main__":
108 | import fire
109 |
110 | fire.Fire(train)
111 |
--------------------------------------------------------------------------------
/images/pax_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NTT123/pax/13916cb86ede38c56750cf1bde3ac37c63674014/images/pax_logo.png
--------------------------------------------------------------------------------
/pax/__init__.py:
--------------------------------------------------------------------------------
1 | """PAX package."""
2 |
3 | from pax import experimental, nets, utils
4 | from pax._src.core import (
5 | EmptyNode,
6 | Module,
7 | ParameterModule,
8 | StateModule,
9 | apply_mp_policy,
10 | assert_structure_equal,
11 | enable_eval_mode,
12 | enable_train_mode,
13 | freeze_parameters,
14 | module_and_value,
15 | parameters_method,
16 | pure,
17 | purecall,
18 | select_parameters,
19 | unfreeze_parameters,
20 | unwrap_mp_policy,
21 | update_parameters,
22 | )
23 | from pax._src.core.rng import next_rng_key, seed_rng_key
24 | from pax._src.nn import (
25 | EMA,
26 | GRU,
27 | LSTM,
28 | BatchNorm1D,
29 | BatchNorm2D,
30 | Conv1D,
31 | Conv1DTranspose,
32 | Conv2D,
33 | Conv2DTranspose,
34 | Dropout,
35 | Embed,
36 | GroupNorm,
37 | GRUState,
38 | Identity,
39 | Lambda,
40 | LayerNorm,
41 | Linear,
42 | LSTMState,
43 | MultiHeadAttention,
44 | RngSeq,
45 | Sequential,
46 | VanillaRNN,
47 | VanillaRNNState,
48 | avg_pool,
49 | max_pool,
50 | )
51 | from pax._src.nn.dropout import dropout
52 | from pax._src.utils import build_update_fn, grad, scan, value_and_grad
53 |
54 | __version__ = "0.5.9"
55 |
56 | __all__ = (
57 | "apply_mp_policy",
58 | "assert_structure_equal",
59 | "avg_pool",
60 | "BatchNorm1D",
61 | "BatchNorm2D",
62 | "build_update_fn",
63 | "Conv1D",
64 | "Conv1DTranspose",
65 | "Conv2D",
66 | "Conv2DTranspose",
67 | "dropout",
68 | "Dropout",
69 | "EMA",
70 | "Embed",
71 | "EmptyNode",
72 | "enable_eval_mode",
73 | "enable_train_mode",
74 | "experimental",
75 | "freeze_parameters",
76 | "grad",
77 | "GroupNorm",
78 | "GRU",
79 | "GRUState",
80 | "Identity",
81 | "Lambda",
82 | "LayerNorm",
83 | "Linear",
84 | "LSTM",
85 | "LSTMState",
86 | "max_pool",
87 | "module_and_value",
88 | "Module",
89 | "MultiHeadAttention",
90 | "nets",
91 | "next_rng_key",
92 | "ParameterModule",
93 | "parameters_method",
94 | "pure",
95 | "purecall",
96 | "RngSeq",
97 | "scan",
98 | "seed_rng_key",
99 | "select_parameters",
100 | "Sequential",
101 | "StateModule",
102 | "unfreeze_parameters",
103 | "unwrap_mp_policy",
104 | "update_parameters",
105 | "utils",
106 | "value_and_grad",
107 | "VanillaRNN",
108 | "VanillaRNNState",
109 | )
110 |
111 |
112 | try:
113 | del _src # pylint: disable=undefined-variable
114 | except NameError:
115 | pass
116 |
--------------------------------------------------------------------------------
/pax/_src/__init__.py:
--------------------------------------------------------------------------------
1 | ###
2 | ### Empty init
3 | ###
4 |
--------------------------------------------------------------------------------
/pax/_src/core/__init__.py:
--------------------------------------------------------------------------------
1 | """PAX Module"""
2 |
3 | from .graph_module import GraphModule, InputNode, build_graph_module
4 | from .mixed_precision import apply_mp_policy, unwrap_mp_policy
5 | from .module import EmptyNode, Module, parameters_method
6 | from .module_and_value import module_and_value
7 | from .mutable import mutable
8 | from .pure import pure, purecall
9 | from .transforms import (
10 | enable_eval_mode,
11 | enable_train_mode,
12 | freeze_parameters,
13 | select_parameters,
14 | unfreeze_parameters,
15 | update_parameters,
16 | )
17 | from .utility_modules import Flattener, LazyModule, ParameterModule, StateModule
18 | from .utils import assert_structure_equal
19 |
--------------------------------------------------------------------------------
/pax/_src/core/base.py:
--------------------------------------------------------------------------------
1 | """PAX BaseModule."""
2 |
3 | # Note: This file is originated from
4 | # https://raw.githubusercontent.com/cgarciae/treex/32e4cce5ca0cc991cda8076903853621d0aa4ab9/treex/module.py
5 | # which is under MIT License.
6 |
7 | from typing import Any, List, Mapping, Optional, Tuple, TypeVar
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.tree_util
12 | import numpy as np
13 |
14 | T = TypeVar("T", bound="BaseModule")
15 | M = TypeVar("M")
16 |
17 |
18 | class BaseModule:
19 | """BaseModule manages all information related to the pytree.
20 |
21 | There are two important methods:
22 |
23 | - ``tree_flatten`` converts a module to ``(leaves, treedef)``
24 | - ``tree_unflatten`` restores the module.
25 |
26 | BaseModule maintains a ``pytree_attributes`` tuple that lists all subtree attribute names.
27 | """
28 |
29 | _pytree_attributes: Tuple[str, ...] = ()
30 | _mixed_pytree_attributes: Optional[Tuple[str, ...]] = None
31 |
32 | @property
33 | def pytree_attributes(self):
34 | if self._mixed_pytree_attributes is not None:
35 | return self._pytree_attributes + self._mixed_pytree_attributes
36 | else:
37 | return self._pytree_attributes
38 |
39 | def find_and_register_pytree_attributes(self: T):
40 | """Find and register ndarrays and submodules."""
41 | is_mod_or_node = lambda x: isinstance(x, (BaseModule, EmptyNode))
42 | is_pytree = lambda x: isinstance(x, pytree_cls)
43 |
44 | pytree_attributes = []
45 | mixed_pytree_attributes = []
46 | for name, value in self.__dict__.items():
47 | leaves, _ = jax.tree_util.tree_flatten(value, is_leaf=is_mod_or_node)
48 | pytree_cls = (jnp.ndarray, np.ndarray, BaseModule, EmptyNode)
49 | any_pytree = any(map(is_pytree, leaves))
50 | all_pytree = all(map(is_pytree, leaves))
51 | if any_pytree and all_pytree:
52 | pytree_attributes.append(name)
53 | elif any_pytree:
54 | mixed_pytree_attributes.append(name)
55 | self._pytree_attributes = tuple(pytree_attributes)
56 | if len(mixed_pytree_attributes) > 0:
57 | self._mixed_pytree_attributes = tuple(mixed_pytree_attributes)
58 | else:
59 | self._mixed_pytree_attributes = None
60 |
61 | def tree_flatten(self) -> Tuple[List[jnp.ndarray], Mapping[str, Any]]:
62 | """Convert a module to ``(children, treedef)``."""
63 | aux = dict(self.__dict__)
64 | children = [aux.pop(name) for name in self._pytree_attributes]
65 | if self._mixed_pytree_attributes is not None:
66 | is_module = lambda x: isinstance(x, BaseModule)
67 | array_mod_cls = (jnp.ndarray, np.ndarray, BaseModule)
68 | is_array_mod = lambda x: isinstance(x, array_mod_cls)
69 | for name in self._mixed_pytree_attributes:
70 | value = aux.pop(name)
71 | leaves, treedef = jax.tree_util.tree_flatten(value, is_leaf=is_module)
72 | leaves = (v if is_array_mod(v) else ValueNode(v) for v in leaves)
73 | value = jax.tree_util.tree_unflatten(treedef, leaves)
74 | children.append(value)
75 | return children, aux
76 |
77 | @classmethod
78 | def tree_unflatten(cls, aux, children):
79 | """Recreate a module from its ``(children, treedef)``."""
80 | module = object.__new__(cls)
81 | module_dict = module.__dict__
82 | module_dict.update(aux)
83 | module_dict.update(zip(module._pytree_attributes, children))
84 | if module._mixed_pytree_attributes is not None:
85 | L = len(module._pytree_attributes)
86 | is_leaf = lambda x: isinstance(x, (ValueNode, BaseModule))
87 | unwrap = lambda x: x.value if isinstance(x, ValueNode) else x
88 | for name, value in zip(module._mixed_pytree_attributes, children[L:]):
89 | module_dict[name] = jax.tree_util.tree_map(
90 | unwrap, value, is_leaf=is_leaf
91 | )
92 | return module
93 |
94 | def __init_subclass__(cls):
95 | """Any subclass of ``Module`` is also registered as pytree."""
96 | jax.tree_util.register_pytree_node_class(cls)
97 |
98 | def __eq__(self, o: object) -> bool:
99 | """Compare two modules."""
100 | if id(self) == id(o):
101 | return True
102 |
103 | if type(self) is not type(o):
104 | return False
105 |
106 | self_leaves, self_treedef = jax.tree_util.tree_flatten(self)
107 | o_leaves, o_treedef = jax.tree_util.tree_flatten(o)
108 |
109 | if len(self_leaves) != len(o_leaves):
110 | return False
111 |
112 | if self_treedef != o_treedef:
113 | return False
114 |
115 | leaves_equal = jax.tree_util.tree_map(
116 | lambda a, b: a is b, self_leaves, o_leaves
117 | )
118 | return all(leaves_equal)
119 |
120 | def __hash__(self) -> int:
121 | leaves, treedef = jax.tree_util.tree_flatten(self)
122 | leaves = jax.tree_util.tree_map(lambda x: (x.shape, x.dtype), leaves)
123 | return hash((tuple(leaves), treedef))
124 |
125 |
126 | # Note: this class is inspired by treex's `Nothing` class.
127 | @jax.tree_util.register_pytree_node_class
128 | class EmptyNode:
129 | """Mark an uninitialized or deleted pytree node."""
130 |
131 | def tree_flatten(self):
132 | """Flatten empty node."""
133 | return (), None
134 |
135 | @classmethod
136 | def tree_unflatten(cls, aux, children):
137 | """Unflatten empty node."""
138 | del aux, children
139 | return EmptyNode()
140 |
141 | def __repr__(self) -> str:
142 | return "EmptyNode"
143 |
144 | def __eq__(self, o: object) -> bool:
145 | if isinstance(o, EmptyNode):
146 | return True
147 | return False
148 |
149 |
150 | @jax.tree_util.register_pytree_node_class
151 | class ValueNode:
152 | """We use this class to store a value in treedef."""
153 |
154 | def __init__(self, value):
155 | super().__init__()
156 | self.value = value
157 |
158 | def tree_flatten(self):
159 | return (), self.value
160 |
161 | @classmethod
162 | def tree_unflatten(cls, value, children):
163 | return ValueNode(value)
164 |
165 | def __repr__(self) -> str:
166 | return f"ValueNode({self.value})"
167 |
--------------------------------------------------------------------------------
/pax/_src/core/mixed_precision.py:
--------------------------------------------------------------------------------
1 | """Enforce mixed-precision policy."""
2 |
3 | import functools
4 | from typing import TypeVar
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import jmp
9 |
10 | from .module import Module
11 | from .safe_module import find_descriptor
12 |
13 | T = TypeVar("T", bound=Module)
14 |
15 |
16 | def _wrap_method(func):
17 | """Wrap a class's method to enforce mixe-precision policy."""
18 |
19 | @functools.wraps(func)
20 | def mp_method_wrapper(self, *args, **kwargs):
21 | """A mixed-precision method.
22 |
23 | - Convert all weights to compute dtype.
24 | - Cast all arguments to compute dtype.
25 | - Call the original method.
26 | - Convert all weights to param dtype.
27 | - Cast output to output dtype.
28 |
29 | We bypass PAX mutability checking to make mixed-precision
30 | policy transparent from the user's point of view.
31 | """
32 | original_values = {}
33 | casted_original = {}
34 | # pylint: disable=protected-access
35 |
36 | # convert weights to compute dtype
37 | for name in self.pytree_attributes:
38 | value = getattr(self, name)
39 | if not _has_module(value):
40 | casted_value = self._pax_mp_policy.cast_to_compute(value)
41 | self.__dict__[name] = casted_value
42 | original_values[name] = value
43 | casted_original[name] = casted_value
44 |
45 | # cast arguments to compute dtype
46 | args, kwargs = self._pax_mp_policy.cast_to_compute((args, kwargs))
47 | output = func.__get__(self, type(self))(*args, **kwargs) # type:ignore
48 |
49 | # convert weights to param dtype
50 | for name in self.pytree_attributes:
51 | value = getattr(self, name)
52 | if not _has_module(value):
53 | if value is not casted_original[name]: # modified
54 | casted_value = self._pax_mp_policy.cast_to_param(value)
55 | setattr(self, name, casted_value)
56 | else:
57 | # avoid casting operation
58 | self.__dict__[name] = original_values[name]
59 |
60 | # cast output to output dtype
61 | output = self._pax_mp_policy.cast_to_output(output)
62 | return output
63 |
64 | return mp_method_wrapper
65 |
66 |
67 | def _mp_repr(mp_policy):
68 | dtype_to_name = {
69 | jnp.bfloat16: "H",
70 | jnp.float16: "H",
71 | jnp.float32: "F",
72 | jnp.float64: "F",
73 | }
74 |
75 | return (
76 | dtype_to_name[mp_policy.param_dtype]
77 | + dtype_to_name[mp_policy.compute_dtype]
78 | + dtype_to_name[mp_policy.output_dtype]
79 | )
80 |
81 |
82 | def apply_mp_policy(module: T, mp_policy: jmp.Policy) -> T:
83 | """Create a mixed-precision module.
84 |
85 | Create a subclass on the fly to enforce the mixed-precision policy.
86 |
87 | >>> import jmp
88 | >>> mp_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
89 | >>> net = pax.Linear(3, 3)
90 | >>> net = pax.apply_mp_policy(net, mp_policy)
91 | >>> print(net.summary())
92 | Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF)
93 | """
94 |
95 | if hasattr(module, "_pax_mp_policy"):
96 | raise ValueError(
97 | "Cannot apply multiple mixed-precision policies on an object.\n"
98 | "Call `pax.unwrap_mp_policy(...)` to remove the policy first."
99 | )
100 |
101 | # pylint: disable=protected-access
102 | cls_name = module.__class__.__name__
103 | module_methods = dir(Module)
104 | base = module.__class__
105 |
106 | methods = {}
107 | for name in dir(base):
108 | if name != "__call__" and name.startswith("__"):
109 | continue
110 | if name == "__call__" or name not in module_methods:
111 | value = getattr(base, name)
112 | if callable(value):
113 | value = find_descriptor(base, name)
114 | if value is None:
115 | continue
116 | if isinstance(value, (staticmethod, classmethod)):
117 | methods[name] = value
118 | else:
119 | methods[name] = _wrap_method(value)
120 |
121 | def _repr(self, info=None):
122 | if info is None:
123 | info = {}
124 | info["mp_policy"] = _mp_repr(self._pax_mp_policy)
125 | return super(base, self)._repr(info) # type: ignore
126 |
127 | methods["_repr"] = _repr
128 |
129 | cls = type(cls_name, (base,), methods)
130 | obj = object.__new__(cls)
131 | obj.__dict__.update(module.__dict__)
132 | obj.__dict__["_pax_mp_policy"] = mp_policy
133 | for name in obj.pytree_attributes:
134 | value = getattr(obj, name)
135 | if not _has_module(value):
136 | obj.__dict__[name] = mp_policy.cast_to_param(obj.__dict__[name])
137 | return obj
138 |
139 |
140 | def unwrap_mp_policy(module: T) -> T:
141 | """Unwrap a mixed-precision module to recreate the original module.
142 |
143 | >>> import jmp
144 | >>> mp_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
145 | >>> net = pax.Linear(3, 3)
146 | >>> net = pax.apply_mp_policy(net, mp_policy)
147 | >>> print(net.summary())
148 | Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF)
149 | >>> net = pax.unwrap_mp_policy(net)
150 | >>> print(net.summary())
151 | Linear(in_dim=3, out_dim=3, with_bias=True)
152 | """
153 | if not hasattr(module, "_pax_mp_policy"):
154 | raise ValueError("Expected a mixed-precision module.")
155 |
156 | base = module.__class__.__base__
157 | original = object.__new__(base)
158 | original.__dict__.update(module.__dict__)
159 | del original.__dict__["_pax_mp_policy"]
160 | return original
161 |
162 |
163 | def _has_module(mod):
164 | is_mod = lambda x: x is not mod
165 | leaves, _ = jax.tree_util.tree_flatten(mod, is_leaf=is_mod)
166 | return any(map(is_mod, leaves))
167 |
--------------------------------------------------------------------------------
/pax/_src/core/module_and_value.py:
--------------------------------------------------------------------------------
1 | """PAX mechanisms to make PAX method pure."""
2 |
3 | from functools import partial
4 | from types import MethodType
5 | from typing import Callable, Tuple, TypeVar
6 |
7 | from .base import BaseModule
8 | from .pure import pure
9 |
10 | O = TypeVar("O")
11 | T = TypeVar("T", bound=BaseModule)
12 |
13 |
14 | def module_and_value(module_or_method: Callable[..., O]) -> Callable[..., Tuple[T, O]]:
15 | """Return a pure function that executes a module's method.
16 |
17 | This pure function also returns the updated input module in the output.
18 |
19 | Example:
20 |
21 | >>> net = pax.Linear(1, 1)
22 | >>> x = jnp.ones((32, 1))
23 | >>> net, y = pax.module_and_value(net)(x) # note: `net` is also returned.
24 |
25 |
26 | Arguments:
27 | module_or_method: Either a PAX module or a method of a PAX module.
28 |
29 | Returns:
30 | A pure function.
31 | """
32 | is_bound_method = True
33 | if isinstance(module_or_method, MethodType): # a method
34 | mod = module_or_method.__self__
35 | func = module_or_method.__func__
36 | elif isinstance(module_or_method, BaseModule): # a module
37 | mod = module_or_method
38 | assert hasattr(mod, "__call__"), "Expecting a callable module."
39 | func = module_or_method.__call__.__func__
40 | elif callable(module_or_method):
41 | is_bound_method = False
42 | func = module_or_method
43 | else:
44 | raise ValueError("Expecting a module or a module's method.")
45 |
46 | @pure
47 | def _run(mod, *args, **kwargs):
48 | assert isinstance(mod, BaseModule), "Expecting a PAX module."
49 | out = func(mod, *args, **kwargs)
50 | return mod, out
51 |
52 | if is_bound_method:
53 | return partial(_run, mod)
54 | else:
55 | return _run
56 |
--------------------------------------------------------------------------------
/pax/_src/core/mutable.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 |
3 | from .module import Module
4 | from .pure import get_all_submodules
5 | from .threading_local import allow_mutation
6 |
7 |
8 | @contextmanager
9 | def mutable(module: Module):
10 | """A context manager that allows a copy module to be mutable inside the context.
11 |
12 | >>> net = pax.Linear(1, 2)
13 | >>> with pax.experimental.mutable(net) as net:
14 | ... net.bias = jnp.array(0.)
15 | >>> assert net.bias.item() == 0.
16 | """
17 |
18 | copy = module.copy()
19 | all_submodules = get_all_submodules(copy)
20 |
21 | with allow_mutation(all_submodules):
22 | try:
23 | yield copy
24 | finally:
25 | copy.find_and_register_pytree_attributes()
26 | copy.scan_bugs()
27 |
--------------------------------------------------------------------------------
/pax/_src/core/pure.py:
--------------------------------------------------------------------------------
1 | """PAX mechanisms to make PAX functions pure."""
2 |
3 | import functools
4 | from types import MethodType
5 | from typing import Any, Callable, Tuple, TypeVar
6 |
7 | import jax
8 |
9 | from .base import BaseModule
10 | from .threading_local import allow_mutation
11 |
12 | T = TypeVar("T")
13 | O = TypeVar("O")
14 |
15 |
16 | def pure(func: Callable):
17 | """Make a function pure by copying the inputs.
18 |
19 | Any modification on the copy will not affect the original inputs.
20 |
21 | **Note**: only functions that are wrapped by `pax.pure` are allowed to modify PAX's Modules.
22 |
23 | Example:
24 |
25 | >>> f = pax.Linear(3,3)
26 | >>> f.a_list = []
27 | Traceback (most recent call last):
28 | ...
29 | ValueError: Cannot modify a module in immutable mode.
30 | Please do this computation inside a function decorated by `pax.pure`.
31 | >>>
32 | >>> @pax.pure
33 | ... def add_list(m):
34 | ... m.a_list = []
35 | ... return m
36 | ...
37 | >>> f = add_list(f)
38 | >>> print(f.a_list)
39 | []
40 |
41 | Arguments:
42 | func: A function.
43 |
44 | Returns:
45 | A pure function.
46 | """
47 |
48 | @functools.wraps(func)
49 | def wrapper(*args, **kwargs):
50 | for m in _get_modules((func, args, kwargs)):
51 | m.scan_bugs()
52 |
53 | # support calling method
54 | if isinstance(func, MethodType):
55 | args = (func.__self__, *args)
56 | unbound_func = func.__func__
57 | # or calling a module
58 | elif isinstance(func, BaseModule) and callable(func):
59 | args = (func, *args)
60 | unbound_func = func.__call__.__func__
61 | elif callable(func):
62 | unbound_func = func
63 | else:
64 | raise ValueError("Not supported")
65 |
66 | args, kwargs = _copy((args, kwargs))
67 | modules = get_all_submodules((args, kwargs))
68 | with allow_mutation(modules):
69 | out = unbound_func(*args, **kwargs)
70 |
71 | for m in modules:
72 | m.find_and_register_pytree_attributes()
73 | m.scan_bugs()
74 | return out
75 |
76 | return wrapper
77 |
78 |
79 | @pure
80 | def purecall(module: Callable[..., O], *args, **kwargs) -> Tuple[Any, O]:
81 | """Call a module and return the updated module.
82 |
83 | A shortcut for `pax.pure(lambda f, x: [f, f(x)])`.
84 | """
85 | assert isinstance(module, BaseModule)
86 | assert callable(module)
87 | return module, module(*args, **kwargs)
88 |
89 |
90 | def _get_modules(tree):
91 | "Return a list of modules in the pytree `tree`."
92 | modules = jax.tree_util.tree_flatten(
93 | tree, is_leaf=lambda x: isinstance(x, BaseModule)
94 | )[0]
95 | modules = [m for m in modules if isinstance(m, BaseModule)]
96 | return modules
97 |
98 |
99 | def get_all_submodules(value):
100 | submods = _get_modules(value)
101 | out = list(submods)
102 | for mod in submods:
103 | out.extend(get_all_submodules(mod.submodules()))
104 | return out
105 |
106 |
107 | def _copy(value: T) -> T:
108 | leaves, treedef = jax.tree_util.tree_flatten(value)
109 | return jax.tree_util.tree_unflatten(treedef, leaves)
110 |
--------------------------------------------------------------------------------
/pax/_src/core/rng.py:
--------------------------------------------------------------------------------
1 | """Random Number Generator."""
2 |
3 | from .threading_local import KeyArray, next_rng_key, seed_rng_key
4 |
5 | __all__ = (
6 | "KeyArray",
7 | "next_rng_key",
8 | "seed_rng_key",
9 | )
10 |
--------------------------------------------------------------------------------
/pax/_src/core/safe_module.py:
--------------------------------------------------------------------------------
1 | """Safeguards to prevent potential bugs."""
2 |
3 | import inspect
4 | from typing import Iterable, List, Type, TypeVar
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import numpy as np
9 |
10 | from .base import BaseModule
11 | from .threading_local import allow_mutation, is_mutable
12 |
13 | T = TypeVar("T")
14 |
15 |
16 | class SafeBaseModuleMetaclass(type):
17 | """Metaclass for `SafeBaseModule`."""
18 |
19 | def __call__(cls: Type[T], *args, **kwargs) -> T:
20 | module = cls.__new__(cls, *args, **kwargs) # type: ignore
21 |
22 | with allow_mutation(module):
23 | cls.__init__(module, *args, **kwargs)
24 | module.find_and_register_pytree_attributes()
25 |
26 | # scan module after initialization for potential bugs
27 | if hasattr(module, "__slots__"):
28 | raise ValueError("`__slots__` is not supported by PAX modules.")
29 | module._assert_not_shared_module()
30 | module._assert_not_shared_weight()
31 | module._scan_fields(module._class_fields())
32 | return module
33 |
34 |
35 | class SafeBaseModule(BaseModule, metaclass=SafeBaseModuleMetaclass):
36 | """Adding safe guards to BaseModule to prevent bugs."""
37 |
38 | def _class_fields(self):
39 | for name, value in inspect.getmembers(self):
40 | if name.startswith("__") or inspect.ismethod(value):
41 | continue
42 |
43 | if name in self.__dict__:
44 | continue
45 |
46 | if find_descriptor(self.__class__, name) is not None:
47 | # ignore descriptors
48 | continue
49 |
50 | yield name
51 |
52 | def _assert_mutability(self):
53 | if not is_mutable(self):
54 | raise ValueError(
55 | "Cannot modify a module in immutable mode.\n"
56 | "Please do this computation inside a function decorated by `pax.pure`."
57 | )
58 |
59 | def _assert_not_shared_module(self):
60 | """Shared module is not allowed."""
61 | shared_module = _find_shared_module(self)
62 | if shared_module is not None:
63 | raise ValueError(
64 | f"The module `{shared_module}` is shared between two nodes of the pytree.\n"
65 | f"This is not allowed to prevent potential silence bugs."
66 | )
67 |
68 | def _assert_not_shared_weight(self):
69 | """Shared weight is not allowed."""
70 | leaves = jax.tree_util.tree_leaves(self)
71 | leaf_ids = set()
72 | for leaf in leaves:
73 | if id(leaf) in leaf_ids:
74 | raise ValueError(
75 | f"Detected a shared ndarray. This is not allowed.\n"
76 | f"Shape={leaf.shape}\n"
77 | f"Dtype={leaf.dtype}\n"
78 | f"Value={leaf}",
79 | )
80 | leaf_ids.add(id(leaf))
81 |
82 | def _scan_fields(self, fields: Iterable[str]):
83 | """Scan fields for *potential* bugs."""
84 |
85 | for name in fields:
86 | if name in self.pytree_attributes:
87 | continue
88 |
89 | value = getattr(self, name)
90 | is_mod = lambda x: isinstance(x, BaseModule)
91 | is_ndarray = lambda x: isinstance(x, (jnp.ndarray, np.ndarray))
92 | mods, _ = jax.tree_util.tree_flatten(value, is_leaf=is_mod)
93 | leaves = jax.tree_util.tree_leaves(value)
94 | has_mods = any(map(is_mod, mods))
95 | has_arrays = any(map(is_ndarray, mods))
96 |
97 | if has_mods:
98 | raise ValueError(
99 | f"\n"
100 | f"Unregistered field `{self}.{name}`:\n"
101 | f" value={value}\n"
102 | f"contains a module leaf.\n"
103 | )
104 |
105 | if has_arrays:
106 | raise ValueError(
107 | f"\n"
108 | f"Unregistered field `{self}.{name}`:\n"
109 | f" value={value}\n"
110 | f"contains a ndarray leaf.\n"
111 | )
112 |
113 |
114 | def _find_shared_module(module: BaseModule):
115 | """Find shared module.
116 |
117 | - Return the first module that is shared by two nodes of the pytree.
118 | - Return `None` if there is no shared module.
119 | """
120 |
121 | def _get_all_modules(mod: BaseModule, lst: List):
122 | lst.append(mod)
123 | is_mod = lambda x: isinstance(x, BaseModule) and x is not mod
124 | submodules, _ = jax.tree_util.tree_flatten(mod, is_leaf=is_mod)
125 | submodules = (m for m in submodules if is_mod(m))
126 | for m in submodules:
127 | _get_all_modules(m, lst)
128 |
129 | mods = []
130 | _get_all_modules(module, mods)
131 | module_ids = set()
132 | for m in mods:
133 | if id(m) in module_ids:
134 | return m
135 | module_ids.add(id(m))
136 |
137 | return None
138 |
139 |
140 | # source: https://stackoverflow.com/a/21963090
141 | def find_descriptor(cls, attrname):
142 | """Find the descriptor of an attribute."""
143 |
144 | def hasspecialmethod(obj, name):
145 | return any(name in klass.__dict__ for klass in type(obj).__mro__)
146 |
147 | for klass in cls.__mro__:
148 | if attrname in klass.__dict__:
149 | descriptor = klass.__dict__[attrname]
150 | if not hasspecialmethod(descriptor, "__get__"):
151 | return None
152 | return descriptor
153 | return None
154 |
--------------------------------------------------------------------------------
/pax/_src/core/threading_local.py:
--------------------------------------------------------------------------------
1 | """
2 | Manage thread local states
3 | """
4 |
5 | import random
6 | import threading
7 | import weakref
8 | from contextlib import contextmanager
9 | from typing import Any, Optional, Tuple, Union
10 |
11 | import jax
12 | import jax.numpy as jnp
13 | import jax.tree_util
14 |
15 | KeyArray = Union[Any, jnp.ndarray]
16 |
17 |
18 | class PaxThreadingLocalState(threading.local):
19 | """Manage all thread local states used by PAX"""
20 |
21 | _mutable_module_ref_list: Tuple[weakref.ReferenceType, ...]
22 | _mutable_module_level: int
23 | _rng: Optional[random.Random]
24 |
25 | def __init__(self):
26 | super().__init__()
27 | self._mutable_module_ref_list = ()
28 | self._mutable_module_level = _jax_cur_level()
29 | self._rng = random.Random(42)
30 |
31 | def add_mutable_module(self, module):
32 | """add `module` to mutable list"""
33 | self._mutable_module_ref_list = (
34 | weakref.ref(module),
35 | *self._mutable_module_ref_list,
36 | )
37 |
38 | def is_mutable(self, module):
39 | """Is `module` mutable?"""
40 |
41 | # cannot modify a module whose level of abstraction
42 | # is lower than the current level
43 | if self._mutable_module_level < _jax_cur_level():
44 | return False
45 |
46 | for ref in self._mutable_module_ref_list:
47 | if module is ref():
48 | return True
49 |
50 | return False
51 |
52 | @contextmanager
53 | def allow_mutation(self, modules):
54 | """A context manager that turns on mutability."""
55 |
56 | if not isinstance(modules, (tuple, list)):
57 | modules = (modules,)
58 | modules = tuple(weakref.ref(mod) for mod in modules)
59 |
60 | prev = self._mutable_module_ref_list
61 | prev_abstraction_level = self._mutable_module_level
62 | try:
63 | self._mutable_module_ref_list = modules
64 | self._mutable_module_level = _jax_cur_level()
65 | yield
66 | finally:
67 | self._mutable_module_ref_list = prev
68 | self._mutable_module_level = prev_abstraction_level
69 |
70 | def seed_rng_key(self, seed: int) -> None:
71 | """Set ``self._rng = random.Random(seed)``.
72 |
73 | Arguments:
74 | seed: an integer seed.
75 | """
76 | assert isinstance(seed, int)
77 | self._rng = random.Random(seed)
78 |
79 | def next_rng_key(self) -> KeyArray:
80 | """Return a random rng key. Renew the global random state."""
81 | seed = self._rng.randint(1, 999999999)
82 | return jax.random.PRNGKey(seed)
83 |
84 | def get_rng_state(self):
85 | """Return internal random states."""
86 | return self._rng.getstate()
87 |
88 | def set_rng_state(self, state):
89 | """Set internal random states."""
90 | self._rng.setstate(state)
91 |
92 |
93 | def _jax_cur_level():
94 | """
95 | Return the level of current jax trace.
96 |
97 | If it is an eval_trace, return -1.
98 | """
99 | trace = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
100 | if trace.trace_type == jax.core.EvalTrace:
101 | return -1
102 | else:
103 | return trace.level
104 |
105 |
106 | PAX_STATE = PaxThreadingLocalState()
107 | add_mutable_module = PAX_STATE.add_mutable_module
108 | allow_mutation = PAX_STATE.allow_mutation
109 | get_rng_state = PAX_STATE.get_rng_state
110 | is_mutable = PAX_STATE.is_mutable
111 | next_rng_key = PAX_STATE.next_rng_key
112 | seed_rng_key = PAX_STATE.seed_rng_key
113 | set_rng_state = PAX_STATE.set_rng_state
114 |
--------------------------------------------------------------------------------
/pax/_src/core/transforms.py:
--------------------------------------------------------------------------------
1 | """Transform a module to a new one."""
2 | from typing import Any, TypeVar
3 |
4 | import jax
5 |
6 | from .module import Module, parameters_method, update_pytree
7 |
8 | TreeDef = Any
9 |
10 | T = TypeVar("T", bound=Module)
11 | K = TypeVar("K", bound=Module)
12 | O = TypeVar("O", bound=Module)
13 |
14 |
15 | def enable_train_mode(mod: T) -> T:
16 | """Return a module in training mode."""
17 | return mod.train()
18 |
19 |
20 | def enable_eval_mode(mod: T) -> T:
21 | """Return a module in evaluation mode."""
22 | return mod.eval()
23 |
24 |
25 | def freeze_parameters(mod: T) -> T:
26 | """Return a copy module with all trainable parameters are converted to non-trainable states."""
27 |
28 | def _freeze_apply_fn(mod: T) -> T:
29 | return mod.replace_method(parameters=parameters_method())
30 |
31 | return mod.apply(_freeze_apply_fn)
32 |
33 |
34 | def unfreeze_parameters(mod: T, *, origin: T) -> T:
35 | """Return a copy module with all trainable parameters are converted to non-trainable states."""
36 | tree_def = jax.tree_util.tree_structure(origin)
37 | leaves = jax.tree_util.tree_leaves(mod)
38 | return jax.tree_util.tree_unflatten(tree_def, leaves)
39 |
40 |
41 | def select_parameters(mod: T) -> T:
42 | """Select `PARAMETER` leaves only."""
43 | return mod.parameters()
44 |
45 |
46 | def update_parameters(mod: T, *, params: T) -> T:
47 | """Return a module that uses trainable parameters in `params`."""
48 | return update_pytree(mod, other=params.parameters())
49 |
--------------------------------------------------------------------------------
/pax/_src/core/utility_modules.py:
--------------------------------------------------------------------------------
1 | """Utility Modules."""
2 |
3 |
4 | from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
5 |
6 | import jax
7 | import jax.numpy as jnp
8 |
9 | from .module import Module, parameters_method
10 |
11 | T = TypeVar("T", bound=Module)
12 | O = TypeVar("O")
13 |
14 |
15 | class ParameterModule(Module):
16 | """A PAX module that registers attributes as parameters by default."""
17 |
18 | def parameters(self):
19 | return self.apply_submodules(lambda x: x.parameters())
20 |
21 |
22 | class StateModule(Module):
23 | """A PAX module that registers attributes as states by default."""
24 |
25 | parameters = parameters_method()
26 |
27 |
28 | class LazyModule(Module):
29 | """A lazy module is a module that only creates submodules when needed.
30 |
31 |
32 | Example:
33 |
34 | >>> from dataclasses import dataclass
35 | >>> @dataclass
36 | ... class MLP(pax.experimental.LazyModule):
37 | ... features: list
38 | ...
39 | ... def __call__(self, x):
40 | ... sizes = zip(self.features[:-1], self.features[1:])
41 | ... for i, (in_dim, out_dim) in enumerate(sizes):
42 | ... fc = self.get_or_create(f"fc_{i}", lambda: pax.Linear(in_dim, out_dim))
43 | ... x = jax.nn.relu(fc(x))
44 | ... return x
45 | ...
46 | ...
47 | >>> mlp, _ = MLP([1, 2, 3]) % jnp.ones((1, 1))
48 | >>> print(mlp.summary())
49 | MLP(features=[1, 2, 3])
50 | ├── Linear(in_dim=1, out_dim=2, with_bias=True)
51 | └── Linear(in_dim=2, out_dim=3, with_bias=True)
52 | """
53 |
54 | def get_or_create(self, name, create_fn: Callable[[], T]) -> T:
55 | """Create and register a new attribute when it is not exist.
56 |
57 | Return the attribute.
58 | """
59 | if hasattr(self, name):
60 | value = getattr(self, name)
61 | else:
62 | assert callable(create_fn), "Expect a callable function"
63 | value = create_fn()
64 | setattr(self, name, value)
65 | return value
66 |
67 |
68 | class Lambda(Module):
69 | """Convert a function to a module.
70 |
71 | Example:
72 |
73 | >>> net = pax.Lambda(jax.nn.relu)
74 | >>> print(net.summary())
75 | x => relu(x)
76 | >>> y = net(jnp.array(-1))
77 | >>> y
78 | DeviceArray(0, dtype=int32, weak_type=True)
79 | """
80 |
81 | func: Callable
82 |
83 | def __init__(self, func: Callable, name: Optional[str] = None):
84 | super().__init__(name=name)
85 | self.func = func
86 |
87 | def __call__(self, *args, **kwargs):
88 | return self.func(*args, **kwargs)
89 |
90 | def __repr__(self) -> str:
91 | if self.name is not None:
92 | return super().__repr__()
93 | else:
94 | return f"{self.__class__.__qualname__}({self.func.__name__})"
95 |
96 | def summary(self, return_list: bool = False) -> Union[str, List[str]]:
97 | if self.name is not None:
98 | name = self.name
99 | elif isinstance(self.func, jax.custom_jvp) and hasattr(self.func, "fun"):
100 | if hasattr(self.func.fun, "__name__"):
101 | name = self.func.fun.__name__
102 | else:
103 | name = f"{self.func.fun}"
104 | elif hasattr(self.func, "__name__"):
105 | name = self.func.__name__
106 | else:
107 | name = f"{self.func}"
108 | output = f"x => {name}(x)"
109 | return [output] if return_list else output
110 |
111 |
112 | class Flattener(Module):
113 | """Flatten PAX modules for better performance.
114 |
115 | Example:
116 |
117 | >>> net = pax.Linear(3, 3)
118 | >>> opt = opax.adam(1e-3)(net.parameters())
119 | >>> flat_mods = pax.experimental.Flattener(model=net, optimizer=opt)
120 | >>> net, opt = flat_mods.model, flat_mods.optimizer
121 | >>> print(net.summary())
122 | Linear(in_dim=3, out_dim=3, with_bias=True)
123 | >>> print(opt.summary())
124 | chain..Chain
125 | ├── scale_by_adam..ScaleByAdam
126 | │ ├── Linear(in_dim=3, out_dim=3, with_bias=True)
127 | │ └── Linear(in_dim=3, out_dim=3, with_bias=True)
128 | └── scale..Scale
129 | """
130 |
131 | treedef_dict: Dict[str, Any]
132 | leaves_dict: Dict[str, Sequence[jnp.ndarray]]
133 |
134 | def __init__(self, **kwargs):
135 | """Create a new flattener."""
136 | super().__init__()
137 | self.treedef_dict = {}
138 | self.leaves_dict = {}
139 |
140 | for name, value in kwargs.items():
141 | leaves, treedef = jax.tree_util.tree_flatten(value)
142 | self.treedef_dict[name] = treedef
143 | self.leaves_dict[name] = leaves
144 |
145 | def __getattr__(self, name: str) -> Any:
146 | if name in self.treedef_dict:
147 | treedef = self.treedef_dict[name]
148 | leaves = self.leaves_dict[name]
149 | value = jax.tree_util.tree_unflatten(treedef, leaves)
150 | return value
151 | else:
152 | raise AttributeError()
153 |
154 | def update(self: T, **kwargs) -> T:
155 | """Update the flattener.
156 |
157 | Example:
158 |
159 | >>> net = pax.Linear(3, 3)
160 | >>> flats = pax.experimental.Flattener(net=net)
161 | >>> flats = flats.update(net=pax.Linear(4, 4))
162 | >>> print(flats.net.summary())
163 | Linear(in_dim=4, out_dim=4, with_bias=True)
164 | """
165 | new_self = self.copy()
166 | for name, value in kwargs.items():
167 | leaves, treedef = jax.tree_util.tree_flatten(value)
168 | new_self.treedef_dict[name] = treedef
169 | new_self.leaves_dict[name] = leaves
170 | return new_self
171 |
172 | def parameters(self: T) -> T:
173 | """Raise an error.
174 |
175 | Need to reconstruct the original module before getting parameters.
176 | """
177 |
178 | raise ValueError(
179 | "A flattener only stores ndarray leaves as non-trainable states.\n"
180 | "Reconstruct the original module before getting parameters."
181 | )
182 |
--------------------------------------------------------------------------------
/pax/_src/core/utils.py:
--------------------------------------------------------------------------------
1 | """Useful functions."""
2 |
3 | from typing import TypeVar
4 | from unittest import TestCase
5 |
6 | import jax
7 |
8 | from .module import Module
9 |
10 | T = TypeVar("T", bound=Module)
11 |
12 |
13 | def assert_structure_equal(tree_a: T, tree_b: T):
14 | """Assert that the two pytrees are structurally the same.
15 |
16 | Print out the difference.
17 | """
18 | if jax.tree_util.tree_structure(tree_a) == jax.tree_util.tree_structure(tree_b):
19 | return True
20 |
21 | def check(subtree_a, subtree_b):
22 | if isinstance(subtree_a, Module) and isinstance(subtree_b, Module):
23 | assert_structure_equal(subtree_a, subtree_b)
24 |
25 | has_error = False
26 | try:
27 | jax.tree_util.tree_map(
28 | check,
29 | tree_a,
30 | tree_b,
31 | is_leaf=lambda x: isinstance(x, Module)
32 | and x is not tree_a
33 | and x is not tree_b,
34 | )
35 | except ValueError:
36 | has_error = True
37 |
38 | if has_error:
39 | test_case = TestCase()
40 | test_case.maxDiff = None
41 | # do not compare weights
42 | tree_a_w_none_leaves = jax.tree_util.tree_map(lambda _: None, tree_a)
43 | tree_b_w_none_leaves = jax.tree_util.tree_map(lambda _: None, tree_b)
44 | test_case.assertDictEqual(
45 | vars(tree_a_w_none_leaves), vars(tree_b_w_none_leaves)
46 | )
47 |
48 | return has_error
49 |
--------------------------------------------------------------------------------
/pax/_src/nets/__init__.py:
--------------------------------------------------------------------------------
1 | """Public nets."""
2 |
3 | from .resnet import (
4 | ResNet,
5 | ResNet18,
6 | ResNet34,
7 | ResNet50,
8 | ResNet101,
9 | ResNet152,
10 | ResNet200,
11 | )
12 | from .transformer import Transformer
13 |
--------------------------------------------------------------------------------
/pax/_src/nets/transformer.py:
--------------------------------------------------------------------------------
1 | """Transformer Decoder Stack."""
2 |
3 | from typing import Dict, Optional, Sequence
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 |
9 | from ..core import Module
10 | from ..nn import LayerNorm, Linear, MultiHeadAttention, RngSeq
11 | from ..nn.dropout import dropout
12 |
13 |
14 | class CausalSelfAttention(MultiHeadAttention):
15 | """Self attention with a causal mask applied."""
16 |
17 | def __call__(
18 | self,
19 | query: jnp.ndarray,
20 | key: Optional[jnp.ndarray] = None,
21 | value: Optional[jnp.ndarray] = None,
22 | mask: Optional[jnp.ndarray] = None,
23 | ) -> jnp.ndarray:
24 | key = key if key is not None else query
25 | value = value if value is not None else query
26 |
27 | seq_len = query.shape[1]
28 | causal_mask = np.tril(np.ones((seq_len, seq_len)))
29 | mask = mask * causal_mask if mask is not None else causal_mask
30 |
31 | return super().__call__(query, key, value, mask)
32 |
33 |
34 | class DenseBlock(Module):
35 | """A 2-layer MLP which widens then narrows the input."""
36 |
37 | def __init__(self, in_dim: int, init_scale: float, widening_factor: int = 4):
38 | super().__init__()
39 | self._init_scale = init_scale
40 | initializer = jax.nn.initializers.variance_scaling(
41 | self._init_scale, mode="fan_in", distribution="normal"
42 | )
43 | self._widening_factor = widening_factor
44 | self.fc1 = Linear(in_dim, in_dim * widening_factor, w_init=initializer)
45 | self.fc2 = Linear(in_dim * widening_factor, in_dim, w_init=initializer)
46 |
47 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
48 | x = self.fc1(x)
49 | x = jax.nn.gelu(x)
50 | return self.fc2(x)
51 |
52 |
53 | class Transformer(Module):
54 | """A transformer stack."""
55 |
56 | layers: Sequence[Dict[str, Module]]
57 |
58 | def __init__(self, dim: int, num_heads: int, num_layers: int, dropout_rate: float):
59 | super().__init__()
60 | assert dim % num_heads == 0
61 | self._num_layers = num_layers
62 | self._num_heads = num_heads
63 | self._dropout_rate = dropout_rate
64 |
65 | self.rng_seq = RngSeq()
66 |
67 | init_scale = 2.0 / self._num_layers
68 | layers = []
69 | for _ in range(num_layers):
70 | layers.append(
71 | {
72 | "attention": CausalSelfAttention(
73 | num_heads=self._num_heads,
74 | key_size=dim // num_heads,
75 | w_init_scale=init_scale,
76 | ),
77 | "attn_layer_norm": LayerNorm(dim, -1, True, True),
78 | "dense_layer_norm": LayerNorm(dim, -1, True, True),
79 | "dense_block": DenseBlock(dim, init_scale),
80 | }
81 | )
82 | self.layers = layers
83 | self.layer_norm_output = LayerNorm(dim, -1, True, True)
84 |
85 | def __call__(
86 | self, h: jnp.ndarray, mask: Optional[jnp.ndarray] = None
87 | ) -> jnp.ndarray:
88 | """Connects the transformer.
89 | Args:
90 | h: Inputs, [B, T, H].
91 | mask: Padding mask, [B, T].
92 | is_training: Whether we're training or not.
93 | Returns:
94 | Array of shape [B, T, H].
95 | """
96 |
97 | dropout_rate = self._dropout_rate if self.training else 0.0
98 | if mask is not None:
99 | mask = mask[:, None, None, :]
100 |
101 | # Note: names chosen to approximately match those used in the GPT-2 code;
102 | # see https://github.com/openai/gpt-2/blob/master/src/model.py.
103 | rngs = self.rng_seq.next_rng_key(self._num_layers * 2)
104 | for i in range(self._num_layers):
105 | h_norm = self.layers[i]["attn_layer_norm"](h)
106 | h_attn = self.layers[i]["attention"](h_norm, mask=mask)
107 | h_attn = dropout(rngs[i * 2 + 0], dropout_rate, h_attn)
108 | h = h + h_attn
109 | h_norm = self.layers[i]["dense_layer_norm"](h)
110 | h_dense = self.layers[i]["dense_block"](h_norm)
111 | h_dense = dropout(rngs[i * 2 + 1], dropout_rate, h_dense)
112 | h = h + h_dense
113 | h = self.layer_norm_output(h)
114 |
115 | return h
116 |
--------------------------------------------------------------------------------
/pax/_src/nn/__init__.py:
--------------------------------------------------------------------------------
1 | """Modules."""
2 |
3 | from .attention import MultiHeadAttention
4 | from .batch_norm import BatchNorm1D, BatchNorm2D
5 | from .conv import Conv1D, Conv1DTranspose, Conv2D, Conv2DTranspose
6 | from .dropout import Dropout
7 | from .ema import EMA
8 | from .embed import Embed
9 | from .group_norm import GroupNorm
10 | from .identity import Identity
11 | from .lambda_module import Lambda
12 | from .layer_norm import LayerNorm
13 | from .linear import Linear
14 | from .pool import avg_pool, max_pool
15 | from .recurrent import GRU, LSTM, GRUState, LSTMState, VanillaRNN, VanillaRNNState
16 | from .rng_seq import RngSeq
17 | from .sequential import Sequential
18 |
--------------------------------------------------------------------------------
/pax/_src/nn/attention.py:
--------------------------------------------------------------------------------
1 | """Transformer self-attention module."""
2 |
3 | # Source: https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/attention.py
4 | from typing import Optional
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import numpy as np
9 |
10 | from ..core import Module
11 | from .linear import Linear
12 |
13 |
14 | class MultiHeadAttention(Module):
15 | """Multi-headed attention mechanism.
16 | As described in the vanilla Transformer paper:
17 | "Attention is all you need" https://arxiv.org/abs/1706.03762
18 | """
19 |
20 | num_heads: int
21 | key_size: int
22 | value_size: int
23 | model_size: int
24 |
25 | def __init__(
26 | self,
27 | num_heads: int,
28 | key_size: int,
29 | w_init_scale: float,
30 | ):
31 | super().__init__()
32 | self.num_heads = num_heads
33 | self.key_size = key_size
34 | self.value_size = key_size
35 | self.model_size = key_size * num_heads
36 | w_init = jax.nn.initializers.variance_scaling(
37 | w_init_scale, mode="fan_in", distribution="normal"
38 | )
39 | self.query_projection = Linear(
40 | self.model_size, self.model_size, w_init=w_init, name="qry_proj"
41 | )
42 | self.key_projection = Linear(
43 | self.model_size, self.model_size, w_init=w_init, name="key_proj"
44 | )
45 | self.value_projection = Linear(
46 | self.model_size, self.model_size, w_init=w_init, name="val_proj"
47 | )
48 | self.output_projection = Linear(
49 | self.model_size, self.model_size, w_init=w_init, name="out_proj"
50 | )
51 |
52 | def __call__(
53 | self,
54 | query: jnp.ndarray,
55 | key: jnp.ndarray,
56 | value: jnp.ndarray,
57 | mask: Optional[jnp.ndarray] = None,
58 | ) -> jnp.ndarray:
59 | """Compute (optionally masked) MHA with queries, keys & values."""
60 |
61 | query_heads = self.query_projection(query)
62 | key_heads = self.key_projection(key)
63 | value_heads = self.value_projection(value)
64 | (query_heads, key_heads, value_heads) = jax.tree_util.tree_map(
65 | lambda x, y: x.reshape(*y.shape[:-1], self.num_heads, self.key_size),
66 | (query_heads, key_heads, value_heads),
67 | (query, key, value),
68 | )
69 |
70 | attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads)
71 | sqrt_key_size = np.sqrt(self.key_size).astype(key.dtype)
72 | attn_logits = attn_logits / sqrt_key_size
73 | if mask is not None:
74 | # assert mask.shape == attn_logits.shape
75 | attn_logits = jnp.where(mask, attn_logits, -1e30)
76 |
77 | attn_weights = jax.nn.softmax(attn_logits)
78 | attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
79 | # Concatenate attention matrix of all heads into a single vector.
80 | attn_vec = jnp.reshape(attn, (*query.shape[:-1], -1))
81 | return self.output_projection(attn_vec)
82 |
83 | def __repr__(self, info=None) -> str:
84 | info = {"num_heads": self.num_heads, "key_size": self.key_size}
85 | return self._repr(info)
86 |
--------------------------------------------------------------------------------
/pax/_src/nn/batch_norm.py:
--------------------------------------------------------------------------------
1 | """BatchNorm modules."""
2 |
3 | from typing import Optional, Sequence
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | from ..core import Module, parameters_method
9 | from .ema import EMA
10 |
11 |
12 | class BatchNorm(Module):
13 | """A Generic BatchNorm Module.
14 |
15 | Normalize a mini-batch of data by subtracting its mean and dividing by its standard deviation.
16 |
17 | Use EMA modules to track the averaged mean and averaged variance for later uses in `eval` mode.
18 | """
19 |
20 | scale: Optional[jnp.ndarray]
21 | offset: Optional[jnp.ndarray]
22 |
23 | parameters = parameters_method("scale", "offset")
24 |
25 | ema_mean: EMA
26 | ema_var: EMA
27 |
28 | reduced_axes: Sequence[int]
29 | create_offset: bool
30 | create_scale: bool
31 | eps: float
32 | data_format: Optional[str]
33 |
34 | def __init__(
35 | self,
36 | num_channels: int,
37 | create_scale: bool = True,
38 | create_offset: bool = True,
39 | decay_rate: float = 0.9,
40 | eps: float = 1e-5,
41 | data_format: Optional[str] = None,
42 | reduced_axes=None,
43 | param_shape=None,
44 | *,
45 | name: Optional[str] = None,
46 | ):
47 | """Create a new BatchNorm module.
48 |
49 | Arguments:
50 | num_channels: the number of filters.
51 | create_scale: create a trainable `scale` parameter.
52 | create_offset: create a trainable `offset` parameter.
53 | decay_rate: the decay rate for tracking the averaged mean and the averaged variance.
54 | eps: a small positive number to avoid divided by zero.
55 | data_format: the data format ["NHWC", NCHW", "NWC", "NCW"].
56 | reduced_axes: list of axes that will be reduced in the `jnp.mean` computation.
57 | param_shape: the shape of parameters.
58 | """
59 | super().__init__(name=name)
60 | assert 0 <= decay_rate <= 1
61 |
62 | self.num_channels = num_channels
63 | self.data_format = data_format
64 | self.create_scale = create_scale
65 | self.create_offset = create_offset
66 | self.eps = eps
67 | self.decay_rate = decay_rate
68 |
69 | self.reduced_axes = tuple(reduced_axes)
70 |
71 | if create_scale:
72 | self.scale = jnp.ones(param_shape, dtype=jnp.float32)
73 | else:
74 | self.scale = None
75 | if create_offset:
76 | self.offset = jnp.zeros(param_shape, dtype=jnp.float32)
77 | else:
78 | self.offset = None
79 |
80 | # initial values do not matter because debias=True
81 | initial_mean = jnp.zeros(param_shape, dtype=jnp.float32)
82 | self.ema_mean = EMA(initial_mean, decay_rate, debias=True)
83 | initial_var = jnp.ones(param_shape, dtype=jnp.float32)
84 | self.ema_var = EMA(initial_var, decay_rate, debias=True)
85 |
86 | def __call__(self, x):
87 | if self.training:
88 | batch_mean = jnp.mean(x, axis=self.reduced_axes, keepdims=True)
89 | batch_mean_of_squares = jnp.mean(
90 | jnp.square(x), axis=self.reduced_axes, keepdims=True
91 | )
92 | batch_var = batch_mean_of_squares - jnp.square(batch_mean)
93 | self.ema_mean(batch_mean)
94 | self.ema_var(batch_var)
95 | else:
96 | batch_mean = self.ema_mean.averages
97 | batch_var = self.ema_var.averages
98 |
99 | if self.create_scale:
100 | scale = self.scale
101 | else:
102 | scale = 1.0
103 |
104 | if self.create_offset:
105 | offset = self.offset
106 | else:
107 | offset = 0.0
108 |
109 | inv = scale * jax.lax.rsqrt(batch_var + self.eps)
110 | x = (x - batch_mean) * inv + offset
111 | return x
112 |
113 | def __repr__(self):
114 | info = {
115 | "num_channels": self.num_channels,
116 | "create_scale": self.create_scale,
117 | "create_offset": self.create_offset,
118 | "data_format": self.data_format,
119 | "decay_rate": self.decay_rate,
120 | }
121 | return self._repr(info)
122 |
123 | def summary(self, return_list: bool = False):
124 | lines = super().summary(return_list=True)
125 | if return_list:
126 | return lines[:1]
127 | else:
128 | return lines[0]
129 |
130 |
131 | class BatchNorm1D(BatchNorm):
132 | """The 1D version of BatchNorm."""
133 |
134 | def __init__(
135 | self,
136 | num_channels: int,
137 | create_scale: bool = True,
138 | create_offset: bool = True,
139 | decay_rate: float = 0.9,
140 | eps: float = 1e-5,
141 | data_format: str = "NWC",
142 | *,
143 | name: Optional[str] = None,
144 | ):
145 | assert data_format in ["NWC", "NCW"], "expecting a correct `data_format`"
146 |
147 | param_shape = [1, 1, 1]
148 | if data_format == "NWC":
149 | axis = -1
150 | reduced_axes = [0, 1]
151 | else:
152 | axis = 1
153 | reduced_axes = [0, 2]
154 | param_shape[axis] = num_channels
155 |
156 | super().__init__(
157 | num_channels=num_channels,
158 | create_scale=create_scale,
159 | create_offset=create_offset,
160 | decay_rate=decay_rate,
161 | eps=eps,
162 | data_format=data_format,
163 | param_shape=param_shape,
164 | reduced_axes=reduced_axes,
165 | name=name,
166 | )
167 |
168 |
169 | class BatchNorm2D(BatchNorm):
170 | """The 2D version of BatchNorm."""
171 |
172 | def __init__(
173 | self,
174 | num_channels: int,
175 | create_scale: bool = True,
176 | create_offset: bool = True,
177 | decay_rate: float = 0.9,
178 | eps: float = 1e-5,
179 | data_format: str = "NHWC",
180 | *,
181 | name: Optional[str] = None,
182 | ):
183 | assert data_format in ["NHWC", "NCHW"], "expecting a correct `data_format`"
184 |
185 | param_shape = [1, 1, 1, 1]
186 | if data_format == "NHWC":
187 | axis = -1
188 | reduced_axes = [0, 1, 2]
189 | else:
190 | axis = 1
191 | reduced_axes = [0, 2, 3]
192 | param_shape[axis] = num_channels
193 |
194 | super().__init__(
195 | num_channels=num_channels,
196 | create_scale=create_scale,
197 | create_offset=create_offset,
198 | decay_rate=decay_rate,
199 | eps=eps,
200 | data_format=data_format,
201 | param_shape=param_shape,
202 | reduced_axes=reduced_axes,
203 | name=name,
204 | )
205 |
--------------------------------------------------------------------------------
/pax/_src/nn/dropout.py:
--------------------------------------------------------------------------------
1 | """Dropout module."""
2 |
3 | from typing import Optional
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | from ..core import StateModule
9 | from ..core.rng import KeyArray, next_rng_key
10 |
11 |
12 | def dropout(rng_key: KeyArray, dropout_rate: float, x: jnp.ndarray) -> jnp.ndarray:
13 | """Dropout input `x` randomly.
14 |
15 | Scaling the input by ``1 / (1-dropout_rate)`` makes ``E[output] = input``.
16 | """
17 | assert 0 <= dropout_rate < 1.0
18 |
19 | if dropout_rate == 0.0:
20 | return x
21 | else:
22 | mask = jax.random.bernoulli(rng_key, dropout_rate, shape=x.shape)
23 | x = jnp.where(mask, 0.0, x / (1.0 - dropout_rate))
24 | return x
25 |
26 |
27 | class Dropout(StateModule):
28 | """A Dropout Module.
29 |
30 | Dropout module stores an internal state ``rng_key``.
31 | It refreshes ``rng_key`` whenever a forward pass is executed.
32 | """
33 |
34 | rng_key: KeyArray
35 | dropout_rate: float
36 |
37 | def __init__(self, dropout_rate: float, *, name: Optional[str] = None):
38 | """Create a dropout module.
39 |
40 | Arguments:
41 | dropout_rate: the probability of dropping an element.
42 | name: the module name.
43 | """
44 | super().__init__(name=name)
45 | assert 0 <= dropout_rate < 1.0
46 |
47 | self.dropout_rate = dropout_rate
48 | self.rng_key = next_rng_key()
49 |
50 | def __call__(self, x):
51 | """Dropout `x` randomly.
52 |
53 | Return the input `x` if in `eval` mode or `dropout_rate=0`.
54 | """
55 |
56 | if self.training and self.dropout_rate > 0:
57 | self.rng_key, rng_key = jax.random.split(self.rng_key)
58 | return dropout(rng_key, self.dropout_rate, x)
59 | else:
60 | return x
61 |
62 | def __repr__(self):
63 | return super()._repr({"dropout_rate": self.dropout_rate})
64 |
--------------------------------------------------------------------------------
/pax/_src/nn/ema.py:
--------------------------------------------------------------------------------
1 | """EMA module."""
2 |
3 | from typing import Any, Optional
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | from ..core import StateModule
9 |
10 |
11 | def _has_integer_leaves(x):
12 | """check if there is any interger/bool leaves"""
13 | leaves = jax.tree_util.tree_leaves(x)
14 | return not all(jnp.issubdtype(leaf, jnp.floating) for leaf in leaves)
15 |
16 |
17 | class EMA(StateModule):
18 | """Exponential Moving Average (EMA) Module"""
19 |
20 | averages: Any
21 | decay_rate: float
22 | debias: Optional[jnp.ndarray]
23 | allow_int: bool
24 |
25 | def __init__(
26 | self,
27 | initial_value,
28 | decay_rate: float,
29 | debias: bool = False,
30 | allow_int: bool = False,
31 | ):
32 | """Create a new EMA module.
33 |
34 | If allow_int=True, integer leaves are updated to
35 | the newest values instead of averaging.
36 |
37 | Arguments:
38 | initial_value: the initial value.
39 | decay_rate: the decay rate.
40 | debias: ignore the initial value to avoid biased estimates.
41 | allow_int: allow integer values.
42 | """
43 | if not allow_int:
44 | if _has_integer_leaves(initial_value):
45 | raise ValueError(
46 | "There are integer arrays in the initial value.\n"
47 | "Use `allow_int=True` to allow this."
48 | )
49 |
50 | super().__init__()
51 | self.averages = initial_value
52 | self.decay_rate = decay_rate
53 | self.allow_int = allow_int
54 | if debias:
55 | # avoid integer ndarray for `jax.grad` convenience,
56 | # e.g., no need to pass `allow_int=True` to `jax.grad`.
57 | self.debias = jnp.array(0.0)
58 | else:
59 | self.debias = None
60 |
61 | def __call__(self, xs):
62 | """Return the ema of `xs`. Also, update internal states."""
63 | if not self.allow_int:
64 | if _has_integer_leaves(xs):
65 | raise ValueError(
66 | "There are integer arrays in the new value.\n"
67 | "Use `allow_int=True` to allow this."
68 | )
69 |
70 | if self.training:
71 | if self.debias is not None:
72 | cond = self.debias > 0
73 | debias_func = lambda a, x: jnp.where(cond, a, x)
74 | self.debias = jnp.array(1.0)
75 | else:
76 | debias_func = lambda a, _: a
77 |
78 | def update_fn(a, x):
79 | if jnp.issubdtype(a, jnp.floating):
80 | a = debias_func(a, x)
81 | return a * self.decay_rate + x * (1 - self.decay_rate)
82 | else:
83 | return x
84 |
85 | self.averages = jax.tree_util.tree_map(update_fn, self.averages, xs)
86 |
87 | return self.averages
88 |
--------------------------------------------------------------------------------
/pax/_src/nn/embed.py:
--------------------------------------------------------------------------------
1 | """Embed module."""
2 |
3 | from typing import Callable, Optional
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | from ..core import ParameterModule
9 | from ..core.rng import KeyArray, next_rng_key
10 |
11 |
12 | class Embed(ParameterModule):
13 | """Embed module maps integer values to real vectors.
14 | The embedded vectors are trainable.
15 | """
16 |
17 | weight: jnp.ndarray
18 | vocab_size: int
19 | embed_dim: int
20 |
21 | def __init__(
22 | self,
23 | vocab_size: int,
24 | embed_dim: int,
25 | w_init: Optional[Callable] = None,
26 | *,
27 | rng_key: Optional[KeyArray] = None,
28 | name: Optional[str] = None
29 | ):
30 | """
31 | An embed module.
32 |
33 | Arguments:
34 | vocab_size: the number of embedded vectors.
35 | embed_dim: the size of embedded vectors.
36 | w_init: weight initializer. Default: `truncated_normal`.
37 | name: module name.
38 | """
39 |
40 | super().__init__(name=name)
41 |
42 | self.vocab_size = vocab_size
43 | self.embed_dim = embed_dim
44 | shape = [vocab_size, embed_dim]
45 |
46 | if w_init is None:
47 | w_init = jax.nn.initializers.normal()
48 |
49 | if rng_key is None:
50 | rng_key = next_rng_key()
51 |
52 | self.weight = w_init(rng_key, shape)
53 |
54 | def __call__(self, x: jnp.ndarray):
55 | """Return embedded vectors indexed by ``x``."""
56 | return self.weight[(x,)]
57 |
58 | def __repr__(self):
59 | info = {"vocab_size": self.vocab_size, "embed_dim": self.embed_dim}
60 | return self._repr(info)
61 |
--------------------------------------------------------------------------------
/pax/_src/nn/identity.py:
--------------------------------------------------------------------------------
1 | """Identity module."""
2 |
3 | from ..core import Module
4 |
5 |
6 | class Identity(Module):
7 | """Identity function as a module."""
8 |
9 | def __call__(self, x):
10 | """return x"""
11 | return x
12 |
--------------------------------------------------------------------------------
/pax/_src/nn/lambda_module.py:
--------------------------------------------------------------------------------
1 | """Lambda module."""
2 |
3 | from ..core.utility_modules import Lambda
4 |
5 | __all__ = ("Lambda",)
6 |
--------------------------------------------------------------------------------
/pax/_src/nn/layer_norm.py:
--------------------------------------------------------------------------------
1 | """LayerNorm Module."""
2 |
3 | # The implementation is almost identical to dm-haiku LayerNorm at:
4 | # https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/layer_norm.py
5 | # deepmind/dm-haiku is licensed under the Apache License 2.0
6 | #
7 | # Differences:
8 | # 1. We need to input ``num_channels``, the size of the last dimension,
9 | # to initialize scale/offset parameters.
10 | # 2. We can input `rng_key` to seed the value of scale/offset parameters.
11 |
12 | import collections
13 | from typing import Callable, Optional, Sequence, Union
14 |
15 | import jax
16 | import jax.numpy as jnp
17 | import numpy as np
18 |
19 | from ..core import ParameterModule
20 | from ..core.rng import KeyArray, next_rng_key
21 |
22 |
23 | class LayerNorm(ParameterModule):
24 | """LayerNorm module.
25 | See: https://arxiv.org/abs/1607.06450.
26 | """
27 |
28 | scale: Optional[jnp.ndarray]
29 | offset: Optional[jnp.ndarray]
30 |
31 | def __init__(
32 | self,
33 | num_channels: int,
34 | axis: Union[int, Sequence[int], slice],
35 | create_scale: bool,
36 | create_offset: bool,
37 | eps: float = 1e-5,
38 | scale_init: Optional[Callable] = None,
39 | offset_init: Optional[Callable] = None,
40 | *,
41 | rng_key: Optional[KeyArray] = None,
42 | name: Optional[str] = None,
43 | ):
44 | jax.nn.initializers
45 | """Constructs a LayerNorm module.
46 |
47 | Arguments:
48 | num_channels: Integer, size of the last dimension. The data format is ``[N, ..., C]``.
49 | axis: Integer, list of integers, or slice indicating which axes to normalize over.
50 | create_scale: Bool, defines whether to create a trainable scale
51 | per channel applied after the normalization.
52 | create_offset: Bool, defines whether to create a trainable offset
53 | per channel applied after normalization and scaling.
54 | eps: Small epsilon to avoid division by zero variance.
55 | Defaults ``1e-5``, as in the paper and Sonnet.
56 | scale_init: Optional initializer for gain (aka scale). By default, one.
57 | offset_init: Optional initializer for bias (aka offset). By default, zero.
58 | rng_key: RNG key.
59 | name: module name.
60 | """
61 | super().__init__(name=name)
62 | if not create_scale and scale_init is not None:
63 | raise ValueError("Cannot set `scale_init` if `create_scale=False`.")
64 | if not create_offset and offset_init is not None:
65 | raise ValueError("Cannot set `offset_init` if `create_offset=False`.")
66 |
67 | if isinstance(axis, slice):
68 | self.axis = axis
69 | elif isinstance(axis, int):
70 | self.axis = (axis,)
71 | elif isinstance(axis, collections.abc.Iterable) and all(
72 | isinstance(ax, int) for ax in axis
73 | ):
74 | self.axis = tuple(axis)
75 | else:
76 | raise ValueError("`axis` should be an int, slice or iterable of ints.")
77 |
78 | self.eps = eps
79 | self.create_scale = create_scale
80 | self.create_offset = create_offset
81 | self.scale_init = scale_init or jax.nn.initializers.ones
82 | self.offset_init = offset_init or jax.nn.initializers.zeros
83 | self.num_channels = num_channels
84 |
85 | param_shape = [num_channels]
86 | rng_key = next_rng_key() if rng_key is None else rng_key
87 | rng1, rng2 = jax.random.split(rng_key)
88 | if create_scale:
89 | self.scale = self.scale_init(rng1, param_shape)
90 | else:
91 | self.scale = None
92 | if create_offset:
93 | self.offset = self.offset_init(rng2, param_shape)
94 | else:
95 | self.offset = None
96 |
97 | def __call__(
98 | self,
99 | inputs: jnp.ndarray,
100 | scale: Optional[jnp.ndarray] = None,
101 | offset: Optional[jnp.ndarray] = None,
102 | ) -> jnp.ndarray:
103 | """Returns normalized inputs.
104 |
105 | Arguments:
106 | inputs: An array, where the data format is ``[N, ..., C]``.
107 | scale: An array up to n-D. The shape of this tensor must be broadcastable
108 | to the shape of ``inputs``. This is the scale applied to the normalized
109 | inputs. This cannot be passed in if the module was constructed with
110 | ``create_scale=True``.
111 | offset: An array up to n-D. The shape of this tensor must be broadcastable
112 | to the shape of ``inputs``. This is the offset applied to the normalized
113 | inputs. This cannot be passed in if the module was constructed with
114 | ``create_offset=True``.
115 |
116 | Returns:
117 | The array, normalized.
118 | """
119 | if self.create_scale and scale is not None:
120 | raise ValueError("Cannot pass `scale` at call time if `create_scale=True`.")
121 | if self.create_offset and offset is not None:
122 | raise ValueError(
123 | "Cannot pass `offset` at call time if `create_offset=True`."
124 | )
125 |
126 | axis = self.axis
127 | if isinstance(axis, slice):
128 | axis = tuple(range(inputs.ndim)[axis])
129 |
130 | mean = jnp.mean(inputs, axis=axis, keepdims=True)
131 | variance = jnp.var(inputs, axis=axis, keepdims=True)
132 |
133 | # param_shape = inputs.shape[-1:]
134 | if self.create_scale:
135 | scale = self.scale
136 | elif scale is None:
137 | scale = np.array(1.0, dtype=inputs.dtype)
138 |
139 | if self.create_offset:
140 | offset = self.offset
141 | elif offset is None:
142 | offset = np.array(0.0, dtype=inputs.dtype)
143 |
144 | scale = jnp.broadcast_to(scale, inputs.shape)
145 | offset = jnp.broadcast_to(offset, inputs.shape)
146 | mean = jnp.broadcast_to(mean, inputs.shape)
147 |
148 | eps = jax.lax.convert_element_type(self.eps, variance.dtype)
149 | inv = scale * jax.lax.rsqrt(variance + eps)
150 | return inv * (inputs - mean) + offset
151 |
152 | def __repr__(self, info=None) -> str:
153 | info = {
154 | "num_channels": self.num_channels,
155 | "axis": self.axis,
156 | "create_scale": self.create_scale,
157 | "create_offset": self.create_offset,
158 | }
159 | return self._repr(info)
160 |
--------------------------------------------------------------------------------
/pax/_src/nn/linear.py:
--------------------------------------------------------------------------------
1 | """Linear module."""
2 |
3 | from typing import Optional
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 |
9 | from ..core import ParameterModule
10 | from ..core.rng import KeyArray, next_rng_key
11 |
12 |
13 | class Linear(ParameterModule):
14 | """A linear transformation is applied over the last dimension of the input."""
15 |
16 | weight: jnp.ndarray
17 | bias: jnp.ndarray
18 |
19 | in_dim: int
20 | out_dim: int
21 | with_bias: bool
22 |
23 | def __init__(
24 | self,
25 | in_dim: int,
26 | out_dim: int,
27 | with_bias: bool = True,
28 | w_init=None,
29 | b_init=None,
30 | *,
31 | rng_key: KeyArray = None,
32 | name: Optional[str] = None,
33 | ):
34 | """
35 | Arguments:
36 | in_dim: the number of input features.
37 | out_dim: the number of output features.
38 | with_bias: whether to add a bias to the output (default: True).
39 | w_init: initializer function for the weight matrix.
40 | b_init: initializer function for the bias.
41 | rng_key: the key to generate initial parameters.
42 | name: module name.
43 | """
44 | super().__init__(name=name)
45 | self.in_dim = in_dim
46 | self.out_dim = out_dim
47 | self.with_bias = with_bias
48 |
49 | rng_key = next_rng_key() if rng_key is None else rng_key
50 | if w_init is None:
51 | w_init = jax.nn.initializers.normal(stddev=1.0 / np.sqrt(self.in_dim))
52 | if b_init is None:
53 | b_init = jax.nn.initializers.normal(stddev=1.0 / np.sqrt(self.in_dim))
54 | rng_key_w, rng_key_b = jax.random.split(rng_key)
55 | self.weight = w_init(rng_key_w, (in_dim, out_dim))
56 | if self.with_bias:
57 | self.bias = b_init(rng_key_b, (out_dim,))
58 |
59 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
60 | """Applies a linear transformation to the inputs along the last dimension.
61 |
62 | Arguments:
63 | x: The nd-array to be transformed.
64 | """
65 | assert len(x.shape) >= 2, "expecting an input of shape `N...C`"
66 | x = jnp.dot(x, self.weight)
67 | if self.with_bias:
68 | x = x + self.bias
69 | return x
70 |
71 | def __repr__(self):
72 | info = {
73 | "in_dim": self.in_dim,
74 | "out_dim": self.out_dim,
75 | "with_bias": self.with_bias,
76 | }
77 | return self._repr(info)
78 |
--------------------------------------------------------------------------------
/pax/_src/nn/pool.py:
--------------------------------------------------------------------------------
1 | # Source: https://raw.githubusercontent.com/deepmind/dm-haiku/main/haiku/_src/pool.py
2 | #
3 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Pooling Haiku modules."""
18 |
19 | import warnings
20 | from typing import Optional, Sequence, Tuple, Union
21 |
22 | import jax.numpy as jnp
23 | import numpy as np
24 | from jax import lax
25 |
26 |
27 | def _infer_shape(
28 | x: jnp.ndarray,
29 | size: Union[int, Sequence[int]],
30 | channel_axis: Optional[int] = -1,
31 | ) -> Tuple[int, ...]:
32 | """Infer shape for pooling window or strides."""
33 | if isinstance(size, int):
34 | if channel_axis and not 0 <= abs(channel_axis) < x.ndim:
35 | raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
36 | if channel_axis and channel_axis < 0:
37 | channel_axis = x.ndim + channel_axis
38 | return (1,) + tuple(size if d != channel_axis else 1 for d in range(1, x.ndim))
39 | elif len(size) < x.ndim:
40 | # Assume additional dimensions are batch dimensions.
41 | return (1,) * (x.ndim - len(size)) + tuple(size)
42 | else:
43 | assert x.ndim == len(size)
44 | return tuple(size)
45 |
46 |
47 | _VMAP_SHAPE_INFERENCE_WARNING = (
48 | "When running under vmap, passing an `int` (except for `1`) for "
49 | "`window_shape` or `strides` will result in the wrong shape being inferred "
50 | "because the batch dimension is not visible to Haiku. Please update your "
51 | "code to specify a full unbatched size. "
52 | ""
53 | "For example if you had `pool(x, window_shape=3, strides=1)` before, you "
54 | "should now pass `pool(x, window_shape=(3, 3, 1), strides=1)`. "
55 | ""
56 | "Haiku will assume that any additional dimensions in your input are "
57 | "batch dimensions, and will pad `window_shape` and `strides` accordingly "
58 | "making your module support both batched and per-example inputs."
59 | )
60 |
61 |
62 | def _warn_if_unsafe(window_shape, strides):
63 | unsafe = lambda size: isinstance(size, int) and size != 1
64 | if unsafe(window_shape) or unsafe(strides):
65 | warnings.warn(_VMAP_SHAPE_INFERENCE_WARNING, DeprecationWarning)
66 |
67 |
68 | def max_pool(
69 | value: jnp.ndarray,
70 | window_shape: Union[int, Sequence[int]],
71 | strides: Union[int, Sequence[int]],
72 | padding: str,
73 | channel_axis: Optional[int] = -1,
74 | ) -> jnp.ndarray:
75 | """Max pool.
76 |
77 | Args:
78 | value: Value to pool.
79 | window_shape: Shape of the pooling window, an int or same rank as value.
80 | strides: Strides of the pooling window, an int or same rank as value.
81 | padding: Padding algorithm. Either ``VALID`` or ``SAME``.
82 | channel_axis: Axis of the spatial channels for which pooling is skipped,
83 | used to infer ``window_shape`` or ``strides`` if they are an integer.
84 |
85 | Returns:
86 | Pooled result. Same rank as value.
87 | """
88 | if padding not in ("SAME", "VALID"):
89 | raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
90 |
91 | _warn_if_unsafe(window_shape, strides)
92 | window_shape = _infer_shape(value, window_shape, channel_axis)
93 | strides = _infer_shape(value, strides, channel_axis)
94 |
95 | return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides, padding)
96 |
97 |
98 | def avg_pool(
99 | value: jnp.ndarray,
100 | window_shape: Union[int, Sequence[int]],
101 | strides: Union[int, Sequence[int]],
102 | padding: str,
103 | channel_axis: Optional[int] = -1,
104 | ) -> jnp.ndarray:
105 | """Average pool.
106 |
107 | Args:
108 | value: Value to pool.
109 | window_shape: Shape of the pooling window, an int or same rank as value.
110 | strides: Strides of the pooling window, an int or same rank as value.
111 | padding: Padding algorithm. Either ``VALID`` or ``SAME``.
112 | channel_axis: Axis of the spatial channels for which pooling is skipped,
113 | used to infer ``window_shape`` or ``strides`` if they are an integer.
114 |
115 | Returns:
116 | Pooled result. Same rank as value.
117 |
118 | Raises:
119 | ValueError: If the padding is not valid.
120 | """
121 | if padding not in ("SAME", "VALID"):
122 | raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
123 |
124 | _warn_if_unsafe(window_shape, strides)
125 | window_shape = _infer_shape(value, window_shape, channel_axis)
126 | strides = _infer_shape(value, strides, channel_axis)
127 |
128 | reduce_window_args = (0.0, lax.add, window_shape, strides, padding)
129 | pooled = lax.reduce_window(value, *reduce_window_args)
130 | if padding == "VALID":
131 | # Avoid the extra reduce_window.
132 | return pooled / np.prod(window_shape)
133 | else:
134 | # Count the number of valid entries at each input point, then use that for
135 | # computing average. Assumes that any two arrays of same shape will be
136 | # padded the same.
137 | window_counts = lax.reduce_window(jnp.ones_like(value), *reduce_window_args)
138 | assert pooled.shape == window_counts.shape
139 | return pooled / window_counts
140 |
--------------------------------------------------------------------------------
/pax/_src/nn/recurrent.py:
--------------------------------------------------------------------------------
1 | """Recurrent modules."""
2 |
3 | from typing import Callable, NamedTuple, Optional, Tuple
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | from ..core import Module
9 | from ..core.rng import KeyArray, next_rng_key
10 | from .linear import Linear
11 |
12 |
13 | class LSTMState(NamedTuple):
14 | """LSTMState."""
15 |
16 | hidden: jnp.ndarray
17 | cell: jnp.ndarray
18 |
19 |
20 | class GRUState(NamedTuple):
21 | """GRUState."""
22 |
23 | hidden: jnp.ndarray
24 |
25 |
26 | class VanillaRNNState(NamedTuple):
27 | """VanillaRNNState."""
28 |
29 | hidden: jnp.ndarray
30 |
31 |
32 | class RNN(Module):
33 | """Base class for all recurrent modules."""
34 |
35 | def __init__(self, name: Optional[str] = None):
36 | super().__init__(name=name)
37 |
38 | def initial_state(self, batch_size):
39 | raise NotImplementedError()
40 |
41 |
42 | class VanillaRNN(RNN):
43 | """Basic recurrent neural network."""
44 |
45 | input_dim: int
46 | hidden_dim: int
47 | fc: Linear
48 |
49 | def __init__(
50 | self,
51 | input_dim: int,
52 | hidden_dim: int,
53 | *,
54 | rng_key: KeyArray = None,
55 | name: Optional[str] = None
56 | ):
57 | """Create a vanilla RNN module.
58 |
59 | Arguments:
60 | input_dim: input dimension.
61 | hidden_dim: hidden dimension.
62 | rng_key: random key.
63 | name: module name.
64 | """
65 | super().__init__(name=name)
66 | self.input_dim = input_dim
67 | self.hidden_dim = hidden_dim
68 | self.fc = Linear(
69 | input_dim + hidden_dim,
70 | hidden_dim,
71 | rng_key=rng_key,
72 | name="vanilla_rnn_fc",
73 | )
74 |
75 | def __call__(
76 | self, state: VanillaRNNState, x: jnp.ndarray
77 | ) -> Tuple[VanillaRNNState, jnp.ndarray]:
78 | """A single rnn step."""
79 | xh = jnp.concatenate((x, state.hidden), axis=-1)
80 | hidden = jnp.tanh(self.fc(xh))
81 | return VanillaRNNState(hidden), hidden
82 |
83 | def __repr__(self):
84 | info = {"input_dim": self.input_dim, "hidden_dim": self.hidden_dim}
85 | return self._repr(info)
86 |
87 | def initial_state(self, batch_size) -> VanillaRNNState:
88 | shape = (batch_size, self.hidden_dim)
89 | hidden = jnp.zeros(shape=shape, dtype=jnp.float32)
90 | return VanillaRNNState(hidden=hidden)
91 |
92 |
93 | class LSTM(RNN):
94 | """Long Short Term Memory (LSTM) RNN module."""
95 |
96 | input_dim: int
97 | hidden_dim: int
98 |
99 | weight: jnp.ndarray
100 | bias: jnp.ndarray
101 |
102 | def __init__(
103 | self,
104 | input_dim: int,
105 | hidden_dim: int,
106 | w_init: Optional[Callable] = None,
107 | forget_gate_bias: float = 0.0,
108 | *,
109 | rng_key: KeyArray = None,
110 | name: Optional[str] = None
111 | ):
112 | """Create a LSTM module.
113 |
114 | Arguments:
115 | input_dim: The input dimension.
116 | hidden_dim: The number of LSTM cells.
117 | w_init: weight initializer.
118 | forget_gate_bias: Prefer forget. Default `0`.
119 | rng_key: random key.
120 | name: module name.
121 | """
122 |
123 | super().__init__(name=name)
124 | self.input_dim = input_dim
125 | self.hidden_dim = hidden_dim
126 | self.forget_gate_bias = forget_gate_bias
127 |
128 | self.fc = Linear(
129 | (input_dim + hidden_dim),
130 | 4 * hidden_dim,
131 | rng_key=rng_key,
132 | name="lstm_fc",
133 | w_init=w_init,
134 | )
135 |
136 | def __call__(
137 | self,
138 | state: LSTMState,
139 | x: jnp.ndarray,
140 | ) -> Tuple[LSTMState, jnp.ndarray]:
141 | """Do a single lstm step.
142 |
143 |
144 | Arguments:
145 | state: The current LSTM state.
146 | x: The input.
147 | """
148 | xh = jnp.concatenate((x, state.hidden), axis=-1)
149 | gated = self.fc(xh)
150 | i, g, f, o = jnp.split(gated, 4, axis=-1)
151 | f = jax.nn.sigmoid(f + self.forget_gate_bias)
152 | c = f * state.cell + jax.nn.sigmoid(i) * jnp.tanh(g)
153 | h = jax.nn.sigmoid(o) * jnp.tanh(c)
154 | return LSTMState(h, c), h
155 |
156 | def __repr__(self):
157 | info = {"input_dim": self.input_dim, "hidden_dim": self.hidden_dim}
158 | return self._repr(info)
159 |
160 | def initial_state(self, batch_size) -> LSTMState:
161 | shape = (batch_size, self.hidden_dim)
162 | hidden = jnp.zeros(shape=shape, dtype=jnp.float32)
163 | cell = jnp.zeros(shape=shape, dtype=jnp.float32)
164 | return LSTMState(hidden=hidden, cell=cell)
165 |
166 |
167 | class GRU(RNN):
168 | """This class implements the "fully gated unit" GRU.
169 |
170 | Reference: https://en.wikipedia.org/wiki/Gated_recurrent_unit
171 | """
172 |
173 | input_dim: int
174 | hidden_dim: int
175 |
176 | def __init__(
177 | self,
178 | input_dim: int,
179 | hidden_dim: int,
180 | *,
181 | rng_key: Optional[KeyArray] = None,
182 | name: Optional[str] = None
183 | ):
184 | """Create a GRU module.
185 |
186 | Arguments:
187 | input_dim: the input size.
188 | hidden_dim: the number of GRU cells.
189 | """
190 | super().__init__(name=name)
191 |
192 | self.input_dim = input_dim
193 | self.hidden_dim = hidden_dim
194 |
195 | if rng_key is None:
196 | rng_key = next_rng_key()
197 | rng_key_1, rng_key_2 = jax.random.split(rng_key, 2)
198 | self.xh_zr_fc = Linear(
199 | (input_dim + hidden_dim), hidden_dim * 2, name="xh_to_zr", rng_key=rng_key_1
200 | )
201 |
202 | self.xh_h_fc = Linear(
203 | (input_dim + hidden_dim), hidden_dim, name="xh_to_h", rng_key=rng_key_2
204 | )
205 |
206 | def initial_state(self, batch_size: int) -> GRUState:
207 | """Create an all zeros initial state."""
208 | return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32))
209 |
210 | def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]:
211 | """Do a single gru step.
212 |
213 | Arguments:
214 | state: The current GRU state.
215 | x: The input.
216 | """
217 | hidden = state.hidden
218 | xh = jnp.concatenate((x, hidden), axis=-1)
219 | zr = jax.nn.sigmoid(self.xh_zr_fc(xh))
220 | z, r = jnp.split(zr, 2, axis=-1)
221 |
222 | xrh = jnp.concatenate((x, r * hidden), axis=-1)
223 | h_hat = jnp.tanh(self.xh_h_fc(xrh))
224 | h = (1 - z) * hidden + z * h_hat
225 | return GRUState(h), h
226 |
227 | def __repr__(self):
228 | info = {"input_dim": self.input_dim, "hidden_dim": self.hidden_dim}
229 | return self._repr(info)
230 |
--------------------------------------------------------------------------------
/pax/_src/nn/rng_seq.py:
--------------------------------------------------------------------------------
1 | """RngSeq module."""
2 |
3 | from typing import Optional, Sequence, Union
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 |
9 | from ..core import StateModule, rng
10 |
11 |
12 | class RngSeq(StateModule):
13 | """A module which generates an infinite sequence of rng keys."""
14 |
15 | _rng_key: rng.KeyArray
16 |
17 | def __init__(
18 | self, seed: Optional[int] = None, rng_key: Optional[rng.KeyArray] = None
19 | ):
20 | """Initialize a random key sequence.
21 |
22 | **Note**: ``rng_key`` has a higher priority than ``seed``.
23 |
24 | Arguments:
25 | seed: an integer seed.
26 | rng_key: a jax random key.
27 | """
28 | super().__init__()
29 | if rng_key is not None:
30 | rng_key_ = rng_key
31 | elif seed is not None:
32 | rng_key_ = jax.random.PRNGKey(seed)
33 | else:
34 | rng_key_ = rng.next_rng_key()
35 |
36 | if isinstance(rng_key_, (np.ndarray, jnp.ndarray)):
37 | self._rng_key = rng_key_
38 | else:
39 | raise ValueError("Impossible")
40 |
41 | def next_rng_key(
42 | self, num_keys: int = 1
43 | ) -> Union[rng.KeyArray, Sequence[rng.KeyArray]]:
44 | """Return the next random key of the sequence.
45 |
46 | **Note**:
47 |
48 | * Return a key if ``num_keys`` is ``1``,
49 | * Return a list of keys if ``num_keys`` is greater than ``1``.
50 | * This is not a deterministic sequence if values of ``num_keys`` are mixed randomly.
51 |
52 | Arguments:
53 | num_keys: return more than one key.
54 | """
55 | self._rng_key, *rng_keys = jax.random.split(self._rng_key, num_keys + 1)
56 | return rng_keys[0] if num_keys == 1 else rng_keys
57 |
--------------------------------------------------------------------------------
/pax/_src/nn/sequential.py:
--------------------------------------------------------------------------------
1 | """Sequential module."""
2 |
3 | from typing import Optional, Tuple, TypeVar
4 |
5 | from ..core import Module
6 | from .lambda_module import Lambda
7 |
8 | T = TypeVar("T", bound=Module)
9 |
10 |
11 | class Sequential(Module):
12 | """Execute layers in order.
13 |
14 | Support pax.Module (callable pytree) and any jax functions.
15 |
16 | For example:
17 |
18 | >>> net = pax.Sequential(
19 | ... pax.Linear(2, 32),
20 | ... jax.nn.relu,
21 | ... pax.Linear(32, 3)
22 | ... )
23 | >>> print(net.summary())
24 | Sequential
25 | ├── Linear(in_dim=2, out_dim=32, with_bias=True)
26 | ├── x => relu(x)
27 | └── Linear(in_dim=32, out_dim=3, with_bias=True)
28 | >>> x = jnp.empty((3, 2))
29 | >>> y = net(x)
30 | >>> y.shape
31 | (3, 3)
32 | """
33 |
34 | # Note: we cannot mix pax.Module and jax functions (e.g., jax.nn.relu) in the same list.
35 | # therefore, we have to convert a jax function to ``Lambda`` module first.
36 | modules: Tuple[Module, ...]
37 |
38 | def __init__(self, *layers, name: Optional[str] = None):
39 | """Create a Sequential module."""
40 | super().__init__(name=name)
41 | self.modules = tuple(
42 | (f if isinstance(f, Module) else Lambda(f)) for f in layers
43 | )
44 |
45 | def __call__(self, x):
46 | """Call layers in order."""
47 | for f in self.modules:
48 | x = f(x)
49 | return x
50 |
51 | def __getitem__(self, index: int) -> T:
52 | """Get an item from the `modules` list."""
53 | return self.modules[index]
54 |
55 | def set(self: T, index: int, value) -> T:
56 | """Set an item to the `modules` list."""
57 | if not isinstance(value, Module):
58 | value = Lambda(value)
59 |
60 | modules = list(self.modules)
61 | modules[index] = value
62 | return super().replace(modules=tuple(modules))
63 |
64 | def __rshift__(self, other: Module):
65 | return Sequential(*self.modules, other)
66 |
--------------------------------------------------------------------------------
/pax/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | """Experimental API"""
2 |
3 |
4 | from pax._src.core import Flattener, LazyModule, mutable
5 | from pax._src.utils import (
6 | apply_scaled_gradients,
7 | default_mp_policy,
8 | load_weights_from_dict,
9 | save_weights_to_dict,
10 | )
11 |
12 | from . import graph
13 |
14 | __all__ = (
15 | "apply_scaled_gradients",
16 | "default_mp_policy",
17 | "Flattener",
18 | "graph",
19 | "LazyModule",
20 | "load_weights_from_dict",
21 | "mutable",
22 | "save_weights_to_dict",
23 | )
24 |
--------------------------------------------------------------------------------
/pax/experimental/graph.py:
--------------------------------------------------------------------------------
1 | """Experimental graph API"""
2 |
3 |
4 | from pax._src.core.graph_module import GraphModule, InputNode, Node, build_graph_module
5 |
6 | __all__ = (
7 | "build_graph_module",
8 | "GraphModule",
9 | "InputNode",
10 | "Node",
11 | )
12 |
--------------------------------------------------------------------------------
/pax/nets.py:
--------------------------------------------------------------------------------
1 | """Public nets."""
2 |
3 | from pax._src.nets import (
4 | ResNet18,
5 | ResNet34,
6 | ResNet50,
7 | ResNet101,
8 | ResNet152,
9 | ResNet200,
10 | Transformer,
11 | )
12 |
13 | __all__ = (
14 | "ResNet18",
15 | "ResNet34",
16 | "ResNet50",
17 | "ResNet101",
18 | "ResNet152",
19 | "ResNet200",
20 | "Transformer",
21 | )
22 |
--------------------------------------------------------------------------------
/pax/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NTT123/pax/13916cb86ede38c56750cf1bde3ac37c63674014/pax/py.typed
--------------------------------------------------------------------------------
/pax/utils.py:
--------------------------------------------------------------------------------
1 | """Public utility functions."""
2 |
3 | from pax._src.utils import build_update_fn, grad, scan
4 |
5 | __all__ = (
6 | "build_update_fn",
7 | "grad",
8 | "scan",
9 | )
10 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup PAX3 package."""
2 |
3 | from setuptools import find_namespace_packages, setup
4 |
5 |
6 | def _get_version():
7 | with open("pax/__init__.py", encoding="utf-8") as file:
8 | for line in file:
9 | if line.startswith("__version__"):
10 | _globals = {}
11 | exec(line, _globals) # pylint: disable=exec-used
12 | return _globals["__version__"]
13 | raise ValueError("`__version__` not defined in `pax/__init__.py`")
14 |
15 |
16 | __version__ = _get_version()
17 | URL = "https://github.com/ntt123/pax"
18 |
19 | install_requires = ["jax>=0.2.21", "jmp>=0.0.2"]
20 | setup_requires = []
21 | tests_requires = [
22 | "chex",
23 | "dm-haiku",
24 | "fire",
25 | "opax",
26 | "pytest",
27 | "pytype",
28 | "tqdm",
29 | ]
30 |
31 | setup(
32 | name="pax3",
33 | version=__version__,
34 | description="A stateful pytree library for training neural networks.",
35 | long_description=open("README.md", encoding="utf-8").read(),
36 | long_description_content_type="text/markdown",
37 | author="Thông Nguyễn",
38 | url=URL,
39 | keywords=[
40 | "deep-learning",
41 | "jax",
42 | ],
43 | install_requires=install_requires,
44 | setup_requires=setup_requires,
45 | tests_require=tests_requires,
46 | packages=find_namespace_packages(exclude=["examples", "tests"]),
47 | extras_require={"test": tests_requires},
48 | python_requires=">=3.7",
49 | include_package_data=True,
50 | zip_safe=False,
51 | )
52 |
--------------------------------------------------------------------------------
/tests/test_auto_modules.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import pax
4 |
5 | # import pytest
6 |
7 |
8 | def test_scan_bug_param_module():
9 | class M(pax.ParameterModule):
10 | def __init__(self):
11 | super().__init__()
12 | self.a = jnp.array(0.0)
13 |
14 | # with pytest.raises(ValueError):
15 | _ = M()
16 |
17 |
18 | def test_scan_bug_state_module():
19 | class M(pax.StateModule):
20 | def __init__(self):
21 | super().__init__()
22 | self.a = jnp.array(0.0)
23 |
24 | # with pytest.raises(ValueError):
25 | _ = M()
26 |
27 |
28 | def test_auto_module():
29 | class M(pax.experimental.LazyModule):
30 | def __call__(self, x):
31 | x = self.get_or_create("fc", lambda: pax.Linear(1, 1))(x)
32 | x = jax.nn.relu(x)
33 | return x
34 |
35 | m = M()
36 | x = jnp.ones((2, 1))
37 | m, _ = pax.module_and_value(m)(x)
38 | print(m.summary())
39 |
--------------------------------------------------------------------------------
/tests/test_counter.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import pax
4 |
5 |
6 | def test_counter():
7 | class Counter(pax.Module):
8 | counter: jnp.ndarray
9 | bias: jnp.ndarray
10 | parameters = pax.parameters_method("counter")
11 |
12 | def __init__(self, start_value: int = 0):
13 | super().__init__()
14 |
15 | self.counter = jnp.array(start_value, dtype=jnp.int32)
16 | self.bias = jnp.array(0.0)
17 |
18 | def __call__(self, x):
19 | self.counter = self.counter + 1
20 | return self.counter * x + self.bias
21 |
22 | @pax.pure
23 | def loss_fn(model: Counter, x: jnp.ndarray):
24 | y = model(x)
25 | loss = jnp.mean(jnp.square(x - y))
26 | return loss, (loss, model)
27 |
28 | grad_fn = jax.grad(loss_fn, has_aux=True, allow_int=True)
29 |
30 | net = Counter(3)
31 | x = jnp.array(10.0)
32 | grads, (loss, net) = grad_fn(net, x)
33 | assert grads.counter.dtype is jax.float0
34 | assert grads.bias.item() == 60.0
35 |
--------------------------------------------------------------------------------
/tests/test_deepscan.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import numpy as np
3 | import pax
4 | import pytest
5 |
6 |
7 | def test_list_of_mod():
8 | class M(pax.Module):
9 | def __init__(self):
10 | super().__init__()
11 | self.a = [pax.Linear(3, 3)]
12 |
13 | m = M()
14 | # m.pax.name_to_kind["a"] == pax.PaxKind.MODULE
15 |
16 |
17 | def test_assigned_field_an_array():
18 | class M(pax.ParameterModule):
19 | def __init__(self):
20 | super().__init__()
21 | self.a = np.array([3.0, 1.0], dtype=np.float32)
22 |
23 | # no error because we will automatically assign `a` to kind PARAMETER
24 | m = M()
25 | # assert m.pax.name_to_kind["a"] == pax.PaxKind.PARAMETER
26 |
27 | class N(pax.Module):
28 | def __init__(self):
29 | super().__init__()
30 |
31 | n = N()
32 |
33 | n.scan_bugs()
34 | # no error because we will automatically assign `a` to kind PARAMETER
35 | def mutate(n: N) -> N:
36 | n.b = jnp.array([1, 2, 3], dtype=jnp.float32)
37 | return n
38 |
39 | n = pax.pure(mutate)(n)
40 | assert "b" in n.pytree_attributes
41 |
42 | # assert n.pax.name_to_kind["b"] == pax.PaxKind.PARAMETER
43 |
44 |
45 | def test_assign_int_to_param():
46 | class M(pax.ParameterModule):
47 | def __init__(self):
48 | super().__init__()
49 | self.a = np.array([3, 1], dtype=np.int32)
50 |
51 | _ = M()
52 |
53 |
54 | def test_assign_int_to_param_deepscan():
55 | class M(pax.Module):
56 | def __init__(self):
57 | super().__init__()
58 | self.a = np.array([3, 1], dtype=np.int32)
59 |
60 | _ = M()
61 | # m = pax.freeze_parameters(m)
62 | # d = OrderedDict(m.name_to_kind)
63 | # d["a"] = pax.module.PaxKind.PARAMETER
64 | # m.__dict__["name_to_kind"] = MappingProxyType(d)
65 | # m = pax.scan_bugs(m)
66 |
67 |
68 | # def test_jit_():
69 | # class M(pax.Module):
70 | # def __init__(self):
71 | # super().__init__()
72 | # self.a_list = [pax.Linear(2, 2)]
73 |
74 | # def __call__(self, x):
75 | # self.a_list.append(0)
76 | # return x
77 |
78 | # m = M()
79 |
80 | # @pax.jit_
81 | # def fwd(m, x):
82 | # return m(x)
83 |
84 | # with pytest.raises(ValueError):
85 | # x = fwd(m, jnp.zeros((2, 2)))
86 |
--------------------------------------------------------------------------------
/tests/test_finetune.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import opax
7 | import pax
8 |
9 |
10 | def test_finetune():
11 | pax.seed_rng_key(42)
12 |
13 | class MLP(pax.Module):
14 | layers: List[pax.Linear]
15 |
16 | def __init__(self, dims: List[int]):
17 | super().__init__()
18 | layers = []
19 | for in_dim, out_dim in zip(dims[:-1], dims[1:]):
20 | layers.append(pax.Linear(in_dim, out_dim))
21 | self.layers = layers
22 |
23 | def __call__(self, x):
24 | for f in self.layers:
25 | x = f(x)
26 | x = jax.nn.sigmoid(x)
27 | return x
28 |
29 | net = MLP([10, 2, 2, 2, 10])
30 |
31 | @pax.pure
32 | def loss_fn(params: MLP, model: MLP, x):
33 | model = pax.update_parameters(model, params=params)
34 | y = model(x)
35 | loss = jnp.mean(jnp.square(x - y))
36 | return loss, (loss, model)
37 |
38 | x = jax.random.normal(pax.next_rng_key(), (1, 10))
39 |
40 | # make all layers non-trainable except the last layer.
41 | for i in range(len(net.layers) - 1):
42 | net.layers[i] = pax.freeze_parameters(net.layers[i])
43 |
44 | # net.layers[-1] = pax.Linear(2, 10)
45 | optimizer = opax.adam(1e-2)(net.parameters())
46 |
47 | @jax.jit
48 | def update_fn(model, optimizer, x):
49 | params = model.parameters()
50 | grads, (loss, model) = jax.grad(loss_fn, has_aux=True)(params, model, x)
51 | model, optimizer = opax.apply_gradients(model, optimizer, grads=grads)
52 | return model, optimizer, loss
53 |
54 | old_layers = net.layers
55 | for i in range(100):
56 | net, optimizer, loss = update_fn(net, optimizer, x)
57 | if i % 10 == 0:
58 | print(f"[step {i:03d}] loss {loss:.3f}")
59 | new_layers = net.layers
60 |
61 | for i in range(len(net.layers) - 1):
62 | np.testing.assert_array_equal(old_layers[i].weight, new_layers[i].weight)
63 |
64 | np.testing.assert_raises(
65 | AssertionError,
66 | np.testing.assert_array_equal,
67 | old_layers[-1].weight,
68 | new_layers[-1].weight,
69 | )
70 |
--------------------------------------------------------------------------------
/tests/test_freeze_unfreeze.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import pax
3 |
4 |
5 | def test_freeze_really_working():
6 | a = pax.Sequential(
7 | pax.Linear(3, 3),
8 | pax.Linear(5, 5),
9 | )
10 | b = pax.freeze_parameters(a)
11 | # assert b[0].pax.name_to_kind["weight"] == pax.PaxKind.STATE
12 | # assert a[0].pax.name_to_kind["weight"] == pax.PaxKind.PARAMETER
13 |
14 |
15 | def test_freeze_mapping_proxy():
16 | a = pax.Sequential(
17 | pax.Linear(3, 3),
18 | pax.Linear(5, 5),
19 | )
20 | b = pax.freeze_parameters(a)
21 | # assert isinstance(b.pax.name_to_kind, MappingProxyType), "expecting a proxy map"
22 |
23 |
24 | def test_freeze_twice():
25 | a = pax.Linear(2, 2)
26 | # with pytest.raises(ValueError):
27 | _ = pax.freeze_parameters(pax.freeze_parameters(a))
28 |
29 |
30 | # def test_freeze_unfreeze():
31 | # a = pax.Sequential(
32 | # pax.Linear(2, 2),
33 | # pax.Linear(3, 3),
34 | # pax.Linear(4, 4),
35 | # pax.Linear(5, 5),
36 | # )
37 |
38 | # b = pax.freeze_parameters(a)
39 | # c = pax.unfreeze_parameters(b, origin=a)
40 | # # pylint: disable=-access
41 | # # assert a[0].pax.name_to_kind is c[0].pax.name_to_kind
42 |
43 |
44 | def test_copy():
45 | a = pax.Linear(1, 1, with_bias=False)
46 | b = pax.enable_eval_mode(a)
47 | assert jax.tree_util.tree_structure(a) != jax.tree_util.tree_structure(b)
48 | c = pax.enable_train_mode(b)
49 | assert jax.tree_util.tree_structure(a) == jax.tree_util.tree_structure(c)
50 |
--------------------------------------------------------------------------------
/tests/test_graph_module.py:
--------------------------------------------------------------------------------
1 | """Test graph module"""
2 |
3 | import copy
4 | from functools import partial
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import pax
9 | import pytest
10 | from pax.experimental.graph import GraphModule, InputNode, build_graph_module
11 |
12 |
13 | def test_simple_graph():
14 | x = InputNode(jnp.zeros((3, 3)))
15 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
16 | assert y.value.shape == (3, 4)
17 |
18 |
19 | def test_cat_graph():
20 | x = InputNode(jnp.zeros((3, 3)))
21 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
22 | z = x & y
23 | t = z >> partial(jnp.concatenate, axis=-1)
24 | assert t.value.shape == (3, 7)
25 |
26 |
27 | def test_cat_merge_left():
28 | x = InputNode(jnp.zeros((3, 3)))
29 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
30 | q = y & y
31 | z = q & x
32 | assert z.parents == (y, y, x)
33 |
34 |
35 | def test_cat_merge_right():
36 | x = InputNode(jnp.zeros((3, 3)))
37 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
38 | q = y & y
39 | z = x & q
40 | assert z.parents == (x, y, y)
41 |
42 |
43 | def test_merge_2_cat():
44 | x = InputNode(jnp.zeros((3, 3)))
45 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
46 | q = y & y
47 | t = x & x
48 | k = q & t
49 | assert k.parents == (y, y, x, x)
50 |
51 |
52 | def test_3_cat_graph():
53 | x = InputNode(jnp.zeros((3, 3)))
54 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
55 | z = x & y & x
56 | t = z >> partial(jnp.concatenate, axis=-1)
57 | assert t.value.shape == (3, 10)
58 |
59 |
60 | def test_3_cat_graph_module():
61 | x = InputNode(jnp.zeros((3, 3)))
62 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
63 | z = x & y & y
64 | t = z >> partial(jnp.concatenate, axis=-1)
65 | _ = GraphModule((x,), t)
66 |
67 |
68 | def test_or_graph():
69 | x = InputNode(jnp.zeros((3, 3)))
70 | y = x >> pax.Linear(3, 3) >> jax.nn.relu
71 | z = (x | y) >> jax.lax.add
72 | assert z.value.shape == (3, 3)
73 |
74 |
75 | def test_merge_2_or():
76 | x = InputNode(jnp.zeros((3, 3)))
77 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
78 | q = y | y
79 | t = x | x
80 | k = t | q
81 | assert k.parents == (x, x, y, y)
82 |
83 |
84 | def test_or_merge_left():
85 | x = InputNode(jnp.zeros((3, 3)))
86 | y = x >> pax.Linear(3, 3) >> jax.nn.relu
87 | z = x | y
88 | t = z | x
89 | assert t.parents == (x, y, x)
90 |
91 |
92 | def test_or_merge_right():
93 | x = InputNode(jnp.zeros((3, 3)))
94 | y = x >> pax.Linear(3, 3) >> jax.nn.relu
95 | z = x | y
96 | t = x | z
97 | assert t.parents == (x, x, y)
98 |
99 |
100 | def test_cat_graph_merge():
101 | x = InputNode(jnp.zeros((3, 3)))
102 | y = x >> pax.Linear(3, 4) >> jax.nn.relu
103 | q = y | y
104 | z = x | q
105 | assert z.parents == (x, y, y)
106 |
107 |
108 | def test_binops():
109 | x = InputNode(jnp.ones((3, 3)))
110 | y = x.binary_ops(jax.lax.add, x)
111 | assert y.parents == (x, x)
112 | assert jnp.array_equal(y.fx((x.value, x.value)), jnp.ones((3, 3)) * 2)
113 | assert jnp.array_equal(y.value, jnp.ones((3, 3)) * 2)
114 |
115 |
116 | def test_type_shape():
117 | x = InputNode(jnp.ones((3, 3), dtype=jnp.int32))
118 | assert x.shape == (3, 3)
119 | assert x.dtype == jnp.int32
120 |
121 |
122 | def test_build_residual_net():
123 | def residual(x):
124 | y = x >> pax.Linear(3, 3) >> jax.nn.relu
125 | t = x >> pax.Linear(3, 3) >> jax.nn.tanh
126 | z = (y | t) >> jax.lax.add
127 | return z
128 |
129 | x = jnp.empty((1, 3))
130 | net = build_graph_module(residual)(x)
131 | y = net(x)
132 | assert y.shape == (1, 3)
133 |
134 |
135 | def test_reuse_module_error():
136 | def reuse(x):
137 | mod = pax.Linear(3, 3)
138 | y = x >> mod >> jax.nn.relu
139 | t = x >> mod
140 | z = (y | t) >> jax.lax.add
141 | return z
142 |
143 | x = jnp.empty((1, 3))
144 | with pytest.raises(ValueError):
145 | _ = build_graph_module(reuse)(x)
146 |
147 |
148 | def test_copy_error():
149 | x = InputNode(jnp.empty((3, 3)))
150 | with pytest.raises(TypeError):
151 | _ = copy.copy(x)
152 |
153 |
154 | def test_deepcopy_error():
155 | x = InputNode(jnp.empty((3, 3)))
156 | with pytest.raises(TypeError):
157 | _ = copy.deepcopy(x)
158 |
--------------------------------------------------------------------------------
/tests/test_immutability.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import pax
3 | import pytest
4 |
5 |
6 | def test_immutability():
7 | f = pax.Linear(3, 3)
8 | with pytest.raises(ValueError):
9 | f.c = 123
10 | g = pax.freeze_parameters(f)
11 | # k = pax.unfreeze_parameters(g, origin=f)
12 |
13 |
14 | def test_new_empty_attribute():
15 | class M(pax.Module):
16 | a = []
17 |
18 | m = M()
19 |
20 |
21 | def test_new_unregistered_array():
22 | class M(pax.Module):
23 | a = [jnp.zeros((3, 3))]
24 |
25 | with pytest.raises(ValueError):
26 | m = M()
27 |
28 |
29 | def test_new_unregistered_module():
30 | class M(pax.Module):
31 | a = pax.Linear(3, 3)
32 |
33 | with pytest.raises(ValueError):
34 | m = M()
35 |
--------------------------------------------------------------------------------
/tests/test_jax_transform.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import pax
4 | import pytest
5 |
6 |
7 | def test_jit_immutability():
8 | class M(pax.Module):
9 | def __init__(self):
10 | self.x = pax.Linear(2, 2)
11 | self.counter = 2
12 |
13 | def __call__(self, x):
14 | self.counter = self.counter + 1
15 | return x
16 |
17 | m = M()
18 | x = jnp.zeros((1, 1))
19 | with pytest.raises(ValueError):
20 | y = jax.jit(lambda y: m(y))(x)
21 |
22 |
23 | def test_grad_deepscan():
24 | class M(pax.Module):
25 | def __init__(self):
26 | self.fc = pax.Linear(2, 2)
27 |
28 | def __call__(self, x):
29 | return self.fc(x)
30 |
31 | def loss_fn(params, model, inputs):
32 | model = pax.update_parameters(model, params=params)
33 | loss = jnp.mean(model(inputs))
34 | return loss, (loss, model)
35 |
36 | m = M()
37 | x = jnp.zeros((1, 2))
38 | m.set_attribute("fc1", pax.Linear(2, 2))
39 | y = jax.grad(loss_fn, has_aux=True)(pax.select_parameters(m), m, x)
40 |
41 |
42 | def test_loss_fn_no_return_model():
43 | def loss_fn(params, model, inputs):
44 | model = pax.update_parameters(model, params=params)
45 | y = model(inputs)
46 | return jnp.sum(y)
47 |
48 | grad_fn = jax.grad(loss_fn)
49 | x = jnp.zeros((3, 3))
50 | net = pax.Linear(3, 3)
51 | y = grad_fn(net.parameters(), net, x)
52 |
53 |
54 | def test_jit__call__():
55 | class M(pax.Module):
56 | @jax.jit
57 | def __call__(self, x):
58 | return x, self
59 |
60 | x = jnp.zeros((3, 3))
61 | net = M()
62 | y = net(x)
63 |
64 | class M(pax.Module):
65 | @jax.jit
66 | def __call__(self, x):
67 | return x
68 |
69 | net = M()
70 | y = net(x)
71 |
--------------------------------------------------------------------------------
/tests/test_mixed_precision.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import jmp
4 | import pax
5 | import pytest
6 | from pax import apply_mp_policy
7 |
8 | half = jmp.half_dtype()
9 | full = jnp.float32
10 |
11 |
12 | def test_wrap_unwrap_mp_policy():
13 | f = pax.Linear(3, 3)
14 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)
15 |
16 | ff = pax.apply_mp_policy(f, mp_policy=my_policy)
17 | fff = pax.unwrap_mp_policy(ff)
18 | assert hasattr(ff, "_pax_mp_policy")
19 | assert not hasattr(fff, "_pax_mp_policy")
20 |
21 | x = jax.numpy.ones((3, 3))
22 | assert f(x).dtype == full
23 | assert ff(x).dtype == half
24 | assert fff(x).dtype == full # type: ignore
25 |
26 |
27 | def test_sequential_mixed_precision():
28 | f = pax.Sequential(
29 | pax.Linear(3, 3),
30 | pax.BatchNorm2D(3, True, True, 0.9),
31 | pax.Linear(3, 3),
32 | pax.BatchNorm2D(3, True, True, 0.9),
33 | )
34 | linear_policy = jmp.Policy(compute_dtype=half, param_dtype=half, output_dtype=half)
35 | batchnorm_policy = jmp.Policy(
36 | compute_dtype=full, param_dtype=full, output_dtype=half
37 | )
38 |
39 | def policy_fn(mod):
40 | if isinstance(mod, pax.Linear):
41 | return pax.apply_mp_policy(mod, mp_policy=linear_policy)
42 | elif isinstance(mod, pax.BatchNorm2D):
43 | return pax.apply_mp_policy(mod, mp_policy=batchnorm_policy)
44 | else:
45 | # unchanged
46 | return mod
47 |
48 | f_mp = f.apply(policy_fn)
49 | x = jnp.zeros((32, 5, 5, 3))
50 |
51 | @pax.pure
52 | def run(f_mp):
53 | return f_mp(x)
54 |
55 | y = run(f_mp)
56 | assert y.dtype == half
57 |
58 |
59 | def test_change_internal_state():
60 | class M(pax.Module):
61 | counter: jnp.ndarray
62 |
63 | def __init__(self):
64 | super().__init__()
65 | self.counter = jnp.array(0)
66 |
67 | def __call__(self, x):
68 | self.counter = self.counter + 1
69 | return x * self.counter
70 |
71 | m = M()
72 | mp = jmp.Policy(
73 | compute_dtype=jnp.float16, param_dtype=jnp.float32, output_dtype=jnp.float16
74 | )
75 | mm = m.apply(
76 | lambda x: (pax.apply_mp_policy(x, mp_policy=mp) if isinstance(x, M) else x)
77 | )
78 | x = jnp.array(0.0)
79 | assert mm.counter.item() == 0
80 | mm, y = pax.module_and_value(mm)(x)
81 | assert mm.counter.item() == 1
82 | assert m.counter.item() == 0
83 |
84 |
85 | def test_change_tree_def():
86 | class M(pax.Module):
87 | counter: jnp.ndarray
88 | count: int
89 |
90 | def __init__(self):
91 | super().__init__()
92 | self.counter = jnp.array(0)
93 | self.count = 0
94 |
95 | def __call__(self, x):
96 | self.counter = self.counter + 1
97 | self.count = self.count + 1
98 | return x * self.counter
99 |
100 | m = M()
101 | mp = jmp.Policy(
102 | compute_dtype=jnp.float16, param_dtype=jnp.float32, output_dtype=jnp.float16
103 | )
104 | mm = m.apply(
105 | lambda x: (pax.apply_mp_policy(x, mp_policy=mp) if isinstance(x, M) else x)
106 | )
107 | x = jnp.array(0.0)
108 | assert mm.counter.item() == 0
109 | with pytest.raises(ValueError):
110 | y = mm(x)
111 | assert mm.counter.item() == 0
112 | assert m.counter.item() == 0
113 |
114 |
115 | def test_wrap_wrap_mixed_precision():
116 | f = pax.Linear(3, 3)
117 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)
118 |
119 | f = pax.apply_mp_policy(f, mp_policy=my_policy)
120 | with pytest.raises(ValueError):
121 | f = pax.apply_mp_policy(f, mp_policy=my_policy)
122 |
123 | f = pax.unwrap_mp_policy(f)
124 | f = pax.apply_mp_policy(f, mp_policy=my_policy)
125 |
126 | with pytest.raises(ValueError):
127 | f = pax.apply_mp_policy(f, mp_policy=my_policy)
128 |
129 |
130 | def test_mixed_precision_clone():
131 | f = pax.BatchNorm1D(3)
132 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)
133 |
134 | ff = pax.apply_mp_policy(f, mp_policy=my_policy)
135 |
136 | f = f.set_attribute("new_fc", pax.Linear(1, 1))
137 | # assert "new_fc" not in ff.pax.name_to_kind
138 |
139 |
140 | def test_mixed_precision_unwrap_clone():
141 | f = pax.BatchNorm1D(3)
142 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)
143 |
144 | ff = pax.apply_mp_policy(f, mp_policy=my_policy)
145 | f = pax.unwrap_mp_policy(ff)
146 | f = f.set_attribute("new_fc", pax.Linear(1, 1))
147 | # assert "new_fc" not in ff.pax.name_to_kind
148 |
149 |
150 | def test_mixed_precision_no_method_name():
151 | f = pax.Linear(3, 3)
152 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)
153 |
154 | # with pytest.raises(TypeError):
155 | _ = pax.apply_mp_policy(f, mp_policy=my_policy)
156 |
157 |
158 | def test_mp_call_classmethod():
159 | class M(pax.Module):
160 | def __init__(self):
161 | super().__init__()
162 | self.fc = pax.Linear(3, 3)
163 |
164 | @classmethod
165 | def t(cls, y):
166 | return y
167 |
168 | m = M()
169 | x = jnp.zeros((3, 3))
170 | y = m.t(x)
171 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)
172 | m = apply_mp_policy(m, mp_policy=my_policy)
173 | # with pytest.raises(ValueError):
174 | y = m.t(x)
175 |
176 |
177 | def test_mp_call_staticmethod():
178 | class M(pax.Module):
179 | def __init__(self):
180 | super().__init__()
181 | self.fc = pax.Linear(3, 3)
182 |
183 | @staticmethod
184 | def t(_, y):
185 | return y
186 |
187 | m = M()
188 | x = jnp.zeros((3, 3))
189 | y = m.t(x, x)
190 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)
191 | m = apply_mp_policy(m, mp_policy=my_policy)
192 | # with pytest.raises(ValueError):
193 | y = m.t(x, x)
194 |
195 |
196 | @pax.pure
197 | def test_mp_call_function():
198 | class M(pax.Module):
199 | def __init__(self):
200 | super().__init__()
201 | self.fc = pax.Linear(3, 3)
202 |
203 | m = M()
204 | x = jnp.zeros((3, 3))
205 |
206 | def mutate(m):
207 | m.q = lambda x: x
208 | return m
209 |
210 | m = pax.pure(mutate)(m)
211 | my_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)
212 | m = apply_mp_policy(m, mp_policy=my_policy)
213 | # with pytest.raises(ValueError):
214 | m.q(x)
215 |
--------------------------------------------------------------------------------
/tests/test_multithread.py:
--------------------------------------------------------------------------------
1 | """
2 | Test PAX in multithread environment.
3 | """
4 | import queue
5 | import threading
6 | import time
7 |
8 | import jax.numpy as jnp
9 | import pax
10 |
11 |
12 | class DelayedCounter(pax.Module):
13 | def __init__(self):
14 | super().__init__()
15 | self.counter = jnp.array(0)
16 |
17 | def __call__(self):
18 | time.sleep(1)
19 | self.counter += 1
20 | time.sleep(1)
21 | return self.counter
22 |
23 |
24 | def test_multithread():
25 | @pax.pure
26 | def update(c: DelayedCounter, q):
27 | o = c()
28 | q.put(o)
29 |
30 | c1 = DelayedCounter()
31 | c2 = DelayedCounter()
32 | q = queue.Queue()
33 | x = threading.Thread(target=update, args=(c1, q))
34 | y = threading.Thread(target=update, args=(c2, q))
35 | x.start()
36 | y.start()
37 | x.join()
38 | y.join()
39 | q.get(timeout=1)
40 | q.get(timeout=1)
41 |
--------------------------------------------------------------------------------
/tests/test_nets.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import pax
3 |
4 |
5 | def test_run_resnet():
6 | resnet = pax.nets.ResNet18(3, 1)
7 | x = jax.numpy.zeros((1, 3, 18, 18))
8 | y = pax.pure(resnet)(x)
9 | assert y.shape == (1, 1)
10 |
11 |
12 | def test_run_transformer():
13 | transformer = pax.nets.Transformer(8, 2, 2, 0.1)
14 | x = jax.numpy.zeros((1, 15, 8))
15 | y = pax.pure(transformer)(x)
16 | assert y.shape == (1, 15, 8)
17 |
--------------------------------------------------------------------------------
/tests/test_optim.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import opax
4 | import pax
5 | import pytest
6 |
7 |
8 | def test_optim_model_update_state():
9 | # a module updates it internal `count` value in the forward pass.
10 |
11 | class MyModule(pax.Module):
12 | count: int = 0
13 | fc: pax.Module
14 |
15 | def __init__(self):
16 | super().__init__()
17 | self.fc = pax.Linear(2, 2)
18 | self.count = 0
19 |
20 | def __call__(self, x):
21 | self.count = self.count + 1
22 | x = self.fc(x)
23 | return x
24 |
25 | net = MyModule()
26 |
27 | def loss_fn(model: MyModule, x):
28 | y = model(x)
29 | loss = jnp.mean(jnp.square(x - y))
30 | return loss, (loss, model)
31 |
32 | update_fn = pax.utils.build_update_fn(loss_fn=loss_fn)
33 | optimizer = opax.adamw()(net.parameters())
34 | x = jnp.zeros((2, 2), dtype=jnp.float32)
35 |
36 | with pytest.raises(ValueError):
37 | net, optimizer, loss = update_fn(net, optimizer, x)
38 |
39 |
40 | def test_sgd():
41 | class SGD(pax.Module):
42 | velocity: pax.Module
43 | learning_rate: float
44 | momentum: float
45 |
46 | def __init__(self, params, learning_rate: float = 1e-2, momentum: float = 0.9):
47 | super().__init__()
48 | self.momentum = momentum
49 | self.learning_rate = learning_rate
50 | self.velocity = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), params)
51 |
52 | def step(self, grads: pax.Module, params: pax.Module):
53 | self.velocity = jax.tree_util.tree_map(
54 | lambda v, g: v * self.momentum + g * self.learning_rate,
55 | self.velocity,
56 | grads,
57 | )
58 | new_params = jax.tree_util.tree_map(
59 | lambda p, v: p - v, params, self.velocity
60 | )
61 | return new_params
62 |
63 | f = pax.Linear(2, 2)
64 | sgd = SGD(f, 0.9, 1e-4)
65 | pax.pure(sgd.step)(f, f)
66 |
--------------------------------------------------------------------------------
/tests/test_performance.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import jax
4 | import numpy as np
5 | import pax
6 |
7 |
8 | def test_perf_transformer_flatten_unflatten():
9 | class MyTransformer(pax.Module):
10 | def __init__(self, num_layers: int):
11 | super().__init__()
12 | self.layers = [
13 | pax.MultiHeadAttention(8, 512 // 8, 1.0) for i in range(num_layers)
14 | ]
15 |
16 | f = MyTransformer(16)
17 |
18 | start = time.perf_counter()
19 | n_iters = 100_000
20 | for _ in range(n_iters):
21 | leaves, treedef = jax.tree_util.tree_flatten(f)
22 | f = jax.tree_util.tree_unflatten(treedef, leaves)
23 | end = time.perf_counter()
24 | iters_per_second = n_iters / (end - start)
25 | print(iters_per_second, "iters/second")
26 | assert iters_per_second > 2500
27 |
28 |
29 | def test_perf_resnet200_flatten_unflatten():
30 |
31 | f = pax.nets.ResNet200(3, 100)
32 |
33 | start = time.perf_counter()
34 | n_iters = 1000
35 | for _ in range(n_iters):
36 | leaves, treedef = jax.tree_util.tree_flatten(f)
37 | f = jax.tree_util.tree_unflatten(treedef, leaves)
38 | end = time.perf_counter()
39 | iters_per_second = n_iters / (end - start)
40 | print(iters_per_second, "iters/second")
41 | assert iters_per_second > 100
42 |
43 |
44 | def test_perf_flattenmodule_resnet200_flatten_unflatten():
45 |
46 | x = jax.random.normal(jax.random.PRNGKey(42), (1, 3, 64, 64))
47 | f = pax.nets.ResNet200(3, 100)
48 | y = f.eval()(x)
49 | f = pax.experimental.Flattener(net=f.eval())
50 | y1 = pax.pure(f.net)(x)
51 | np.testing.assert_array_equal(y, y1)
52 |
53 | start = time.perf_counter()
54 | n_iters = 10000
55 | for _ in range(n_iters):
56 | leaves, treedef = jax.tree_util.tree_flatten(f)
57 | f = jax.tree_util.tree_unflatten(treedef, leaves)
58 | end = time.perf_counter()
59 | iters_per_second = n_iters / (end - start)
60 | print(iters_per_second, "iters/second")
61 | assert iters_per_second > 4000
62 |
--------------------------------------------------------------------------------
/tests/test_pure.py:
--------------------------------------------------------------------------------
1 | import weakref
2 | from functools import partial
3 | from typing import Any
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import pax
8 | import pytest
9 | from numpy.testing import assert_array_equal
10 |
11 |
12 | def test_rng_unchanged():
13 | pax.seed_rng_key(41)
14 | pax.next_rng_key()
15 |
16 | @jax.jit
17 | @pax.pure
18 | def fn():
19 | return pax.next_rng_key()
20 |
21 | def f1():
22 | pax.seed_rng_key(41)
23 | pax.next_rng_key()
24 | return pax.next_rng_key()
25 |
26 | def f2():
27 | pax.seed_rng_key(41)
28 | pax.next_rng_key()
29 | fn()
30 | return pax.next_rng_key()
31 |
32 | r1 = f1()
33 | r2 = f2()
34 | assert not jnp.array_equal(r1, r2)
35 |
36 | r3 = fn()
37 | _ = pax.next_rng_key()
38 | r4 = fn()
39 | assert_array_equal(r3, r4)
40 |
41 |
42 | def test_deepcopy():
43 | class C(object):
44 | c: int
45 |
46 | def __init__(self):
47 | self.c = 0
48 |
49 | @pax.pure
50 | def mutate(x):
51 | x.c.c += 1
52 | return x
53 |
54 | class M(pax.Module):
55 | c: C
56 |
57 | def __init__(self):
58 | self.c = C()
59 |
60 | m = M()
61 | assert m.c.c == 0
62 | m1 = mutate(m)
63 | assert m.c.c == 1
64 | assert m1.c.c == 1
65 |
66 |
67 | def test_deep_compare_1():
68 | class C(object):
69 | c: int
70 |
71 | def __init__(self):
72 | self.c = 0
73 |
74 | @pax.pure
75 | def mutate(x):
76 | return x
77 |
78 | class M(pax.Module):
79 | c: C
80 |
81 | def __init__(self):
82 | self.c = C()
83 |
84 | m = M()
85 | m1 = mutate(m)
86 | # with pytest.raises(AssertionError):
87 | pax.assert_structure_equal(m, m1)
88 |
89 |
90 | def test_deep_compare_2():
91 | class C(object):
92 | c: int
93 |
94 | def __init__(self):
95 | self.c = 0
96 |
97 | def __eq__(self, o) -> bool:
98 | return self.c == o.c
99 |
100 | @pax.pure
101 | def mutate(x):
102 | return x
103 |
104 | class M(pax.Module):
105 | f: Any
106 | g: Any
107 | j: Any
108 | c: C
109 |
110 | def __init__(self):
111 | self.f = jax.nn.relu
112 | self.g = jax.nn.sigmoid
113 | self.j = partial(jax.nn.leaky_relu, negative_slope=0.2)
114 | self.h = jnp.tanh
115 | self.c = C()
116 |
117 | m = M()
118 | m1 = mutate(m)
119 | pax.assert_structure_equal(m, m1)
120 |
121 |
122 | def test_module_weak_ref():
123 | mod = pax.Linear(3, 3)
124 | mod_ref = weakref.ref(mod)
125 | assert mod_ref() is mod
126 | del mod
127 | assert mod_ref() is None
128 |
129 |
130 | def test_abstraction_level_checking():
131 | def mutate(f):
132 | @jax.jit
133 | def g():
134 | f.a = "hello"
135 |
136 | g()
137 |
138 | fc = pax.Linear(3, 3)
139 | with pytest.raises(ValueError):
140 | pax.pure(mutate)(fc)
141 |
142 |
143 | def test_decorate_method_with_module_and_value():
144 | class M(pax.StateModule):
145 | def __init__(self):
146 | self.c = jnp.array(0)
147 |
148 | @pax.module_and_value
149 | def step(self):
150 | self.c += 1
151 |
152 | m = M()
153 | assert m.c.item() == 0
154 | m, _ = m.step()
155 | assert m.c.item() == 1
156 |
--------------------------------------------------------------------------------
/tests/test_summary.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import pax
3 |
4 |
5 | def test_linear_summary():
6 | fc = pax.Linear(3, 3)
7 | assert fc.summary() == "Linear(in_dim=3, out_dim=3, with_bias=True)"
8 |
9 |
10 | def test_sequential_summary():
11 | f = pax.Sequential(pax.Linear(3, 32), jax.nn.sigmoid, pax.Linear(32, 64))
12 | f1 = pax.BatchNorm1D(3)
13 | f1 = f1.set_attribute("T", f)
14 | print(f1.summary())
15 |
--------------------------------------------------------------------------------
/tests/test_training.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import opax
6 | import pax
7 |
8 |
9 | def test_train_linear_regression_1():
10 | x = jax.random.normal(jax.random.PRNGKey(42), (32, 1), dtype=jnp.float32)
11 | noise = jax.random.normal(jax.random.PRNGKey(43), (32, 1), dtype=jnp.float32) * 0.2
12 | y = x * 2.5 - 3.1 + noise
13 |
14 | def loss_fn(model: pax.Linear, x, y):
15 | y_hat = model(x)
16 | loss = jnp.mean(jnp.square(y - y_hat))
17 | return loss, (loss, model)
18 |
19 | update_fn = pax.utils.build_update_fn(loss_fn)
20 | net = pax.Linear(1, 1)
21 | optimizer = opax.adamw(1e-1)(net.parameters())
22 | for step in range(100):
23 | net, optimizer, loss = update_fn(net, optimizer, x, y)
24 | print(f"[step {step}] loss {loss:.3f}")
25 |
26 |
27 | def test_train_linear_regression_2():
28 | x = jax.random.normal(jax.random.PRNGKey(42), (32, 1), dtype=jnp.float32)
29 | noise = jax.random.normal(jax.random.PRNGKey(43), (32, 1), dtype=jnp.float32) * 0.2
30 | y = x * 2.5 - 3.1 + noise
31 |
32 | class M(pax.Module):
33 | def __init__(self):
34 | super().__init__()
35 | self.fc1 = pax.Linear(1, 32)
36 | self.fc2 = pax.Linear(32, 1)
37 |
38 | def __call__(self, x):
39 | x = self.fc1(x)
40 | x = jax.nn.relu(x)
41 | x = self.fc2(x)
42 | return x
43 |
44 | def loss_fn(model: M, x, y):
45 | y_hat = model(x)
46 | loss = jnp.mean(jnp.square(y - y_hat))
47 | return loss, (loss, model)
48 |
49 | update_fn = pax.utils.build_update_fn(loss_fn)
50 | net = M()
51 | optimizer = opax.adamw(1e-1)(net.parameters())
52 | for step in range(100):
53 | net, optimizer, loss = update_fn(net, optimizer, x, y)
54 | print(f"[step {step}] loss {loss:.3f}")
55 |
--------------------------------------------------------------------------------
/tests/test_transforms.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import jmp
3 | import pax
4 |
5 |
6 | def test_mutate_new_module_list():
7 | a = pax.Linear(3, 3)
8 | b = a.copy()
9 |
10 | def mutate(b):
11 | b.lst = [pax.Linear(4, 4)]
12 | return b
13 |
14 | b = pax.pure(mutate)(b)
15 | # pylint: disable=protected-access
16 | # assert b.pax.name_to_kind["lst"] == pax.PaxKind.MODULE
17 |
18 |
19 | def test_mp_policy_method_name():
20 | class M(pax.Module):
21 | def __init__(self):
22 | super().__init__()
23 | self.f = pax.Linear(3, 3)
24 |
25 | def __call__(self, x):
26 | return self.f(x)
27 |
28 | def inference(self, x):
29 | return self.f(x) + 1.0
30 |
31 | m = M()
32 | half = jmp.half_dtype()
33 | full = jnp.float32
34 |
35 | p = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=full)
36 |
37 | m = pax.apply_mp_policy(m, mp_policy=p)
38 | x = jnp.zeros((4, 3))
39 | _ = m(x) # ok
40 |
41 | _ = m.inference(x)
42 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import opax
5 | import pax
6 | from pax import EMA, RngSeq
7 |
8 |
9 | def test_grad():
10 | def loss_fn(model: pax.Linear, inputs):
11 | x, target = inputs
12 | y = model(x)
13 | loss = jnp.mean(jnp.square(y - target))
14 | return loss, (loss, model)
15 |
16 | @jax.jit
17 | def update_fn(model, optimizer, inputs):
18 | grads, (loss, model) = pax.grad(loss_fn, has_aux=True)(model, inputs)
19 | model, optimizer = opax.apply_gradients(model, opt, grads=grads)
20 | return model, optimizer, loss
21 |
22 | net = pax.Linear(2, 1)
23 | opt = opax.adamw(learning_rate=1e-2)(net.parameters())
24 | x = np.random.normal(size=(32, 2))
25 | y = np.random.normal(size=(32, 1))
26 | print()
27 | for step in range(5):
28 | net, opt, loss = update_fn(net, opt, (x, y))
29 | print(f"step {step} loss {loss:.3f}")
30 |
31 |
32 | def test_value_and_grad():
33 | def loss_fn(model: pax.Linear, inputs):
34 | x, target = inputs
35 | y = model(x)
36 | loss = jnp.mean(jnp.square(y - target))
37 | return loss, model
38 |
39 | @jax.jit
40 | def update_fn(model, optimizer, inputs):
41 | (loss, model), grads = pax.value_and_grad(loss_fn, has_aux=True)(model, inputs)
42 | model, optimizer = opax.apply_gradients(model, opt, grads)
43 | return model, optimizer, loss
44 |
45 | net = pax.Linear(2, 1)
46 | opt = opax.adamw(learning_rate=1e-2)(net.parameters())
47 | x = np.random.normal(size=(32, 2))
48 | y = np.random.normal(size=(32, 1))
49 | print()
50 | for step in range(5):
51 | net, opt, loss = update_fn(net, opt, (x, y))
52 | print(f"step {step} loss {loss:.3f}")
53 |
54 |
55 | def test_util_update_fn():
56 | def loss_fn(model: pax.Linear, x, target):
57 | y = model(x)
58 | loss = jnp.mean(jnp.square(y - target))
59 | return loss, (loss, model)
60 |
61 | net = pax.Linear(2, 1)
62 | opt = opax.adamw(learning_rate=1e-1)(net.parameters())
63 | update_fn = jax.jit(pax.utils.build_update_fn(loss_fn, scan_mode=True))
64 | x = np.random.normal(size=(32, 2))
65 | y = np.random.normal(size=(32, 1))
66 | print()
67 | for step in range(3):
68 | (net, opt), loss = update_fn((net, opt), x, y)
69 | print(f"step {step} loss {loss:.3f}")
70 |
71 |
72 | def test_Rng_Seq():
73 | rng_seq = RngSeq(seed=42)
74 | assert rng_seq._rng_key.tolist() == [0, 42]
75 |
76 | rng_seq, r1 = pax.module_and_value(rng_seq.next_rng_key)()
77 | assert r1.shape == (2,)
78 | h1 = rng_seq._rng_key
79 | rng_seq, rs = pax.module_and_value(rng_seq.next_rng_key)(2)
80 | h2 = rng_seq._rng_key
81 | assert len(rs) == 2
82 | assert r1.tolist() != rs[0].tolist()
83 | assert h1.tolist() != h2.tolist(), "update internal state in `train` mode"
84 |
85 | rng_seq = pax.enable_eval_mode(rng_seq)
86 | rng_seq, r3 = pax.module_and_value(rng_seq.next_rng_key)()
87 | rng_seq, r4 = pax.module_and_value(rng_seq.next_rng_key)()
88 | assert r3.tolist() != r4.tolist()
89 | h3 = rng_seq._rng_key
90 | assert h2.tolist() != h3.tolist(), "update internal state even in `eval` mode"
91 |
92 |
93 | def test_ema_debias():
94 | ema = EMA(jnp.array(1.0), 0.9, True)
95 | assert ema.debias.item() == False
96 | assert ema.averages.item() == 1.0
97 |
98 | ema, _ = pax.purecall(ema, jnp.array(2.0))
99 | assert ema.averages.item() == 2.0
100 | assert ema.debias.item() == True
101 |
102 | ema, _ = pax.purecall(ema, jnp.array(1.0))
103 | np.testing.assert_almost_equal(ema.averages.item(), 0.9 * 2.0 + 0.1 * 1.0)
104 |
105 |
106 | def test_ema_bias():
107 | ema = EMA(jnp.array(1.0), 0.9, False)
108 | assert ema.debias is None
109 | assert ema.averages.item() == 1.0
110 |
111 | ema, _ = pax.purecall(ema, jnp.array(2.0))
112 | np.testing.assert_almost_equal(ema.averages.item(), 0.1 * 2.0 + 0.9 * 1.0)
113 |
114 |
115 | def test_scan_fn_not_time_major():
116 | def loop(prev_state, x):
117 | next_state = prev_state + x
118 | return next_state, next_state
119 |
120 | h0 = jnp.zeros((1,))
121 | xs = jnp.arange(0, 10).reshape((1, -1))
122 | _, ys = pax.scan(loop, h0, xs, time_major=False)
123 | assert ys[0, -1].item() == 45
124 |
125 |
126 | def test_scan_fn_not_time_major_pytree():
127 | def loop(prev_state, x):
128 | next_state = prev_state + x[0] + x[1]
129 | return next_state, (next_state, next_state)
130 |
131 | h0 = jnp.zeros((1,))
132 | xs = jnp.arange(0, 10).reshape((1, -1))
133 | _, (ys1, ys2) = pax.scan(loop, h0, (xs, xs), time_major=False)
134 | assert ys1[0, -1].item() == 90
135 |
136 |
137 | def test_scan_fn_time_major():
138 | def loop(prev_state, x):
139 | next_state = prev_state + x
140 | return next_state, next_state
141 |
142 | h0 = jnp.zeros((1,))
143 | xs = jnp.arange(0, 10).reshape((-1, 1))
144 | _, ys = pax.scan(loop, h0, xs, time_major=True)
145 | assert ys[-1, 0].item() == 45
146 |
--------------------------------------------------------------------------------