├── .coveragerc
├── .flake8
├── .github
└── workflows
│ ├── publish_dev.yaml
│ ├── run_pre_commit.yaml
│ ├── run_quick_levanter_tests.yaml
│ └── run_tests.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── api.md
├── broadcasting.md
├── cheatsheet.md
├── css
│ ├── material.css
│ └── mkdocstrings.css
├── faq.md
├── figures
│ ├── data_parallel_mesh.png
│ ├── data_parallel_mesh_replicated.png
│ ├── device_mesh_1d.png
│ ├── device_mesh_1d_zero.png
│ ├── device_mesh_2d.png
│ ├── device_mesh_2d_batch_partitioned.png
│ ├── device_mesh_2d_data_replicated.png
│ ├── device_mesh_2d_data_replicated_mlp_partitioned.png
│ ├── device_mesh_2d_intermediate_fully_partitioned.png
│ └── device_mesh_2d_zero.png
├── fp8.md
├── index.md
├── indexing.md
├── matmul.md
├── nn.md
├── partitioning.md
├── rearrange.ipynb
├── rearrange.md
├── requirements.txt
├── scan.md
├── state-dict.md
├── tutorial.md
└── vmap.md
├── mkdocs.yml
├── pyproject.toml
├── src
└── haliax
│ ├── __about__.py
│ ├── __init__.py
│ ├── _src
│ ├── __init__.py
│ ├── compile_utils.py
│ ├── dot.py
│ ├── einsum.py
│ ├── fp8.py
│ ├── parsing.py
│ ├── rearrange.py
│ ├── scan.py
│ ├── state_dict.py
│ └── util.py
│ ├── axis.py
│ ├── core.py
│ ├── debug.py
│ ├── hof.py
│ ├── jax_utils.py
│ ├── nn
│ ├── __init__.py
│ ├── activations.py
│ ├── attention.py
│ ├── conv.py
│ ├── dropout.py
│ ├── embedding.py
│ ├── linear.py
│ ├── loss.py
│ ├── mlp.py
│ ├── normalization.py
│ ├── pool.py
│ └── scan.py
│ ├── ops.py
│ ├── partitioning.py
│ ├── quantization.py
│ ├── random.py
│ ├── specialized_fns.py
│ ├── state_dict.py
│ ├── tree_util.py
│ ├── types.py
│ ├── util.py
│ └── wrap.py
└── tests
├── core_test.py
├── test_attention.py
├── test_axis.py
├── test_conv.py
├── test_debug.py
├── test_dot.py
├── test_einsum.py
├── test_fp8.py
├── test_hof.py
├── test_int8.py
├── test_nn.py
├── test_ops.py
├── test_parsing.py
├── test_partitioning.py
├── test_pool.py
├── test_random.py
├── test_rearrange.py
├── test_scan.py
├── test_specialized_fns.py
├── test_state_dict.py
├── test_tree_util.py
└── test_utils.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [report]
2 | exclude_lines =
3 | pragma: not covered
4 | pragma: no cover
5 | @overload
6 | @typing.overload
7 | @abc.abstractmethod
8 | raise NotImplementedError
9 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | exclude = .git
3 | max-line-length = 120
4 | ignore = E203, E501, W503, W605, F821, E266, E402, E731, F401, F403, F40
5 | per-file-ignores =
6 | */__init__.py: F401
7 | examples/*.py: E402
8 |
--------------------------------------------------------------------------------
/.github/workflows/publish_dev.yaml:
--------------------------------------------------------------------------------
1 | name: Publish Dev Build
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Run Tests"]
6 | types:
7 | - completed
8 | branches: [main]
9 | workflow_dispatch:
10 |
11 | jobs:
12 | build-package:
13 | runs-on: ubuntu-latest
14 | if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success'}}
15 | steps:
16 | - name: Checkout code
17 | uses: actions/checkout@v4
18 | with:
19 | fetch-depth: 0
20 | - name: Set up Python
21 | uses: actions/setup-python@v2
22 | with:
23 | python-version: '3.x'
24 |
25 | - name: Calculate Version and Build Number
26 | run: |
27 | PROJECT_VERSION=$(sed -n 's/__version__ = "\(.*\)"/\1/p' src/haliax/__about__.py)
28 | BUILD_NUMBER=$(git rev-list --count HEAD)
29 | FULL_VERSION="${PROJECT_VERSION}.dev${BUILD_NUMBER}"
30 | echo "FULL_VERSION=${FULL_VERSION}" >> $GITHUB_ENV
31 | echo "Calculated version with build number: $FULL_VERSION"
32 | - name: Update pyproject.toml version
33 | run: |
34 | # replace the version in __about__.py
35 | echo "Updating version in __about__.py to $FULL_VERSION"
36 | sed -i "s/__version__ = \".*\"/__version__ = \"$FULL_VERSION\"/g" src/haliax/__about__.py
37 | - name: Build package
38 | run: |
39 | python -m pip install --upgrade pip
40 | pip install build
41 | python -m build
42 |
43 | - name: Upload package
44 | uses: actions/upload-artifact@v4
45 | with:
46 | name: package
47 | path: dist/
48 |
49 |
50 | # cf https://test.pypi.org/manage/project/haliax/settings/publishing/
51 | publish-dev:
52 | runs-on: ubuntu-latest
53 | needs:
54 | - build-package
55 | permissions:
56 | id-token: write
57 | steps:
58 | - name: Retrieve release distributions
59 | uses: actions/download-artifact@v4
60 | with:
61 | name: package
62 | path: dist/
63 |
64 | - name: Publish release distributions to PyPI Test
65 | uses: pypa/gh-action-pypi-publish@release/v1
66 |
67 |
68 |
--------------------------------------------------------------------------------
/.github/workflows/run_pre_commit.yaml:
--------------------------------------------------------------------------------
1 | name: Pre-Commit
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: ["3.10.11"]
12 |
13 | steps:
14 | - uses: actions/checkout@v3
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v4
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install flake8 pytest pre-commit
23 | pip install .
24 | # - name: Lint with flake8
25 | # run: |
26 | # # stop the build if there are Python syntax errors or undefined names
27 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
28 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
29 | # flake8 . --count --exit-zero --max-complexity=50 --max-line-length=127 --statistics
30 | - name: "Run Pre-commit"
31 | run: |
32 | pre-commit run --all-files --show-diff-on-failure
33 |
34 |
--------------------------------------------------------------------------------
/.github/workflows/run_quick_levanter_tests.yaml:
--------------------------------------------------------------------------------
1 | name: Run Levanter Tests
2 |
3 | on: [pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v3
12 | - name: Set up Python 3.10.11
13 | uses: actions/setup-python@v4
14 | with:
15 | python-version: 3.10.11
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install flake8 pytest
20 | pip install "jax[cpu]==0.5.3" "jaxlib[cpu]==0.5.3" .[dev]
21 |
22 | - name: Install Levanter from source
23 | run: |
24 | cd ..
25 | git clone https://github.com/stanford-crfm/levanter.git
26 | cd levanter
27 | pip install -e .[tests]
28 | pip install -r tests/requirements.txt
29 | # i don't know why this is necessary
30 | pip install tensorboardX
31 | - name: Install Haliax on top
32 | run: |
33 | # install second since levanter will install a built version of haliax
34 | cd ../haliax
35 | - name: Test levanter with pytest
36 | run: |
37 | cd ../levanter
38 | XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:../src pytest tests -m "not entry and not slow"
39 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests.yaml:
--------------------------------------------------------------------------------
1 | name: Run Tests
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v3
12 | - name: Set up Python 3.10.11
13 | uses: actions/setup-python@v4
14 | with:
15 | python-version: 3.10.11
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install flake8 pytest
20 | pip install jax==0.4.35 jaxlib==0.4.35 .[dev]
21 | - name: Test with pytest
22 | run: |
23 | XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests
24 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /scratch
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # JetBrains
134 | .idea/
135 |
136 |
137 | # Wandb stuff
138 | /wandb
139 |
140 | # dataset cache files
141 | *.parquet
142 | ledger.json
143 |
144 | /checkpoints
145 | *.jaxpr
146 |
147 | # local execution commands
148 | local_*.sh
149 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | exclude: ".git"
4 | default_stages:
5 | - commit
6 | fail_fast: true
7 |
8 | repos:
9 | - repo: https://github.com/pre-commit/pre-commit-hooks
10 | rev: v4.0.1
11 | hooks:
12 | - id: trailing-whitespace
13 | - id: end-of-file-fixer
14 | - id: check-yaml
15 | - id: check-toml
16 | - id: check-merge-conflict
17 | - id: check-added-large-files
18 |
19 | - repo: https://github.com/psf/black
20 | rev: 22.3.0
21 | hooks:
22 | - id: black
23 |
24 | - repo: https://github.com/timothycrosley/isort
25 | rev: 5.11.5
26 | hooks:
27 | - id: isort
28 |
29 | - repo: https://github.com/PyCQA/flake8
30 | rev: 3.9.2
31 | hooks:
32 | - id: flake8
33 | additional_dependencies: [flake8-isort]
34 |
35 | - repo: https://github.com/pre-commit/mirrors-mypy
36 | rev: 'v1.5.1'
37 | hooks:
38 | - id: mypy
39 | args: [--ignore-missing-imports, --check-untyped-defs]
40 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file for MkDocs projects
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 | # Required
4 | version: 2
5 | # Set the version of Python and other tools you might need
6 | build:
7 | os: ubuntu-22.04
8 | tools:
9 | python: "3.11"
10 |
11 | mkdocs:
12 | configuration: mkdocs.yml
13 |
14 | # Optionally declare the Python requirements required to build your docs
15 | python:
16 | install:
17 | - requirements: docs/requirements.txt
18 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Contributing
2 | ============
3 |
4 | Levanter is a growing code base, and we are excited for other folks to get involved. The instructions below walk you through our dev setup and how to submit a PR.
5 |
6 | Dev Installation
7 | ----------------
8 |
9 | First follow the same instructions as provided for the [Levanter README](README.md) to install Levanter.
10 |
11 | The main addition for a dev environment is to install [`pre-commit`](https://pre-commit.com/):
12 |
13 | pre-commit install
14 |
15 | This will set up git hook scripts that ensure your code is formatted in a manner consistent with
16 | the repo. If any problems are found, the appropriate files will be updated with fixes. You will
17 | need to review and commit the fixed files.
18 |
19 | Forking The Repo
20 | ----------------
21 |
22 | To submit changes, you will need to work off of a fork of the repo and issue a pull request.
23 |
24 | There are two easy ways to fork the repo.
25 |
26 | If you have installed the [GitHub CLI](https://cli.github.com/) you can issue this command:
27 |
28 | gh repo fork stanford-crfm/levanter --clone=true
29 |
30 | This will create the fork and clone the repo into your current directory.
31 |
32 | Alternatively you can fork the repo in your browser. While logged in to your GitHub account,
33 | go to the [Levanter repo](https://github.com/stanford-crfm/levanter) and click on the Fork
34 | button in the upper left hand corner.
35 |
36 | You can then clone your forked version of the Levanter repo like any other GitHub repo.
37 |
38 | Create A Branch For Your Submission
39 | -----------------------------------
40 |
41 | You will generally need to create a branch of `main` for your code changes. In general every submission
42 | should be focused on a specific set of bug fixes or new features that are coherently
43 | related. Changes that are not related belong in different submissions. So you should
44 | be able to give your branch an informative name such as `checkpointer-time-bugfix` .
45 |
46 | You can create a branch off of `main` with this command:
47 |
48 | git checkout -b checkpointer-time-bugfix main
49 |
50 | Implement Your Changes
51 | ----------------------
52 |
53 | As you implement your changes in your feature branch, the git hook scripts will check your
54 | code for proper formatting as you make commits. Make sure you have run `pre-commit install`
55 | before you start making commits.
56 |
57 | You can also check all files in the current branch with this command:
58 |
59 | pre-commit run --all-files
60 |
61 | When your changes are operational you should verify that the current tests are passing.
62 |
63 | Set up your environment for running the tests:
64 |
65 | export PYTHONPATH=/path/to/levanter/src:path/to/levanter/tests:$PYTHONPATH
66 | wandb offline
67 |
68 | You can run the tests with this command:
69 |
70 | pytest tests
71 |
72 | You should add tests for any functionality you have added consistent with the [pytest](https://docs.pytest.org/en/6.2.x/) format
73 | of the existing tests.
74 |
75 | Submit Pull Request
76 | -------------------
77 |
78 | When your feature branch is ready you should submit a pull request.
79 |
80 | Detailed instructions for submtting a pull request from a fork can be found on [Github Docs](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork).
81 |
82 | The steps basically are:
83 |
84 | 1. While logged in to your GitHub, go to the original [Levanter repo pull request page](https://github.com/stanford-crfm/levanter/pulls)
85 | 2. Click on the highlighted text stating "compare across forks".
86 | 3. Set the base repository to `stanford-crfm/levanter` and the base branch to `main`.
87 | 4. Set the head repository to `your-org/levanter` and the compare branch to `your-feature-branch`.
88 | 5. Click on the "Create pull request" button and complete the pull request form.
89 |
90 | When submitting your pull request, you should provide a detailed description of what you've done.
91 |
92 | The following is a useful template:
93 |
94 | ## Description
95 | A brief and concise description of what your pull request is trying to accomplish.
96 |
97 | ## Fixes Issues
98 | A list of issues/bugs with # references. (e.g., #123)
99 |
100 | ## Unit test coverage
101 | Are there unit tests in place to make sure your code is functioning correctly?
102 |
103 | ## Known breaking changes/behaviors
104 | Does this break anything in Levanter's existing user interface? If so, what is it and how is it addressed?
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Haliax
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | > *Though you don’t seem to be much for listening, it’s best to be careful. If you managed to catch hold of even just a piece of my name, you’d have all manner of power over me.*
18 | > — Patrick Rothfuss, *The Name of the Wind*
19 |
20 | Haliax is a [JAX](https:://github.com/google/jax) library for building neural networks with named tensors, in the tradition of Alexander Rush's [Tensor Considered Harmful](https://nlp.seas.harvard.edu/NamedTensor).
21 | Named tensors improve the **legibility** and **compositionality** of tensor programs by using named axes instead of positional indices
22 | as typically used in NumPy, PyTorch, etc.
23 |
24 | Despite the focus on legibility, Haliax
25 | is also **fast**, typically about as fast as "pure" JAX code.
26 | Haliax is also built to be **scalable**: it
27 | can support [Fully-Sharded Data Parallelism (FSDP)](https://engineering.fb.com/2021/07/15/open-source/fsdp/) and Tensor Parallelism with [just a few lines of code](https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz). Haliax powers [Levanter](https://github.com/stanford-crfm/levanter),
28 | our companion library for training large language models and other foundation models, with scale proven up to 70B parameters
29 | and up to TPU v4-2048.
30 |
31 | ## Example: Attention
32 |
33 | Here's a minimal attention module implementation in Haliax. For a more detailed introduction,
34 | please see the [Haliax tutorial](https://colab.research.google.com/drive/1TiTcQQ4V5mopbgCu1SVl-oqJtXn7rFnC).
35 | (We use the excellent [Equinox](https://github.com/patrick-kidger/equinox) library for its module system and tree transformations.)
36 |
37 | ```python
38 | import equinox as eqx
39 | import jax
40 | import jax.numpy as jnp
41 | import haliax as hax
42 | import haliax.nn as hnn
43 |
44 | Pos = hax.Axis("position", 1024) # sequence length
45 | KPos = Pos.alias("key_position")
46 | Head = hax.Axis("head", 8) # number of attention heads
47 | Key = hax.Axis("key", 64) # key size
48 | Embed = hax.Axis("embed", 512) # embedding size
49 |
50 | # alternatively:
51 | #Pos, KPos, Head, Key, Embed = hax.make_axes(pos=1024, key_pos=1024, head=8, key=64, embed=512)
52 |
53 |
54 | def attention_scores(Key, KPos, query, key, mask):
55 | # how similar is each query to each key
56 | scores = hax.dot(query, key, axis=Key) / jnp.sqrt(Key.size)
57 |
58 | if mask is not None:
59 | scores -= 1E9 * (1.0 - mask)
60 |
61 | # convert to probabilities
62 | scores = haliax.nn.softmax(scores, KPos)
63 | return scores
64 |
65 |
66 | def attention(Key, KPos, query, key, value, mask):
67 | scores = attention_scores(Key, KPos, query, key, mask)
68 | answers = hax.dot(scores, value, axis=KPos)
69 |
70 | return answers
71 |
72 |
73 | # Causal Mask means that if pos >= key_pos, then pos can attend to key_pos
74 | causal_mask = hax.arange(Pos).broadcast_axis(KPos) >= hax.arange(KPos)
75 |
76 |
77 | class Attention(eqx.Module):
78 | proj_q: hnn.Linear # [Embed] -> [Head, Key]
79 | proj_k: hnn.Linear # [Embed] -> [Head, Key]
80 | proj_v: hnn.Linear # [Embed] -> [Head, Key]
81 | proj_answer: hnn.Linear # output projection from [Head, Key] -> [Embed]
82 |
83 | @staticmethod
84 | def init(Embed, Head, Key, *, key):
85 | k_q, k_k, k_v, k_ans = jax.random.split(key, 4)
86 | proj_q = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_q)
87 | proj_k = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_k)
88 | proj_v = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_v)
89 | proj_answer = hnn.Linear.init(In=(Head, Key), Out=Embed, key=k_ans)
90 | return Attention(proj_q, proj_k, proj_v, proj_answer)
91 |
92 | def __call__(self, x, mask=None):
93 | q = self.proj_q(x)
94 | # Rename "position" to "key_position" for self attention
95 | k = self.proj_k(x).rename({"position": "key_position"})
96 | v = self.proj_v(x).rename({"position": "key_position"})
97 |
98 | answers = attention(Key, KPos, q, k, v, causal_mask)
99 |
100 | x = self.proj_answer(answers)
101 | return x
102 | ```
103 |
104 | Haliax was created by [Stanford's Center for Research on Foundation Models (CRFM)](https://crfm.stanford.edu/)'s research engineering team.
105 | You can find us in the #levanter channel on the unofficial [Jax LLM Discord](https://discord.gg/FkRGNX3ND).
106 |
107 |
108 |
109 | ## Documentation
110 |
111 | ### Tutorials
112 |
113 | These are some tutorials to get you started with Haliax. They are available as Colab notebooks:
114 |
115 |
116 |
117 | * [Introduction to Haliax with Transformers](https://colab.research.google.com/drive/1TiTcQQ4V5mopbgCu1SVl-oqJtXn7rFnC)
118 | * [Distributed Training in Haliax](https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz) (including FSDP)
119 | * [Tensor Parallelism in Haliax](https://colab.research.google.com/drive/18_BrtDpe1lu89M4T6fKzda8DdSLtFJhi)
120 | * [Mixed Precision with `jmp`](https://colab.research.google.com/drive/1_4cikwt-UhSH7yRzNRK8ze9msM9r2mEl?usp=sharing) (This one is really a tutorial for [jmp](https://github.com/deepmind/jmp) but it's how to use it with Haliax...)
121 |
122 |
123 | ### API Reference
124 |
125 | Haliax's API documentation is available at [haliax.readthedocs.io](https://haliax.readthedocs.io/en/latest/).
126 |
127 | ## Contributing
128 |
129 | We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more information.
130 | We also have a list of [good first issues](https://github.com/stanford-crfm/haliax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
131 | to help you get started. (If those don't appeal, don't hesitate to reach out to us on Discord!)
132 |
133 | ## License
134 |
135 | Haliax is licensed under the Apache License, Version 2.0. See [LICENSE](LICENSE) for the full license text.
136 |
--------------------------------------------------------------------------------
/docs/broadcasting.md:
--------------------------------------------------------------------------------
1 | # Broadcasting
2 |
3 | One area where Haliax's treatment of named axes differs substantially from Numpy-esque positional code is in broadcasting. In traditional positional code, [broadcasting works like this](https://numpy.org/doc/stable/user/basics.broadcasting.html):
4 |
5 | ```python
6 | import numpy as np
7 |
8 | # compute the outer product of two arrays
9 | a = np.arange(5)
10 | b = np.arange(4)
11 |
12 | c = a.reshape((-1, 1)) * b.reshape((1, -1))
13 | print(c.shape)
14 | print(c)
15 |
16 | # alternatively
17 | c2 = a[:, np.newaxis] * b
18 | ```
19 |
20 | This prints:
21 | ```
22 | (5, 4)
23 | [[ 0 0 0 0]
24 | [ 0 1 2 3]
25 | [ 0 2 4 6]
26 | [ 0 3 6 9]
27 | [ 0 4 8 12]]
28 | ```
29 |
30 | To quote the NumPy documentation, for positional arrays, "in order to broadcast, the size of the trailing axes for both arrays in an operation must either be the same size or one of them must be one."
31 |
32 | I have found this to be a source of bugs: it is easy to accidentally have an array of size [batch_size, 1] and combine it with an array of size [batch_size], yielding an array of [batch_size, batch_size].
33 |
34 | In Haliax, broadcasting is done by matching names. The same operation in Haliax might look like this:
35 |
36 | ```python
37 | M = hax.Axis("M", 5)
38 | N = hax.Axis("N", 4)
39 |
40 | a = hax.arange(M)
41 | b = hax.arange(N)
42 |
43 | c = a.broadcast_axis(N) * b
44 | print(c.axes)
45 | print(c.array)
46 | ```
47 |
48 | ```
49 | (Axis(name='N', size=4), Axis(name='M', size=5))
50 | [[ 0 0 0 0 0]
51 | [ 0 1 2 3 4]
52 | [ 0 2 4 6 8]
53 | [ 0 3 6 9 12]]
54 | ```
55 |
56 | Haliax aims to be "order-independent" as much as possible (while still letting you choose the order for performance or compatibility with positional code).
57 | Its semantics are: "in order to broadcast, identically named Axes of the arrays must have the same size.
58 | In addition, they must share at least one named axis in common, unless one is a scalar."
59 | The second condition is there to avoid bugs: we want to be sure that the arrays have something in common.
60 | To satisfy the second condition, it is not uncommon to use [haliax.broadcast_axis][], like we did above.
61 | This method takes one or more axes and adds them to the array.
62 |
63 | ## Explicit Broadcasting Functions
64 |
65 | * [haliax.broadcast_axis][]
66 | * [haliax.broadcast_to][]
67 |
68 |
69 |
78 |
79 | ## A note on performance
80 |
81 | Under the hood, Haliax will automatically broadcast and permute axes so that the underlying positional code produces the correct result.
82 | This is usually not a substantial performance hit because XLA is pretty good about picking optimal shapes,
83 | but if you're doing repeated operations you may want to use [haliax.rearrange][] to change the order of axes.
84 | As an example, in Levanter's GPT-2 implementation, we found using `rearrange` led to a 5% speedup for small models. This
85 | isn't huge, but it's not nothing either.
86 |
--------------------------------------------------------------------------------
/docs/css/material.css:
--------------------------------------------------------------------------------
1 | .md-main__inner {
2 | margin-bottom: 1.5rem;
3 | }
4 |
5 | /* Custom admonition: preview */
6 | :root {
7 | --md-admonition-icon--preview: url('data:image/svg+xml;charset=utf-8,');
8 | }
9 |
10 | .md-typeset .admonition.preview,
11 | .md-typeset details.preview {
12 | border-color: rgb(220, 139, 240);
13 | }
14 |
15 | .md-typeset .preview>.admonition-title,
16 | .md-typeset .preview>summary {
17 | background-color: rgba(142, 43, 155, 0.1);
18 | }
19 |
20 | .md-typeset .preview>.admonition-title::before,
21 | .md-typeset .preview>summary::before {
22 | background-color: rgb(220, 139, 240);
23 | -webkit-mask-image: var(--md-admonition-icon--preview);
24 | mask-image: var(--md-admonition-icon--preview);
25 | }
26 |
--------------------------------------------------------------------------------
/docs/css/mkdocstrings.css:
--------------------------------------------------------------------------------
1 | /* Indentation. */
2 | div.doc-contents {
3 | padding-left: 25px;
4 | border-left: .05rem solid var(--md-typeset-table-color);
5 | }
6 |
7 |
8 | div.doc-class:not(.doc-contents .doc-contents)::after {
9 | content: "";
10 | display: block;
11 | width: 100%;
12 | height: 1px; /* Adjust thickness */
13 | background-color: black; /* Adjust color */
14 | margin: 10px 0; /* Adjust spacing */
15 | }
16 |
17 | /* Mark external links as such. */
18 | a.external::after,
19 | a.autorefs-external::after {
20 | /* https://primer.style/octicons/arrow-up-right-24 */
21 | mask-image: url('data:image/svg+xml,');
22 | content: ' ';
23 |
24 | display: inline-block;
25 | vertical-align: middle;
26 | position: relative;
27 |
28 | height: 1em;
29 | width: 1em;
30 | background-color: var(--md-typeset-a-color);
31 | }
32 |
33 | a.external:hover::after,
34 | a.autorefs-external:hover::after {
35 | background-color: var(--md-accent-fg-color);
36 | }
37 |
--------------------------------------------------------------------------------
/docs/faq.md:
--------------------------------------------------------------------------------
1 | # Tips and FAQ
2 |
3 | See also the [Equinox FAQ](https://docs.kidger.site/equinox/faq/)
4 |
5 | ## Tip 1: `hax.debug.diagnose_common_issues`
6 |
7 | `hax.debug.diagnose_common_issues` is a function that will raise an exception if it detects problems with your module.
8 | Currently, we diagnose:
9 |
10 | * Reuse of arrays or NamedArrays in a field. [Equinox modules must be trees.](https://docs.kidger.site/equinox/faq/#a-module-saved-in-two-places-has-become-two-independent-copies)
11 | * Use of arrays or NamedArrays in a static field. Static data in JAX/Equinox must be hashable, and arrays are not hashable.
12 |
--------------------------------------------------------------------------------
/docs/figures/data_parallel_mesh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/data_parallel_mesh.png
--------------------------------------------------------------------------------
/docs/figures/data_parallel_mesh_replicated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/data_parallel_mesh_replicated.png
--------------------------------------------------------------------------------
/docs/figures/device_mesh_1d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_1d.png
--------------------------------------------------------------------------------
/docs/figures/device_mesh_1d_zero.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_1d_zero.png
--------------------------------------------------------------------------------
/docs/figures/device_mesh_2d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d.png
--------------------------------------------------------------------------------
/docs/figures/device_mesh_2d_batch_partitioned.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_batch_partitioned.png
--------------------------------------------------------------------------------
/docs/figures/device_mesh_2d_data_replicated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_data_replicated.png
--------------------------------------------------------------------------------
/docs/figures/device_mesh_2d_data_replicated_mlp_partitioned.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_data_replicated_mlp_partitioned.png
--------------------------------------------------------------------------------
/docs/figures/device_mesh_2d_intermediate_fully_partitioned.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_intermediate_fully_partitioned.png
--------------------------------------------------------------------------------
/docs/figures/device_mesh_2d_zero.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_zero.png
--------------------------------------------------------------------------------
/docs/fp8.md:
--------------------------------------------------------------------------------
1 | # Quantized Training
2 |
3 | !!! warning
4 |
5 | FP8 and Int8 training in Haliax is currently experimental and may change in the future.
6 |
7 | Haliax supports training with FP8 and int8. This is useful for training on hardware that is optimized for FP8 or Int8,
8 | such as the H100 (fp8) or A100s (int8) and TPU v5 and newer (int8).
9 |
10 | ## TL;DR
11 |
12 | Using FP8 with Haliax is actually pretty straightforward. To enable FP8, do this:
13 |
14 | ```python
15 | import haliax.quantization as haxq
16 | # setup
17 | module = haxq.quantize_linear_layers(module, haxq.QuantizationConfig(fp8=True))
18 |
19 | # if using optax. This saves a tiny amount of memory so you can skip it if you want
20 | _, trainable_module = haxq.partition_for_grad_overwrite(module)
21 | opt_state = opt.initial_state(trainable_module)
22 |
23 | # train step
24 | grads = eqx.filter_grad(loss_fn)(module, data)
25 | overwrite, grads = haxq.partition_for_grad_overwrite(grads)
26 | updates, opt_state = opt.update(grads, opt_state, params=module) # or however you update your optimizer
27 | module = haxq.apply_updates(module, updates, overwrite)
28 | ```
29 |
30 | And train your model like normal.
31 |
32 | Similarly, you can use `Int8` by setting `Int8=True` in the `QuantizationConfig` object.
33 |
34 |
35 |
36 | ## What is FP8?
37 |
38 | FP8 refers to 8-bit floating point numbers. FP8 is a massively reduced precision compared to the 32-bit floating point numbers
39 | or 16-bit floating point numbers that are typically used in deep learning: there are only 256 possible values in FP8, compared to
40 | the (almost) 2^32 in 32-bit and 2^16 in 16-bit. However, FP8 is still useful for training deep learning models, especially on
41 | hardware that is optimized for FP8. In particular, it can massively accelerate training on hardware that is optimized for FP8:
42 | H100 has 2x FP8 FLOPS compared to FP16 FLOPS and almost 60x(!) compared to F32 FLOPS.
43 |
44 | The FP8 in Haliax is currently designed to optimize throughput on FP8-enabled devices (currently H100) rather
45 | than to save memory. In particular, Haliax's FP8 support is not designed to quantize a model to FP8 for deployment,
46 | though this shouldn't be that hard to add for models that were trained using this functionality.
47 | We would be happy to accept contributions to add this functionality,
48 | and are happy to work with you to do so. In particular, adding this for models trained using Haliax's FP8 should be easy.
49 |
50 | See this [FP8 Primer](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for more information on FP8.
51 |
52 | ## What is Int8?
53 |
54 | Int8 refers to 8-bit integers. Int8 has the same number of bits as FP8, but the interpretation is different: instead of
55 | exponentially spaced numbers, Int8 has linearly spaced numbers.
56 |
57 | In Haliax, we support Int8 training through Google's [AQT](https://github.com/google/aqt) library. AQT (for
58 | "Accurate Quantization Training") is a library that allows you to train models with quantization-aware training (QAT).
59 |
60 | ## How to use FP8 or Int8 in Haliax
61 |
62 | To use quantized training in Haliax, you need to do three things:
63 |
64 | * Enable FP8 (or int8) for the layers you want
65 | * Modify your training step to be compatible
66 |
67 | Each of these is just a couple of lines of code.
68 |
69 | ```python
70 | import haliax as hax
71 | import equinox as eqx
72 | import jax
73 |
74 | In = hax.Axis("In", 32)
75 | Mid = hax.Axis("Mid", 128)
76 | Out = hax.Axis("Out", 16)
77 | Hidden = hax.Axis("Hidden", 64)
78 |
79 |
80 | class MyModule(eqx.Module):
81 | up_proj: hax.nn.Linear
82 | down_proj: hax.nn.Linear
83 |
84 | @staticmethod
85 | def init(*, key):
86 | super().__init__()
87 | k_up, k_down = jax.random.split(key)
88 | return MyModule(
89 | up_proj=hax.nn.Linear.init(In, Mid, key=k_up),
90 | down_proj=hax.nn.Linear.init(Mid, Out, key=k_down),
91 | )
92 |
93 | def __call__(self, x):
94 | x = self.up_proj(x)
95 | x = hax.nn.relu(x)
96 | x = self.down_proj(x)
97 | return x
98 |
99 | module = MyModule.init(key=jax.random.PRNGKey(0))
100 |
101 | # Enable FP8
102 | module = hax.quantization.quantize_linear_layers(module, QuantizationConfig(fp8=True))
103 |
104 | # Enable FP8 for a specific layer
105 | from haliax.quantization import QuantizationConfig
106 |
107 | config = QuantizationConfig(targets=["up_proj"], fp8=True)
108 | module = hax.quantization.quantize_linear_layers(module, config)
109 |
110 | # Train step
111 | grads = eqx.filter_grad(loss_fn)(module, data)
112 | overwrite, grads = haxq.partition_for_grad_overwrite(grads)
113 | updates, opt_state = opt.update(grads, opt_state, params=module) # or however you update your optimizer
114 | module = hax.quantization.apply_updates(module, updates, grads)
115 | ```
116 |
117 | That's it! Just a few lines of code to enable FP8. The `quantize_linear_layers` function will transform your module to use
118 | quantization-aware training for linear layers (or a subset if you want), and the combo of `partition_for_grad_overwrite` and `apply_updates` function will apply the updates to the module
119 | in a way that is compatible with FP8.
120 |
121 | ## How FP8 works
122 |
123 | For an overview of the FP8, see the [FP8 Primer](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html).
124 | You don't need to understand it though. Haliax's FP8 integration is more or less plug and play, as shown above.
125 | The implementation of FP8 in Haliax is more or less a straightforward port (including some copy and paste) of the
126 | [FP8 implementation in Flax](https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py).
127 |
128 | FP8 in JAX (as well as INT8) is typically implemented using "`dot_general` injection", where you pass
129 | a custom implementation of `dot_general` to functions and modules like [haliax.dot][] and [haliax.nn.Linear][].
130 | The `dot_general` for FP8 is implemented by scaling
131 | the inputs, projecting the inputs to FP8, performing the computation in FP8, and then
132 | scaling the result back to the original precision.
133 | The subtle part of FP8 is that the scaling is a parameter that is trained based on a history of the inputs to the layer
134 | (as well as gradients coming in from backward). This means that the FP8 `dot_general` needs to maintain state.
135 | In Equinox, this means that the `dot_general` is actually a `Module` that packages together the state and the
136 | computation. (Unlike [equinox.nn.StatefulLayer][] which returns a state object you pass back into the module, the FP8 `dot_general`
137 | module hijacks the gradient computation to update its state. This is necessary because the FP8 scaling factors
138 | depend on the gradients.)
139 |
140 | The way this happens is by "hijacking" the gradient computation. When you call `eqx.filter_grad(loss_fn)(module, data)`,
141 | you will get the gradient computation as normal, but you'll also get the updated state of the FP8 `dot_general` module.
142 | This updated state needs to directly replace the state in the module (rather than be used for a gradient step), which is
143 | why you need to use the `partition_for_grad_overwrite`
144 |
145 | The FP8 `dot_general` module is implemented in [haliax.quantization.Fp8DotGeneralOp][]. It's actually not that complicated:
146 |
147 | 1) It holds a scaling factor and history of maximum values for each of (lhs, rhs, output) and updates them based on the
148 | gradients.
149 | 2) When invoked, it scales the inputs, projects them to FP8, performs the computation, and scales the result back to the
150 | original precision. It remembers the maximum absolute value for each of the inputs.
151 | 3) For the gradients, it scales the gradients, projects them to FP8, does the backward computation,
152 | and scales the gradients back to the original precision. It remembers the maximum absolute value for the incoming
153 | gradient and stores it in the gradient.
154 |
155 | ## How Int8 works
156 |
157 | Int8 is in principle the same, though the details differ. AQT is a much more flexible library than the FP8 implementation,
158 | because it can be a bit more finicky. We use AQT directly, and we recommend you look at the
159 | [AQT documentation](https://github.com/google/aqt?tab=readme-ov-file#how-aqt-works-internally) for more
160 | information on how it works.
161 |
162 | # API Reference
163 |
164 | ## Functions
165 |
166 | ::: haliax.quantization.quantize_linear_layers
167 | ::: haliax.quantization.partition_for_grad_overwrite
168 | ::: haliax.quantization.apply_updates
169 |
170 |
171 | ## Interfaces
172 | ::: haliax.quantization.DotGeneralOp
173 | ::: haliax.quantization.OverwriteWithGradient
174 |
175 | ## Modules
176 |
177 |
178 | ::: haliax.quantization.DefaultDotGeneralOp
179 | ::: haliax.quantization.Fp8DotGeneralOp
180 | ::: haliax.quantization.Int8DotGeneralOp
181 |
182 | ## Configuration
183 |
184 | ::: haliax.quantization.QuantizationConfig
185 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | {%
2 | include-markdown "../README.md"
3 | start=""
4 | end=""
5 | %}
6 |
7 |
8 | The code is released on GitHub: [Haliax Repository](https://github.com/stanford-crfm/haliax/).
9 |
10 |
11 | To contribute, please refer to the [Contributing Guide](https://github.com/stanford-crfm/haliax/blob/main/CONTRIBUTING.md).
12 |
--------------------------------------------------------------------------------
/docs/matmul.md:
--------------------------------------------------------------------------------
1 | ## Matrix Multiplication
2 |
3 | Haliax has two ways to do matrix multiplication (and tensor contractions more generally):
4 | [haliax.dot][] and [haliax.einsum][]. [haliax.dot][] and [haliax.einsum][]
5 | can both express any tensor contraction, though in different situations one or the other may be
6 | more suitable for expressing a particular contraction In general:
7 |
8 | - Use [haliax.dot][] when you want to express a simple matrix multiplication over one or a few axes.
9 | - Use [haliax.einsum][] when you want to express a more complex tensor contraction.
10 |
11 | See also the API reference for [haliax.dot][] and [haliax.einsum][] and the
12 | [cheat sheet section](cheatsheet.md#matrix-multiplication).
13 |
14 | ### `haliax.dot`
15 |
16 | With [haliax.dot][], you specify the axes to contract over, without needing to write out the
17 | axes you want to keep (though you can if you want):
18 |
19 | ```python
20 | import haliax as hax
21 |
22 | H = hax.Axis("H", 3)
23 | W = hax.Axis("W", 4)
24 | D = hax.Axis("D", 5)
25 |
26 | x = hax.ones((H, W, D))
27 | w = hax.ones((D,))
28 | y = hax.dot(x, w, axis=D) # shape is (H, W), equivalent to np.einsum("hwd,d->hw", x, w)
29 | ```
30 |
31 | [haliax.dot][] is at its best when you want to express a simple matrix multiplication over one or a few axes.
32 | Syntactically, [haliax.dot][] is similar to reduction operations like [haliax.sum][] and [haliax.mean][].
33 |
34 | The [cheat sheet](cheatsheet.md) has a section on [matrix multiplication](cheatsheet.md#matrix-multiplication)
35 | that gives a few examples. Here are several more:
36 |
37 | ```python
38 | import haliax as hax
39 |
40 | H = hax.Axis("H", 3)
41 | W = hax.Axis("W", 4)
42 | D = hax.Axis("D", 5)
43 | C = hax.Axis("C", 6)
44 |
45 | x = hax.arange((H, W, D, C))
46 | w = hax.arange((D, C))
47 | c = hax.arange((C,))
48 |
49 | y = hax.dot(x, c, axis=C) # shape is (H, W, D), equivalent to jnp.dot(x, c)
50 |
51 | y = hax.dot(x, w, axis=(D, C)) # shape is (H, W), equivalent to np.einsum("...dc,dc->...", x, w)
52 | y = hax.dot(x, w, axis=(D, C), out_axes=(W, H)) # shape is (W, H) instead of (H, W)
53 | y = hax.dot(x, w, c, axis=(D, C)) # shape is (H, W), equivalent to np.einsum("...dc,dc,c->...", x, w, c)
54 | y = hax.dot(x, c, axis=(H, D, C)) # shape is (W,), equivalent to np.einsum("hwdc,c->w", x, c)
55 | s = hax.dot(x, w, axis=None) # scalar output, equivalent to np.einsum("hwdc,dc->", x, w)
56 | y = hax.dot(x, w, c, axis=()) # shape is (H, W, D, C), equivalent to np.einsum("hwdc,dc,c->hwdc", x, w, c)
57 | y = hax.dot(x, w, c, axis=(), out_axes=(D, ..., H)) # shape is (D, W, C, H), equivalent to np.einsum("hwdc,dc,c->dwch", x, w, c)
58 | ```
59 |
60 | ### `haliax.einsum`
61 |
62 | [haliax.einsum][] is at its best when you want to express a more complex tensor contraction.
63 | It is similar to [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html)
64 | or [einops.einsum](https://einops.rocks/api/einsum/) in terms of syntax and behavior,
65 | but extended to work with named axes, including the added flexibility that named axes provide.
66 | Our "flavor" of `einsum` is most similar to `einops.einsum`'s flavor, in that
67 | it supports long names for axes (like `"batch h w, h w channel -> batch channel"`)
68 | rather than the compact notation of `numpy.einsum` (like `"bhwc,hwc->bc"`).
69 |
70 | Haliax's version of `einsum` comes in three modes: "ordered", "unordered", and "output axes".
71 | These modes are all accessible through the same function without any flags: the syntax
72 | of the `einsum` string determines which mode is used.
73 |
74 | The syntax for Haliax's `einsum` is similar to [`haliax.rearrange`](rearrange.md), which
75 | is in turn similar to [einops.rearrange](https://einops.rocks/api/rearrange/).
76 |
77 | #### Ordered Mode
78 |
79 | Haliax's `einsum` has an "ordered" mode that is similar to `einops.einsum`'s behavior.
80 | In this mode, the axes in the input arrays are matched to the axes in the `einsum` string in order.
81 | It supports ellipses in the same way as `einops.einsum`. The names in the einsum string need not
82 | match the names of the axes in the input arrays, but the order of the axes must match.
83 |
84 | ```python
85 | import haliax as hax
86 |
87 | H = hax.Axis("H", 3)
88 | W = hax.Axis("W", 4)
89 | D = hax.Axis("D", 5)
90 |
91 | x = hax.ones((H, W, D))
92 | w = hax.ones((D,))
93 | y = hax.einsum("h w d, d -> h w", x, w) # shape is (H, W), equivalent to jnp.einsum("hwd,d->hw", x, w)
94 | y = hax.einsum("... d, d -> ...", x, w) # same as above
95 | ```
96 |
97 | The `...` syntax is used to indicate that the axes in the input arrays that are not mentioned in the `einsum` string
98 | should be preserved in the output. This should be the same as `einops.einsum`'s behavior, with the exception
99 | that the names of axes with the same label in the string must have the same names in the input arrays.
100 |
101 | (If you notice any differences between Haliax's `einsum`'s ordered syntax and `einops.einsum`, please let us know!)
102 |
103 | #### Unordered Mode
104 |
105 | In "unordered mode," the axes in the input arrays are matched to the axes in the `einsum` string by name,
106 | using similar rules to [haliax.rearrange][]. Names involved in the operation are specified inside `{}`
107 | on the left hand side of the `->` in the `einsum` string. Axes not specified are implicitly preserved.
108 |
109 | ```python
110 | import haliax as hax
111 |
112 | H = hax.Axis("H", 3)
113 | W = hax.Axis("W", 4)
114 | D = hax.Axis("D", 5)
115 |
116 | x = hax.ones((H, W, D))
117 | w = hax.ones((D,))
118 |
119 | y = hax.einsum("{H W D} -> H W", x) # shape is (H, W)
120 | y = hax.einsum("{D} -> ", w) # shape is (H, W)
121 | y = hax.einsum("{...} -> ", x) # shape is ()
122 | y = hax.einsum("{H ...} -> H", x) # shape is (H,)
123 | y = hax.einsum("{H ...} -> ...", x) # shape is (W, D)
124 | ```
125 |
126 | This mode is most similar to [haliax.dot][]'s behavior, though it's a bit more expressive.
127 |
128 | You can also use axis aliases in the `einsum` string, which can be useful for expressing contractions
129 | in library code or just for shortening the string:
130 |
131 | ```python
132 | Height = hax.Axis("Height", 3)
133 | Width = hax.Axis("Width", 4)
134 | Depth = hax.Axis("Depth", 5)
135 |
136 | x = hax.ones((Height, Width, Depth))
137 | w = hax.ones((Depth,))
138 |
139 | y = hax.einsum("{H W D} -> H W", x, H=Height, W=Width, D=Depth) # shape is (Height, Width)
140 | y = hax.einsum("{D} -> ", w, D=Depth) # shape is (Height, Width)
141 | ```
142 |
143 |
144 | #### Output Axes Mode
145 |
146 | In "output axes" mode, you only specify the axes that should be in the output. All other
147 | axes are implicitly contracted over. This mode is a bit "dangerous" in that it's easy to
148 | accidentally contract over axes you didn't mean to, but it can be useful for expressing
149 | certain contractions concisely.
150 |
151 | ```python
152 | import haliax as hax
153 |
154 | H = hax.Axis("H", 3)
155 | W = hax.Axis("W", 4)
156 | D = hax.Axis("D", 5)
157 |
158 | x = hax.ones((H, W, D))
159 | w = hax.ones((D,))
160 |
161 | y = hax.einsum("-> H W", x) # shape is (H, W)
162 | y = hax.einsum("-> D", w) # shape is (D,)
163 | ```
164 |
165 | We don't recommend using this mode except in cases when you're sure of the full shape of the input arrays
166 | or you are sure you don't want to let users implicitly batch over any axes.
167 |
168 | Output axes mode also supports axis aliases:
169 |
170 | ```python
171 | Height = hax.Axis("Height", 3)
172 | Width = hax.Axis("Width", 4)
173 | Depth = hax.Axis("Depth", 5)
174 |
175 | x = hax.ones((Height, Width, Depth))
176 | w = hax.ones((Depth,))
177 | y = hax.einsum("-> Height Width", x, Height=Height, Width=Width, Depth=Depth) # shape is (Height, Width)
178 | y = hax.einsum("-> Depth", w, Depth=Depth) # shape is (Depth,)
179 | ```
180 |
--------------------------------------------------------------------------------
/docs/nn.md:
--------------------------------------------------------------------------------
1 | # Neural Networks
2 |
3 |
4 | ## Modules
5 |
6 | Haliax provides a small number of neural network modules that are compatible with Equinox, though
7 | they naturally all use [haliax.NamedArray][]. (We welcome PRs for more modules! Nothing too exotic though.)
8 |
9 | The most interesting of these modules is [haliax.nn.Stacked][], which allows you to create homogenous "stacks"
10 | of the same module (e.g. transformer blocks), which is a common pattern in deep learning.
11 |
12 | ### Linear
13 |
14 | ::: haliax.nn.Embedding
15 | ::: haliax.nn.Linear
16 |
17 | ### Dropout
18 | ::: haliax.nn.Dropout
19 |
20 | ### Normalization
21 |
22 | ::: haliax.nn.normalization.LayerNormBase
23 | ::: haliax.nn.LayerNorm
24 | ::: haliax.nn.RmsNorm
25 |
26 | ### Meta
27 |
28 | ::: haliax.nn.MLP
29 |
30 | ### Stacked
31 |
32 | See the full documentation of [Stacked](scan.md#stacked).
33 |
34 | ### Convolution
35 |
36 | Unlike other frameworks, Haliax doesn't distinguish between 1D, 2D, and 3D, and general convolutions. Instead, we have
37 | a single [haliax.nn.Conv][] module that can be used for all of these, depending on the number of axes
38 | provided. Similarly, for transposed convolutions, we have [haliax.nn.ConvTranspose][].
39 |
40 | ::: haliax.nn.Conv
41 | ::: haliax.nn.ConvTranspose
42 |
43 | ### Pooling
44 |
45 | As with convolutions, we don't distinguish between 1D, 2D, and 3D pooling, and instead have a single
46 | pooling operation for each of the kinds of reductions:
47 |
48 | ::: haliax.nn.max_pool
49 | ::: haliax.nn.mean_pool
50 | ::: haliax.nn.min_pool
51 |
52 | ## Attention
53 |
54 | We don't provide an explicit attention module, but we do provide an attention function and several related functions:
55 |
56 | :::haliax.nn.attention.dot_product_attention
57 | :::haliax.nn.attention.dot_product_attention_weights
58 | :::haliax.nn.attention.self_attention
59 |
60 | ### Masks
61 | ::: haliax.nn.attention.causal_mask
62 | ::: haliax.nn.attention.prefix_lm_mask
63 | ::: haliax.nn.attention.combine_masks_and
64 | ::: haliax.nn.attention.combine_masks_or
65 | ::: haliax.nn.attention.forgetful_causal_mask
66 |
67 | ### Biases
68 |
69 | ::: haliax.nn.attention.mask_to_bias
70 | ::: haliax.nn.attention.alibi_attention_bias
71 |
72 | ## Functions
73 |
74 | These functions wrap the equivalent in [jax.nn][]:
75 |
76 | ::: haliax.nn.relu
77 | ::: haliax.nn.relu6
78 | ::: haliax.nn.sigmoid
79 | ::: haliax.nn.softplus
80 | ::: haliax.nn.soft_sign
81 | ::: haliax.nn.silu
82 | ::: haliax.nn.swish
83 | ::: haliax.nn.log_sigmoid
84 | ::: haliax.nn.leaky_relu
85 | ::: haliax.nn.hard_sigmoid
86 | ::: haliax.nn.hard_silu
87 | ::: haliax.nn.hard_swish
88 | ::: haliax.nn.hard_tanh
89 | ::: haliax.nn.elu
90 | ::: haliax.nn.celu
91 | ::: haliax.nn.selu
92 | ::: haliax.nn.gelu
93 | ::: haliax.nn.quick_gelu
94 | ::: haliax.nn.glu
95 | ::: haliax.nn.logsumexp
96 | ::: haliax.nn.log_softmax
97 | ::: haliax.nn.softmax
98 | ::: haliax.nn.standardize
99 | ::: haliax.nn.one_hot
100 |
101 | ### Loss Functions
102 |
103 | ::: haliax.nn.cross_entropy_loss
104 | ::: haliax.nn.cross_entropy_loss_and_log_normalizers
105 |
--------------------------------------------------------------------------------
/docs/partitioning.md:
--------------------------------------------------------------------------------
1 | # Partitioning
2 |
3 | Partitioning refers to the process of splitting arrays and computation across multiple devices. Haliax provides a number
4 | of functions for partitioning arrays and computation across multiple devices.
5 |
6 |
7 | ## Tutorial
8 | An introduction to using Haliax's partitioning functions to scale a transformer can be found here: [Distributed Training in Haliax](https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz).
9 |
10 | This page is designed to be more of a reference than a tutorial, and we assume you've read the tutorial before reading this page.
11 |
12 |
13 | ## Device Meshes in JAX
14 |
15 | See also JAX's tutorial [Distributed Arrays and Automatic Parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
16 | for more details.
17 |
18 | One of the main ways JAX provides distributed parallelism is via the [jax.sharding.Mesh][].
19 | A mesh is a logical n-dimensional array of devices. Meshes in JAX are represented as a Numpy ndarray (note: not `jax.numpy`)
20 | of devices and a tuple of axis names. For example, a 2D mesh of 16 devices might look like this:
21 |
22 | ```python
23 | import jax
24 | import jax.numpy as jnp
25 |
26 | from jax.sharding import Mesh
27 |
28 | devices = jax.devices()
29 | mesh = Mesh(jnp.array(devices).reshape((-1, 2)), ("data", "model"))
30 | ```
31 |
32 | 
33 |
34 | The mesh above has two axes, `data` and `model`. In JAX's mesh parallelism, arrays are distributed by overlaying axes of
35 | the array on top of the axes of the mesh. For example, if we have a batch of 32 sequences we might do something like this:
36 |
37 | ```python
38 | from jax.sharding import NamedSharding, PartitionSpec
39 |
40 | batch_size = 32
41 | seqlen = 512
42 |
43 | batch = jnp.zeros((batch_size, seqlen), dtype=jnp.float32)
44 | batch = jax.device_put(batch, NamedSharding(mesh, PartitionSpec("data", None)))
45 | ```
46 |
47 | This specifies that the first axis of `batch` should be distributed across the `data` axis of the mesh. The `None` in the
48 | `PartitionSpec` indicates that the second axis of `batch` is not distributed, which means that the data is replicated
49 | so that one copy of the data is partitioned across each row of the mesh.
50 |
51 | 
52 |
53 | What's nice about this approach is that jax will automatically schedule computations so that operations are distributed
54 | in the way you would expect: you don't have to explicitly manage communication between devices.
55 |
56 | However, JAX sometimes gets confused, and it's not sure how you want your arrays partitioned. In Jax, there's a function
57 | called [jax.lax.with_sharding_constraint][] that lets you explicitly specify the sharding for the outputs of arrays.
58 | You use this function only inside `jit`.
59 |
60 | ## Haliax Partitioning in a nutshell
61 |
62 | As you might imagine, it gets tedious and error-prone to have to specify the partitioning of every array you create. Haliax provides
63 | routines to handle mapping of [haliax.NamedArray][]s automatically.
64 |
65 | ```python
66 | import haliax as hax
67 |
68 | Batch = hax.Axis("batch", 32)
69 | SeqLen = hax.Axis("seqlen", 512)
70 |
71 | axis_mapping = {"batch": "data", }
72 |
73 | batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32)
74 | batch = hax.shard(batch, axis_mapping)
75 |
76 | # we also have "auto_sharded" and support context mappings for axis mappings:
77 | with hax.axis_mapping({"batch": "data"}):
78 | batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32)
79 | batch = hax.shard(batch)
80 | ```
81 |
82 | Unlike in JAX, which has separate APIs for partitioning arrays inside and outside of `jit`, Haliax has a single API:
83 | `hax.shard` work inside and outside of `jit`. Haliax automatically
84 | chooses which JAX function to use based on context.
85 |
86 |
87 | ## Axis Mappings
88 |
89 | The core data structure we use to represent partitioning is the [haliax.partitioning.ResourceMapping][] which
90 | is just an alias for a `Dict[str, str|Sequence[str]]`. The keys in this dictionary are the names of "logical" Axes in NamedArrays
91 | and the values are the names of axes in the mesh. (In theory you can partition a single Axis across multiple axes in the mesh,
92 | but we don't use this functionality.)
93 |
94 | ::: haliax.partitioning.ResourceMapping
95 |
96 | A context manager can be used to specify an axis mapping for the current thread for the duration of the context:
97 |
98 | ```python
99 | with hax.axis_mapping({"batch": "data"}):
100 | batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32)
101 | batch = hax.auto_sharded(batch)
102 | ```
103 |
104 | ::: haliax.partitioning.axis_mapping
105 |
106 | ## Partitioning Functions
107 |
108 | ### Sharding Arrays and PyTrees
109 |
110 | These functions are used to shard arrays and PyTrees of arrays, e.g. Modules.
111 | This is the main function you will use to shard arrays:
112 |
113 | ::: haliax.shard
114 |
115 | This function is like `shard` but does not issue a warning if there is no context axis mapping.
116 | It's useful for library code where there may or may not be a context mapping:
117 |
118 | ::: haliax.auto_sharded
119 |
120 | This is an older function that is being deprecated in favor of `shard`. It is functionally equivalent to `shard`:
121 |
122 | ::: haliax.shard_with_axis_mapping
123 |
124 | ### `named_jit` and friends
125 |
126 | ::: haliax.named_jit
127 | ::: haliax.fsdp
128 |
129 |
130 | ### Querying the Mesh and Axis Mapping
131 |
132 |
133 | ::: haliax.partitioning.round_axis_for_partitioning
134 | ::: haliax.partitioning.physical_axis_name
135 | ::: haliax.partitioning.physical_axis_size
136 | ::: haliax.partitioning.sharding_for_axis
137 |
--------------------------------------------------------------------------------
/docs/rearrange.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "initial_id",
7 | "metadata": {
8 | "collapsed": true,
9 | "ExecuteTime": {
10 | "end_time": "2023-10-24T19:27:47.968939Z",
11 | "start_time": "2023-10-24T19:27:47.349182Z"
12 | }
13 | },
14 | "outputs": [
15 | {
16 | "name": "stderr",
17 | "output_type": "stream",
18 | "text": [
19 | "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
20 | "I0000 00:00:1698175667.814623 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n"
21 | ]
22 | }
23 | ],
24 | "source": [
25 | "import haliax as hax\n",
26 | "from jax.random import PRNGKey\n",
27 | "\n",
28 | "N = hax.Axis(\"N\", 8)\n",
29 | "C = hax.Axis(\"C\", 3)\n",
30 | "H = hax.Axis(\"H\", 64)\n",
31 | "W = hax.Axis(\"W\", 64)\n",
32 | "\n",
33 | "x = hax.random.uniform(PRNGKey(0), (N, C, H, W))\n",
34 | "\n"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 9,
40 | "outputs": [
41 | {
42 | "data": {
43 | "text/plain": "{'N': 8, 'C': 3, 'P': 64, 'H': 8, 'W': 8}"
44 | },
45 | "execution_count": 9,
46 | "metadata": {},
47 | "output_type": "execute_result"
48 | }
49 | ],
50 | "source": [
51 | "# split into patches\n",
52 | "hax.rearrange(x, \"N C (ph H) (pw W) -> N C (P: ph pw) H W\", ph=8, pw=8)\n",
53 | "# order agnostic\n",
54 | "hax.rearrange(x, \"{(H: ph H) (W: pw W)} -> ... (P: ph pw) H W\", ph=8, pw=8)"
55 | ],
56 | "metadata": {
57 | "collapsed": false,
58 | "ExecuteTime": {
59 | "end_time": "2023-10-24T19:30:14.757668Z",
60 | "start_time": "2023-10-24T19:30:14.753994Z"
61 | }
62 | },
63 | "id": "f0178b7a92a41783"
64 | }
65 | ],
66 | "metadata": {
67 | "kernelspec": {
68 | "display_name": "Python 3",
69 | "language": "python",
70 | "name": "python3"
71 | },
72 | "language_info": {
73 | "codemirror_mode": {
74 | "name": "ipython",
75 | "version": 2
76 | },
77 | "file_extension": ".py",
78 | "mimetype": "text/x-python",
79 | "name": "python",
80 | "nbconvert_exporter": "python",
81 | "pygments_lexer": "ipython2",
82 | "version": "2.7.6"
83 | }
84 | },
85 | "nbformat": 4,
86 | "nbformat_minor": 5
87 | }
88 |
--------------------------------------------------------------------------------
/docs/rearrange.md:
--------------------------------------------------------------------------------
1 | # Rearrange
2 |
3 | ## Introduction
4 |
5 | Haliax strives to be "order independent": the order of axes should not impact the correctness of the program. However,
6 | when interfacing with non-named APIs (e.g. the JAX Numpy API or Equinox), you have to be able to know exactly what the
7 | order of axes is. In addition, the order of axes can impact performance. To that end, Haliax provides a `rearrange`
8 | function that allows you to specify the order of axes in a tensor.
9 |
10 | In addition, it is sometimes necessary to split and merge axes: turning images into patches,
11 | or turning a batch of images into a single image. Without `rearrange`, this is a bit clunky.
12 |
13 | `rearrange` comes in two flavors: sequence syntax and einops-style syntax. Sequence
14 | syntax is just for transposing axes, while [einops-style](https://einops.rocks/) syntax is more
15 | powerful and can also split and merge axes.
16 |
17 | ## Sequence Syntax
18 |
19 | The sequence syntax is very simple: you just provide a sequence of axis names, and the tensor
20 | will be rearranged to match that sequence. For example:
21 |
22 | ```python
23 | import haliax as hax
24 | import jax.random as jrandom
25 |
26 | N = hax.Axis("N", 32)
27 | C = hax.Axis("C", 3)
28 | H = hax.Axis("H", 64)
29 | W = hax.Axis("W", 64)
30 |
31 | x = hax.random.normal(jrandom.PRNGKey(0), (N, C, H, W))
32 |
33 | y = hax.rearrange(x, (N, H, W, C))
34 |
35 | # at most one ellipsis is allowed
36 | z = hax.rearrange(x, (N, ..., C))
37 |
38 | # you can use strings instead of axis objects
39 | z = hax.rearrange(x, ("N", ..., "C"))
40 | ```
41 |
42 | As we said before, almost all Haliax operations are order-agnostic, so (this version of) `rearrange` only impacts
43 | performance and allows you to interface with other libraries that need you to specify the order of axes
44 | for an unnamed tensor.
45 |
46 | ## Einops-style Syntax
47 |
48 | [einops](https://einops.rocks/) is a powerful library for manipulating tensor shapes, generalizing
49 | `reshape`, `transpose`, and other shape-manipulation operations. Haliax provides a subset of its functionality
50 | (specifically `rearrange` and not `repeat` or `reduce`, which are less useful in named code). The syntax has been generalized to named
51 | tensors.
52 |
53 | If you're used to einops, the syntax should be familiar, with the main differences being specifying names
54 | and the additional "unordered" syntax for selecting dimensions by name.
55 |
56 | !!! warning
57 |
58 | This syntax is fairly new. It is pretty well-tested, but it is possible that there are bugs.
59 |
60 | ### Examples
61 |
62 | Examples are probably the best way to get a feel for the syntax:
63 |
64 | ```python
65 | import haliax as hax
66 | import jax.random as jrandom
67 |
68 | N = hax.Axis("N", 32)
69 | C = hax.Axis("C", 3)
70 | H = hax.Axis("H", 64)
71 | W = hax.Axis("W", 64)
72 |
73 | x = hax.random.normal(jrandom.PRNGKey(0), (N, C, H, W))
74 |
75 | # transpose/permute axes
76 | y = hax.rearrange(x, "N C H W -> N H W C")
77 | # names don't have to match with positional syntax
78 | z = hax.rearrange(x, "num ch h w -> num h w ch")
79 | # ellipsis can be used to specify the rest of the dimensions
80 | z = hax.rearrange(x, "N C ... -> N ... C")
81 |
82 | # unordered patterns allow you to match a subset of dimensions by name, rather than using positional matching
83 | # transpose last two dimensions using the unordered syntax
84 | y = hax.rearrange(x, "{H W} -> ... W H")
85 |
86 | # don't know the order? use an unordered pattern
87 | y = hax.rearrange(x, "{W C H N} -> N H W C")
88 |
89 | # split dims as in einops
90 | y = hax.rearrange(x, "(step microbatch) ... -> step microbatch ...", step=4)
91 | # splitting dims can be done using unordered syntax, similar to positional syntax
92 | y = hax.rearrange(x, "{(N: step microbatch) ...} -> step microbatch ...", step=4)
93 |
94 | # merging dims requires a name
95 | x = hax.rearrange(y, "step microbatch ... -> (N: step microbatch) ...")
96 |
97 | # you can partially specify the order by using two or more ellipses on the rhs
98 | y = hax.rearrange(x, "{H W} -> ... (F: H W) ...")
99 | y = hax.rearrange(x, "{H W C} -> ... (F: H W) ... C") # ensures C is the last dimension
100 |
101 |
102 | # some fancier examples
103 |
104 | # split into patches
105 | y = hax.rearrange(x, "N C (patch_h H) (patch_w W) -> N (P: patch_h patch_w) C H W", H=4, W=4)
106 | # unordered version
107 | y = hax.rearrange(x, "{(H: patch_h H) (W: patch_w W) C } -> ... (P: patch_h patch_w) C H W", H=4, W=4)
108 |
109 | # split into patches, then merge patches and channels
110 | y = hax.rearrange(x, "N C (patch_h H) (patch_w W) -> N (P: patch_h patch_w) (C: C H W)", H=4, W=4)
111 | # unordered version
112 | y = hax.rearrange(x, "{(H: patch_h H) (W: patch_w W) C } -> ... (P: patch_h patch_w) (C: C H W)", H=4, W=4)
113 | ```
114 |
115 | ### Bindings: Aliasing and Sizing
116 |
117 | In the above examples we used keyword arguments to give sizes to split axes, which is the same
118 | as in einops. However, we can also use bindings to alias axes. For example:
119 |
120 | ```python
121 | # this produces the same result as the previous example
122 | y2 = hax.rearrange(x, "N C (patch_h foo) (patch_w bar) -> N (P: patch_h patch_w) (C: C foo bar)", foo=hax.Axis("H", 4), bar=hax.Axis("W", 4))
123 | assert y.axes == y2.axes
124 | ```
125 |
126 | This example is a bit contrived, but the point is that this syntax lets us use shorter or different names in the string,
127 | which is occasionally useful.
128 |
129 | You can actually pass in a string alias instead of an axis object, and it will be converted to an axis object for you:
130 | For instance, if we wanted "P" to actually be called "patch", but wanted to keep the short syntax, we could do:
131 |
132 | ```python
133 | y3 = hax.rearrange(x, "N C (nh ph) (nw pw) -> N (P: nh nw) (C: C ph pw)", P="patch", pw=4, ph=4)
134 | ```
135 |
136 |
137 | ### Differences from einops
138 |
139 | As you may have noticed, there are some differences from einops:
140 |
141 | * Merged axes must have a name: `(C: C H W)` instead of `(C H W)`.
142 | * The unordered syntax with `{ }` is new: you select dimensions by name instead of by position.
143 | * As discussed immediately above, you can use bindings to specify axis objects for names as well as sizes.
144 |
145 | ### Syntax
146 |
147 | Semiformally, the syntax is an `lhs -> rhs` pair, where the `lhs` is either ordered or unordered, and the `rhs` is always ordered.
148 | For the `lhs`:
149 |
150 | * An *ordered lhs* is a sequence of axis variables (e.g. `x`) or (named or anonymous) split axes (e.g. `(x y)`), separated by spaces or commas, and up to one ellipsis
151 | * An *unordered lhs* is a sequence of axis names (e.g. `x`, where `x` is an axis name in the input array) or named split axes (e.g. `(x: y z)`), surrounded by `{}`, separated by spaces or commas.
152 |
153 | * An *axis variable* is an identifier (that need not correspond to an axis name in the input or output.)
154 | * An *axis name* is an identifier that corresponds to an axis name in the input or output.
155 | * An *anonymous split axis* is a parenthesized expression of the form `(ident*)`, e.g. `(N C)`.
156 | * A *named split axis* is a parenthesized expression of the form `(name: ident*)`, where `name` is the name of an axis and the `ident` are axis variable names, e.g. `(N: s mb)`
157 |
158 | A note on "axis variable" vs "axis name": the former is an identifier that can refer to any axis and is matched
159 | by **position** in the input, while the latter is an identifier that refers to a specific axis and is matched by **name** in the input
160 | (or used to name an axis in the output).
161 |
162 | The `rhs` is always ordered. Its syntax is similar to an ordered `lhs`, except that merged axes must always be named and there may be more than one ellipsis.
163 |
164 | * *rhs* is a sequence of axis variable names or named merged axes, separated by spaces or commas, and some number of ellipses.
165 |
166 | * *Named merged axes* are parenthesized expressions of the form `(name: ident ident ...)`, where `name` is an axis name and `ident` is an identifier.
167 | The name is used to name the merged axis in the output, and the `ident` are axis variable names that are merged from the input.
168 |
169 | Identifiers in the `rhs` must be "bound" by an identifier in the `lhs`, that is, they must appear in the `lhs` as an *axis variable name*.
170 |
171 | As in einops: split and merged axes are processed in "C-order": the first dimension is the most significant, and the
172 | last dimension is the least significant.
173 |
174 |
175 | ## Error Handling
176 |
177 | `rearrange` attempts to be as helpful as possible when it encounters errors. For example:
178 |
179 | ```python
180 | y = hax.rearrange(x, "N C H W Z -> N H W C")
181 | # ValueError: Error while parsing:
182 | # N C H W Z -> N H W C
183 | # ^
184 | # Too many axes in lhs
185 | ```
186 |
187 | In general, it will try to give you a helpful error message that points to the problem in the string.
188 |
189 |
190 | ## API Documentation
191 |
192 | See [haliax.rearrange][].
193 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | mkdocs==1.6.1
2 | mkdocs-autorefs==1.4.1
3 | mkdocs-bootswatch==1.1
4 | mkdocs-gen-files==0.5.0
5 | mkdocs-get-deps==0.2.0
6 | mkdocs-include-markdown-plugin==7.1.5
7 | mkdocs-literate-nav==0.6.1
8 | mkdocs-macros-plugin==1.3.7
9 | mkdocs-material==9.6.7
10 | mkdocs-material-extensions==1.3.1
11 | mkdocstrings==0.29.0
12 | mkdocstrings-python==1.16.4
13 |
--------------------------------------------------------------------------------
/docs/tutorial.md:
--------------------------------------------------------------------------------
1 | Haliax's tutorials are hosted on Google Colab. You can run the tutorials in your browser without installing anything.
2 |
3 | Here are our current tutorials:
4 |
5 | {%
6 | include-markdown "../README.md"
7 | start=""
8 | end=""
9 | %}
10 |
--------------------------------------------------------------------------------
/docs/vmap.md:
--------------------------------------------------------------------------------
1 | ## Vectorization
2 |
3 | (This is a work in progress. Please contact dlwh for more information.)
4 |
5 | ::: haliax.vmap
6 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Haliax
2 | repo_url: https://github.com/stanford-crfm/haliax/
3 | edit_uri: blob/main/docs/
4 | theme:
5 | name: material
6 | highlightjs: false
7 | features:
8 | - content.code.copy
9 | markdown_extensions:
10 | - attr_list
11 | - admonition
12 | #- callouts
13 | - footnotes
14 | - codehilite
15 | - pymdownx.details # Allowing hidden expandable regions denoted by ???
16 | - pymdownx.magiclink
17 | - pymdownx.superfences
18 | - pymdownx.arithmatex: # Render LaTeX via MathJax
19 | generic: true
20 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme.
21 | - pymdownx.snippets: # Include one Markdown file into another
22 | base_path: docs
23 | - pymdownx.inlinehilite
24 | - pymdownx.snippets:
25 | check_paths: true
26 | - pymdownx.superfences
27 | - toc:
28 | permalink: "¤"
29 | toc_depth: "2-3"
30 |
31 | plugins:
32 | - search
33 | - autorefs
34 | - mkdocstrings:
35 | handlers:
36 | python:
37 | # setup_commands:
38 | # - import pytkdocs_tweaks
39 | # - pytkdocs_tweaks.main()
40 | paths: [src]
41 | import:
42 | - https://docs.python.org/3/objects.inv
43 | - https://jax.readthedocs.io/en/latest/objects.inv
44 | - https://docs.kidger.site/equinox/objects.inv
45 | - https://einops.rocks/objects.inv
46 | options:
47 | docstring_options:
48 | ignore_init_summary: true
49 | show_source: false
50 | filters:
51 | - "!^_"
52 | heading_level: 4
53 | inherited_members: true
54 | members_order: source
55 | merge_init_into_class: true
56 | parameter_headings: true
57 | separate_signature: false
58 | load_external_modules: true
59 | preload_modules: [haliax, haliax.core]
60 | show_if_no_docstring: true
61 | show_root_heading: true
62 | show_root_full_path: false
63 | show_signature_annotations: true
64 | docstring_section_style: list
65 | show_symbol_type_heading: true
66 | show_symbol_type_toc: false
67 | signature_crossrefs: true
68 | line_length: 100
69 | summary: true
70 |
71 | - include-markdown
72 | extra_css:
73 | - css/material.css
74 | - css/mkdocstrings.css
75 |
76 |
77 | watch:
78 | - src
79 | - docs
80 | nav:
81 | - Home: 'index.md'
82 | - Tutorials:
83 | - "Introduction to Haliax": https://colab.research.google.com/drive/1TiTcQQ4V5mopbgCu1SVl-oqJtXn7rFnC
84 | - "Distributed Training and FSDP": https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz
85 | - "Tensor Parallelism": https://colab.research.google.com/drive/18_BrtDpe1lu89M4T6fKzda8DdSLtFJhi
86 | - "Mixed Precision with `jmp`": https://colab.research.google.com/drive/1_4cikwt-UhSH7yRzNRK8ze9msM9r2mEl?usp=sharing
87 | - Cheatsheet: 'cheatsheet.md'
88 | - Named Arrays:
89 | - Broadcasting: 'broadcasting.md'
90 | - Indexing and Slicing: 'indexing.md'
91 | - Rearrange: 'rearrange.md'
92 | - Matrix Multiplication: 'matmul.md'
93 | - Higher Order Functions:
94 | - Scan and Fold: 'scan.md'
95 | - Vectorization: 'vmap.md'
96 | - Neural Networks: 'nn.md'
97 | - Partitioning: 'partitioning.md'
98 | - FP8: 'fp8.md'
99 | - Serialization: 'state-dict.md'
100 | - API Reference: 'api.md'
101 | - FAQ: 'faq.md'
102 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "haliax"
7 | # Set for release builds
8 | #version = "1.3"
9 | authors = [
10 | { name="David Hall", email="dlwh@cs.stanford.edu" },
11 | ]
12 | description = "Named Tensors for Legible Deep Learning in JAX"
13 | readme = "README.md"
14 | requires-python = ">=3.10"
15 | classifiers = [
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: Apache Software License",
18 | "Operating System :: POSIX :: Linux",
19 | "Operating System :: MacOS :: MacOS X",
20 | "Development Status :: 4 - Beta",
21 | "Intended Audience :: Science/Research",
22 | ]
23 | dependencies = [
24 | # we require that you install jax yourself, since the extras vary by system.
25 | # jax = {version = ">=0.4.19,<0.5.0"}
26 | "equinox>=0.10.6",
27 | "jaxtyping>=0.2.20",
28 | "jmp>=0.0.4",
29 | "safetensors>=0.4.3",
30 | "aqtp>=0.8.2",
31 | ]
32 | dynamic =[ "version" ]
33 |
34 | [project.optional-dependencies]
35 | dev=["pytest >= 7.4.0", "mypy >= 0.910", "mkdocs >= 1.4.3", "mkdocs-material >= 7.3.3", "mkdocstrings >= 0.22.0",
36 | "mkdocs-literate-nav >= 0.6.0", "mkdocs-macros-plugin >= 0.7.0", "mkdocstrings-python >= 1.1.2",
37 | "mkdocs-include-markdown-plugin",
38 | "pymdown-extensions",
39 | "pygments",
40 | "pymdown-extensions",
41 | "chex>=0.1.86"
42 | ]
43 |
44 |
45 | [tool.hatch.version]
46 | path = "src/haliax/__about__.py"
47 |
48 | [options]
49 | packages = ["haliax", "haliax.*"]
50 |
51 | [options.package_data]
52 | haliax = ["src/haliax/*"]
53 |
54 | [tool.black]
55 | line-length = 119
56 | target-version = ["py310"]
57 | preview = true
58 |
59 | [tool.isort]
60 | profile = "black"
61 | multi_line_output = 3
62 | lines_after_imports = 2
63 | include_trailing_comma = true
64 | force_grid_wrap = 0
65 | use_parentheses = true
66 | ensure_newline_before_comments = true
67 | line_length = 119
68 | src_paths = ["src", "tests"]
69 |
70 | [project.urls]
71 | "Homepage" = "https://github.com/stanford-crfm/haliax"
72 | "Bug Tracker" = "https://github.com/stanford-crfm/haliax/issues/"
73 | "Documentation" = "https://haliax.readthedocs.io/en/latest/"
74 |
--------------------------------------------------------------------------------
/src/haliax/__about__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.4"
2 |
--------------------------------------------------------------------------------
/src/haliax/_src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/src/haliax/_src/__init__.py
--------------------------------------------------------------------------------
/src/haliax/_src/compile_utils.py:
--------------------------------------------------------------------------------
1 | # This whole file is copied from Equinox.
2 | # (c) 2023, Google LLC. and/or Patrick Kidger. Apache 2.0 licensed.
3 | # Patrick doesn't like that I depend on Equinox internals, so I copied this stuff
4 | import functools as ft
5 | import types
6 | import warnings
7 | import weakref
8 | from typing import Any
9 |
10 | import jax.tree_util as jtu
11 |
12 |
13 | def _strip_wrapped_partial(fun):
14 | if hasattr(fun, "__wrapped__"): # ft.wraps
15 | return _strip_wrapped_partial(fun.__wrapped__)
16 | if isinstance(fun, ft.partial):
17 | return _strip_wrapped_partial(fun.func)
18 | return fun
19 |
20 |
21 | internal_caches = [] # type: ignore
22 | internal_lru_caches = [] # type: ignore
23 |
24 |
25 | def clear_caches():
26 | """Clears internal Equinox caches.
27 |
28 | Best used before calling `jax.clear_caches()` or `jax.clear_backends()`.
29 |
30 | **Arguments:**
31 |
32 | None.
33 |
34 | **Returns:**
35 |
36 | None.
37 | """
38 | for cache in internal_caches:
39 | cache.clear()
40 | for cache in internal_lru_caches:
41 | cache.cache_clear()
42 |
43 |
44 | def _default_cache_key():
45 | assert False
46 |
47 |
48 | def compile_cache(cacheable_fn):
49 | cache = weakref.WeakKeyDictionary() # type: ignore
50 | internal_caches.append(cache)
51 |
52 | def cached_fn_impl(leaves, treedef):
53 | user_fn_names, args, kwargs = jtu.tree_unflatten(treedef, leaves)
54 | return cacheable_fn(user_fn_names, *args, **kwargs)
55 |
56 | @ft.wraps(cacheable_fn)
57 | def wrapped_cacheable_fn(user_fn, *args, **kwargs):
58 | user_fn = _strip_wrapped_partial(user_fn)
59 | # Best-effort attempt to clear the cache of old and unused entries.
60 | cache_key: Any
61 | if type(user_fn) is types.FunctionType: # noqa: E721
62 | cache_key = user_fn
63 | else:
64 | cache_key = _default_cache_key
65 |
66 | try:
67 | user_fn_names = user_fn.__name__, user_fn.__qualname__
68 | except AttributeError:
69 | user_fn_names = type(user_fn).__name__, type(user_fn).__qualname__
70 | leaves, treedef = jtu.tree_flatten((user_fn_names, args, kwargs))
71 | leaves = tuple(leaves)
72 |
73 | try:
74 | cached_fn = cache[cache_key]
75 | except KeyError:
76 | cached_fn = cache[cache_key] = ft.lru_cache(maxsize=None)(cached_fn_impl)
77 | return cached_fn(leaves, treedef)
78 |
79 | def delete(user_fn):
80 | user_fn = _strip_wrapped_partial(user_fn)
81 | if type(user_fn) is types.FunctionType: # noqa: E721
82 | try:
83 | del cache[user_fn]
84 | except KeyError:
85 | warnings.warn(f"Could not delete cache for function {user_fn}. Has it already been deleted?")
86 | else:
87 | warnings.warn("Could not delete non-function from cache.")
88 |
89 | wrapped_cacheable_fn.delete = delete # type: ignore
90 | return wrapped_cacheable_fn
91 |
--------------------------------------------------------------------------------
/src/haliax/_src/dot.py:
--------------------------------------------------------------------------------
1 | import functools as ft
2 | import typing
3 | import warnings
4 | from typing import Dict, Optional, Tuple
5 |
6 | import jax
7 | import jax.numpy as jnp
8 |
9 | import haliax
10 | from haliax.axis import (
11 | Axis,
12 | AxisSelection,
13 | PartialAxisSpec,
14 | axis_name,
15 | eliminate_axes,
16 | rearrange_for_partial_order,
17 | union_axes,
18 | )
19 | from haliax.core import NamedArray
20 | from haliax.jax_utils import _jittable_dg_einsum
21 | from haliax.types import DTypeLike, PrecisionLike
22 | from haliax.util import ensure_tuple
23 |
24 |
25 | # deprecated overload
26 | @typing.overload
27 | def dot(
28 | axis: Optional[AxisSelection],
29 | *arrays: NamedArray,
30 | precision: PrecisionLike = None,
31 | preferred_element_type: Optional[DTypeLike] = None,
32 | out_axes: Optional[PartialAxisSpec] = ...,
33 | dot_general=jax.lax.dot_general,
34 | ) -> NamedArray:
35 | ...
36 |
37 |
38 | @typing.overload
39 | def dot(
40 | *arrays: NamedArray,
41 | axis: Optional[AxisSelection],
42 | precision: PrecisionLike = None,
43 | preferred_element_type: Optional[DTypeLike] = None,
44 | out_axes: Optional[PartialAxisSpec] = ...,
45 | dot_general=jax.lax.dot_general,
46 | ) -> NamedArray:
47 | ...
48 |
49 |
50 | def dot(
51 | *arrays,
52 | precision: PrecisionLike = None,
53 | preferred_element_type: Optional[DTypeLike] = None,
54 | out_axes: Optional[PartialAxisSpec] = None,
55 | dot_general=jax.lax.dot_general,
56 | **kwargs,
57 | ) -> NamedArray:
58 | """Returns the tensor product of two NamedArrays. The axes `axis` are contracted over,
59 | and any other axes that are shared between the arrays are batched over. Non-contracted Axes in one
60 | that are not in the other are preserved.
61 |
62 | Note that if `axis` is None, the result will be a scalar, not a NamedArray. The semantics of `axis=None` are
63 | similar to e.g. how `sum` and other reduction functions work in numpy. If `axis=()`, then the result will be
64 | an "outer product" of the arrays, i.e. a tensor with shape equal to the concatenation of the shapes of the arrays.
65 |
66 | By default, the order of output axes is determined by the order of the input axes, such that each output axis
67 | occurs in the same order as it first occurs in the concatenation of the input axes.
68 |
69 | If `out_axes` is provided, the output will be transposed to match the provided axes. `out_axes` may be a partial
70 | specification of the output axes (using ellipses), in which case the output will be rearranged to be consistent
71 | with the partial specification. For example, if `out_axes=(..., Height, Width)` and the output axes are
72 | `(Width, Height, Depth)`, the output will be transposed to `(Depth, Height, Width)`. Multiple ellipses
73 | are supported, in which case axes will be inserted according to a greedy heuristic that prefers to place
74 | unconstrained axes as soon as all prior axes in the "natural" order are covered.
75 |
76 | Args:
77 | *arrays (NamedArray): The arrays to contract.
78 | axis (AxisSelection): The axes to contract over.
79 | precision (PrecisionLike, optional): The precision to use. Defaults to None. This argument is passed to `jax.numpy.einsum`,
80 | which in turn passes it to jax.lax.dot_general.
81 | preferred_element_type (DTypeLike, optional): The preferred element type of the result. Defaults to None.
82 | This argument is passed to `jax.numpy.einsum`.
83 | out_axes (Optional[PartialAxisSpec], optional): a potentially partial specification of the output axes.
84 | If provided, the output will be transposed to match the provided axes. Defaults to None.
85 |
86 |
87 | Returns:
88 | NamedArray: The result of the contraction.
89 | """
90 | if len(arrays) == 0:
91 | raise ValueError("Must provide at least one array to dot")
92 |
93 | if "axis" in kwargs:
94 | axis = kwargs["axis"]
95 | else:
96 | axis = arrays[0]
97 | arrays = arrays[1:]
98 | if isinstance(axis, NamedArray):
99 | raise ValueError("Must provide an axis to dot")
100 |
101 | warnings.warn("Axis has been changed to a keyword argument. Please update your code.", DeprecationWarning)
102 |
103 | _ensure_no_mismatched_axes(*arrays)
104 |
105 | # to call dot_general we need two things:
106 | # list of contractions and list of arrays
107 |
108 | all_axes: Tuple[Axis, ...] = ft.reduce(union_axes, (a.axes for a in arrays), ()) # type: ignore
109 | output_axes: Tuple[Axis, ...]
110 | if axis is None:
111 | # we want to contract over all the axes
112 | output_axes = ()
113 | else:
114 | output_axes = eliminate_axes(all_axes, axis)
115 |
116 | if out_axes is not None:
117 | output_axes = rearrange_for_partial_order(out_axes, output_axes)
118 |
119 | array_specs = []
120 |
121 | next_index = 0
122 | axis_mappings: Dict[str, int] = {}
123 |
124 | for a in arrays:
125 | spec = ""
126 | for ax in a.axes:
127 | if ax.name in axis_mappings:
128 | spec += f"{axis_mappings[ax.name]} "
129 | else:
130 | axis_mappings[ax.name] = next_index
131 | spec += f"{next_index} "
132 | next_index += 1
133 |
134 | array_specs.append(spec)
135 |
136 | # now compute the output axes:
137 | output_spec = " ".join(str(axis_mappings[ax.name]) for ax in output_axes)
138 |
139 | # get a name for jax so it's easier to interpret logs
140 | if axis is None:
141 | jax_str = f"contract {', '.join(axis_name(ax) for ax in all_axes)} -> "
142 | else:
143 | axis = ensure_tuple(axis)
144 | jax_str = f"contract {', '.join(axis_name(ax) for ax in axis)} -> {', '.join(a.name for a in output_axes)}"
145 |
146 | with jax.named_scope(jax_str):
147 | output = _jittable_dg_einsum(
148 | ", ".join(array_specs) + "-> " + output_spec,
149 | *[a.array for a in arrays],
150 | precision=precision,
151 | preferred_element_type=preferred_element_type,
152 | _dot_general=dot_general,
153 | )
154 |
155 | out = NamedArray(output, output_axes)
156 | return haliax.auto_sharded(out)
157 |
158 |
159 | def _ensure_no_mismatched_axes(*arrays: NamedArray):
160 | """Ensure that all the arrays have no axes with the same name but different sizes"""
161 | if len(arrays) <= 1:
162 | return
163 |
164 | known_sizes: dict[str, int] = {}
165 | for a in arrays:
166 | for ax in a.axes:
167 | if ax.name in known_sizes:
168 | if known_sizes[ax.name] != ax.size:
169 | raise ValueError(f"Axis {ax.name} has multiple sizes: {known_sizes[ax.name]} and {ax.size}")
170 | else:
171 | known_sizes[ax.name] = ax.size
172 |
--------------------------------------------------------------------------------
/src/haliax/_src/fp8.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from functools import partial
3 |
4 | from jax import custom_jvp, custom_vjp, lax
5 | from jax import numpy as jnp
6 |
7 |
8 | # All of this is copy paste from flax/linen/fp8_ops.py
9 | # (Until we get to the module)
10 |
11 | # Copyright 2024 The Flax Authors.
12 | #
13 | # Licensed under the Apache License, Version 2.0 (the "License");
14 | # you may not use this file except in compliance with the License.
15 | # You may obtain a copy of the License at
16 | #
17 | # http://www.apache.org/licenses/LICENSE-2.0
18 | #
19 | # Unless required by applicable law or agreed to in writing, software
20 | # distributed under the License is distributed on an "AS IS" BASIS,
21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22 | # See the License for the specific language governing permissions and
23 | # limitations under the License.
24 |
25 |
26 | def quantize_dequantize(x, q_dtype, scale, compute_dtype):
27 | qx = quantize(x, q_dtype, scale, compute_dtype)
28 | return dequantize(qx, x.dtype, scale)
29 |
30 |
31 | def get_fp8_max(fp8_dtype, out_dtype):
32 | assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2)
33 | return jnp.finfo(fp8_dtype).max.astype(out_dtype)
34 |
35 |
36 | def quantize(x, q_dtype, scale, compute_dtype):
37 | # Explicitly cast the max values to the compute dtype to avoid unnecessary
38 | # casting to FP32 during the subsequent math operations."
39 | dtype_max = get_fp8_max(q_dtype, compute_dtype)
40 | scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape)
41 | clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
42 | return clipped_x.astype(q_dtype)
43 |
44 |
45 | def dequantize(x, dq_dtype, scale):
46 | return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape)
47 |
48 |
49 | def compute_scale(amax, scale, fp8_max, margin=0):
50 | # The algorithm for computing the new scale is sourced from
51 | # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas
52 | # wherein the `original_scale` corresponds to the reciprocal of the `scale`
53 | # passed in this function.
54 | scale = 1.0 / scale
55 |
56 | sf = (fp8_max / amax) / (2**margin)
57 | sf = jnp.where(amax > 0.0, sf, scale)
58 | sf = jnp.where(jnp.isfinite(amax), sf, scale)
59 |
60 | return 1.0 / sf
61 |
62 |
63 | def compute_amax_history(x, amax_history):
64 | amax_update = jnp.max(jnp.abs(x)).astype(amax_history.dtype)
65 | new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update)
66 | return new_history
67 |
68 |
69 | def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):
70 | dtype_max = get_fp8_max(q_dtype, jnp.float32)
71 | amax_from_history = jnp.max(amax_history, axis=0)
72 | new_scale = compute_scale(amax_from_history, scale, dtype_max)
73 |
74 | qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)
75 |
76 | new_history = compute_amax_history(x, amax_history)
77 |
78 | return qx, new_scale, new_history
79 |
80 |
81 | @partial(custom_vjp, nondiff_argnums=(0,))
82 | def in_qdq(compute_dtype, inp, scale, amax_history):
83 | qin, _, _ = qdq_and_return(inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype)
84 | return qin
85 |
86 |
87 | def in_qdq_fwd(compute_dtype, inp, scale, amax_history):
88 | qin, new_scale, new_history = qdq_and_return(inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype)
89 | return qin, (new_scale, new_history)
90 |
91 |
92 | def in_qdq_bwd(compute_dtype, res, g):
93 | new_scale, new_history = res
94 | q_g = g
95 | return q_g, new_scale, new_history
96 |
97 |
98 | in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd)
99 |
100 |
101 | @partial(custom_vjp, nondiff_argnums=(0,))
102 | def out_qdq(compute_dtype, out, scale, amax_history):
103 | return out
104 |
105 |
106 | def out_qdq_fwd(compute_dtype, out, scale, amax_history):
107 | return out, (scale, amax_history)
108 |
109 |
110 | def out_qdq_bwd(compute_dtype, res, g):
111 | scale, amax_history = res
112 | q_g, new_scale, new_history = qdq_and_return(g, jnp.float8_e5m2, scale, amax_history, compute_dtype)
113 | return q_g, new_scale, new_history
114 |
115 |
116 | out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)
117 |
118 |
119 | @partial(custom_jvp, nondiff_argnums=(2, 3, 4))
120 | def dot_general_with_precision(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
121 | if precision is not None or preferred_element_type is not None:
122 | # einsum sets preferred_element_type and so this is just noisy
123 | # warnings.warn(
124 | # "The function dot_general_with_precision will set the "
125 | # "precision/preferred_element_type and disregard any provided "
126 | # "values."
127 | # )
128 | pass
129 | return lax.dot_general(lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT)
130 |
131 |
132 | @dot_general_with_precision.defjvp
133 | def dot_general_with_precision_jvp(dimension_numbers, precision, preferred_element_type, primals, tangents):
134 | lhs, rhs = primals
135 | lhs_dot, rhs_dot = tangents
136 |
137 | out = lax.dot_general(lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT)
138 | grad_out = lax.dot_general(lhs_dot, rhs, dimension_numbers, precision=lax.Precision.HIGHEST) + lax.dot_general(
139 | lhs, rhs_dot, dimension_numbers, precision=lax.Precision.HIGHEST
140 | )
141 | return out, grad_out
142 |
--------------------------------------------------------------------------------
/src/haliax/_src/util.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, MutableMapping, Sequence, TypeAlias, TypeVar
2 |
3 |
4 | T = TypeVar("T")
5 | U = TypeVar("U")
6 | py_slice = slice
7 | slice_t: TypeAlias = slice
8 |
9 |
10 | def index_where(pred: Callable[[T], bool], xs: Sequence[T], start: int = 0) -> int:
11 | for i in range(start, len(xs)):
12 | if pred(xs[i]):
13 | return i
14 | raise ValueError("No element satisfies predicate")
15 |
16 |
17 | class IdentityMap(MutableMapping[T, U]):
18 | """Map that compares keys by identity.
19 |
20 | This is a map that compares keys by identity instead of equality. It is
21 | useful for storing objects that are not hashable or that should be compared
22 | by identity.
23 |
24 | This is a mutable mapping, but it does not support the ``__hash__`` method
25 | and therefore cannot be used as a dictionary key or as an element of another
26 | set.
27 | """
28 |
29 | def __init__(self, iterable=None):
30 | self._data = {}
31 | if iterable is not None:
32 | self.update(iterable)
33 |
34 | def __contains__(self, key):
35 | return id(key) in self._data
36 |
37 | def __getitem__(self, key):
38 | return self._data[id(key)][1]
39 |
40 | def __setitem__(self, key, value):
41 | self._data[id(key)] = [key, value]
42 |
43 | def __delitem__(self, key):
44 | del self._data[id(key)]
45 |
46 | def __iter__(self):
47 | return (x[0] for x in self._data.values())
48 |
49 | def __len__(self):
50 | return len(self._data)
51 |
52 | def __repr__(self):
53 | return f"IdentityMap({list(repr(x) for x in self._data.values())})"
54 |
55 | def __str__(self):
56 | return f"IdentityMap({list(str(x) for x in self._data.values())})"
57 |
--------------------------------------------------------------------------------
/src/haliax/debug.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import List, Tuple, Union
3 |
4 | import equinox as eqx
5 | import jax.numpy as jnp
6 | import jax.tree_util as jtu
7 |
8 | from haliax.core import NamedArray
9 | from haliax.util import is_jax_or_hax_array_like
10 |
11 | from ._src.util import IdentityMap
12 |
13 |
14 | ArrayLike = Union[jnp.ndarray, NamedArray]
15 |
16 |
17 | def describe_array(arr):
18 | if isinstance(arr, NamedArray):
19 | return f"NamedArray(axes={arr.axes}, dtype={arr.dtype})"
20 | else:
21 | return f"ndarray(shape={arr.shape}, dtype={arr.dtype})"
22 |
23 |
24 | class ModuleProblems(Exception):
25 | def __init__(self):
26 | self.reused_arrays: List[Tuple[ArrayLike, List]] = []
27 | self.static_arrays: List[str] = []
28 |
29 | def __bool__(self):
30 | return bool(self.reused_arrays or self.static_arrays)
31 |
32 | def __str__(self):
33 | if not self:
34 | return "No problems found"
35 | else:
36 | return "\n".join(
37 | [
38 | "Found some problems with your module:",
39 | *self._format_reused_arrays(),
40 | *self._format_static_arrays(),
41 | ]
42 | )
43 |
44 | def _format_reused_arrays(self):
45 | return [f" Reused array {describe_array(arr)} at paths {paths}" for arr, paths in self.reused_arrays]
46 |
47 | def _format_static_arrays(self):
48 | return [f" Static array at field {field}" for field in self.static_arrays]
49 |
50 |
51 | def diagnose_common_issues(module: eqx.Module):
52 | """
53 | Checks for common issues in a module, such as reused arrays and static arrays.
54 | Equinox modules (and therefore Haliax modules) should not have arrays that are stored
55 | in multiple places, and should not have arrays stored as static fields.
56 |
57 | We'll add more checks here as we find them.
58 |
59 | Args:
60 | module: The module to check for problems
61 |
62 | Returns:
63 | None
64 |
65 | Raises:
66 | ModuleProblems: if any problems are found
67 |
68 | """
69 |
70 | problems = ModuleProblems()
71 | _check_for_reused_arrays(problems, module)
72 | _check_for_static_arrays(problems, module)
73 |
74 | if problems:
75 | raise problems
76 |
77 | # just in case we missed anything, raise equinox's errors:
78 | eqx.tree_check(module)
79 |
80 |
81 | def _check_for_reused_arrays(problems, module):
82 | used_arrays = IdentityMap[ArrayLike, List[str]]()
83 |
84 | path_leaves, _ = jtu.tree_flatten_with_path(module, is_leaf=is_jax_or_hax_array_like)
85 |
86 | for path, leaf in path_leaves:
87 | if is_jax_or_hax_array_like(leaf):
88 | if leaf in used_arrays:
89 | used_arrays[leaf].append(jtu.keystr(path))
90 | else:
91 | used_arrays[leaf] = [jtu.keystr(path)]
92 |
93 | for arr, paths in used_arrays.items():
94 | if len(paths) > 1:
95 | problems.reused_arrays.append((arr, paths))
96 |
97 |
98 | def _check_for_static_arrays(problems, module):
99 | static_arrays = []
100 |
101 | def recurse(module, path):
102 | if isinstance(module, eqx.Module):
103 | for field in dataclasses.fields(module):
104 | value = getattr(module, field.name)
105 | if field.metadata.get("static", False) and is_jax_or_hax_array_like(value):
106 | static_arrays.append(f"{path}.{field.name}")
107 | else:
108 | recurse(value, f"{path}.{field.name}")
109 | else:
110 | leaves, _ = eqx.tree_flatten_one_level(module)
111 | if leaves != [module]:
112 | leaves_with_names = jtu.tree_leaves_with_path(module, is_leaf=lambda x: x in leaves)
113 | for name, leaf in leaves_with_names:
114 | recurse(leaf, f"{path}{name}")
115 |
116 | recurse(module, "")
117 |
118 | if static_arrays:
119 | problems.static_arrays.extend(static_arrays)
120 |
--------------------------------------------------------------------------------
/src/haliax/hof.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from functools import wraps
3 |
4 | import equinox as eqx
5 | import jax
6 | from jaxtyping import PyTree
7 |
8 | import haliax.tree_util as htu
9 |
10 | from ._src.scan import (
11 | UnnamedAxisSpec,
12 | _infer_axis_size_from_tree,
13 | _is_passive_array,
14 | _pacify_named_arrays,
15 | _PassiveNamedArray,
16 | _prepend_named_batch_axis,
17 | _zero_if_array_else_none,
18 | fold,
19 | map,
20 | scan,
21 | )
22 | from .axis import Axis, AxisSelector, selects_axis
23 | from .core import NamedArray
24 | from .jax_utils import Static, broadcast_prefix, is_jax_array_like
25 | from .partitioning import physical_axis_name
26 | from .util import is_named_array
27 |
28 |
29 | def vmap(
30 | fn,
31 | axis: AxisSelector,
32 | *,
33 | default: PyTree[UnnamedAxisSpec] = _zero_if_array_else_none,
34 | args: PyTree[UnnamedAxisSpec] = (),
35 | kwargs: PyTree[UnnamedAxisSpec] = None,
36 | ):
37 | """
38 | [haliax.NamedArray][]-aware version of [jax.vmap][]. Normal arrays are mapped according to the specs as in
39 | [equinox.filter_vmap][]
40 |
41 | Because of NamedArrays, vmap is typically less useful than in vanilla JAX, but it is sometimes
42 | useful for initializing modules that will be scanned over. See [haliax.nn.Stacked][] for an example.
43 |
44 | Args:
45 | fn (Callable): function to vmap over
46 | axis (Axis): axis to vmap over
47 | default: how to handle (unnamed) arrays by default. Should be either an integer or None, or a callable that takes a PyTree leaf
48 | and returns an integer or None, or a PyTree prefix of the same. If an integer, the array will be mapped over that axis. If None, the array will not be mapped over.
49 | args: optional per-argument overrides for how to handle arrays. Should be a PyTree prefix of the same type as default.
50 | kwargs: optional per-keyword-argument overrides for how to handle arrays. Should be a PyTree prefix of the same type as default.
51 | """
52 |
53 | if kwargs is None:
54 | kwargs = {}
55 |
56 | signature = inspect.signature(fn)
57 |
58 | # this mirrors equinox's filter_vmap, but it's not really documented there so:
59 | # we use inspect.signature to align args/kwargs specified in vmap to what actually gets passed in
60 | # axis_spec_bound_sig's job is to hold that mapping
61 | signature_default = signature.replace(
62 | parameters=[
63 | p
64 | if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
65 | else p.replace(default=default)
66 | for p in signature.parameters.values()
67 | ]
68 | )
69 | axis_spec_bound_sig = signature_default.bind_partial(*args, **kwargs)
70 | axis_spec_bound_sig.apply_defaults()
71 | del args, kwargs
72 |
73 | def _index_of_batch_axis(array, default):
74 | if isinstance(array, NamedArray):
75 | return array._lookup_indices(axis)
76 | elif callable(default):
77 | return default(array)
78 | else:
79 | return default
80 |
81 | # TODO: tests to exercise this more
82 | @wraps(fn)
83 | def wrapped_vmap_fn(*args, **kwargs):
84 | # TODO: this probably results in a lot of compilation misses. Need to think about it.
85 | actual_bound = signature.bind(*args, **kwargs)
86 | actual_bound.apply_defaults()
87 |
88 | # now that we have args, we can figure out what the axis spec is for each arg
89 | padded_spec_args = axis_spec_bound_sig.args + (default,) * (
90 | len(actual_bound.args) - len(axis_spec_bound_sig.args)
91 | )
92 |
93 | padded_spec_kwargs = {
94 | **axis_spec_bound_sig.kwargs,
95 | **{k: default for k in actual_bound.kwargs.keys() - axis_spec_bound_sig.kwargs.keys()},
96 | }
97 |
98 | # want to support padded_spec_args being a tree prefix of the actual args, which this enables
99 | padded_spec_args = broadcast_prefix(padded_spec_args, actual_bound.args, is_leaf=is_named_array)
100 | padded_spec_kwargs = broadcast_prefix(padded_spec_kwargs, actual_bound.kwargs)
101 |
102 | arg_axis_specs = htu.tree_map(_index_of_batch_axis, actual_bound.args, padded_spec_args)
103 |
104 | kwarg_axis_specs = htu.tree_map(_index_of_batch_axis, actual_bound.kwargs, padded_spec_kwargs)
105 |
106 | # now we can actually vmap. We used "pacified" versions of NamedArrays that don't check
107 | # invariants, because intermediates creating during tracing won't have the axes right
108 | arg_axis_specs = htu.tree_map(_pacify_named_arrays, arg_axis_specs)
109 | kwarg_axis_specs = htu.tree_map(_pacify_named_arrays, kwarg_axis_specs)
110 |
111 | def wrapped_fn(args, kwargs):
112 | # the args that come in here are pacified. Their names will still have the batch axis even though the array
113 | # itself will already have that one removed. We need to turn them back into NamedArrays by removing the axis
114 | unchilled_args = jax.tree_util.tree_map(_to_unbatched_named_array(axis), args, is_leaf=_is_passive_array)
115 | unchilled_kwargs = jax.tree_util.tree_map(
116 | _to_unbatched_named_array(axis), kwargs, is_leaf=_is_passive_array
117 | )
118 |
119 | out = fn(*unchilled_args, **unchilled_kwargs)
120 |
121 | # now we need to pacify the output, which may include NamedArrays, and add the batch axis back at the end
122 | chilled = htu.tree_map(_pacify_named_arrays, out)
123 | arrays, nonarrays = eqx.partition(chilled, is_jax_array_like)
124 | return arrays, Static(nonarrays)
125 |
126 | spmd_axis_name = physical_axis_name(axis)
127 |
128 | args = htu.tree_map(_pacify_named_arrays, actual_bound.args)
129 | kwargs = htu.tree_map(_pacify_named_arrays, actual_bound.kwargs)
130 |
131 | result_dynamic, result_static = jax.vmap(
132 | wrapped_fn,
133 | in_axes=(arg_axis_specs, kwarg_axis_specs),
134 | out_axes=0,
135 | axis_size=axis.size if isinstance(axis, Axis) else None,
136 | spmd_axis_name=spmd_axis_name,
137 | )(args, kwargs)
138 |
139 | result = eqx.combine(result_dynamic, result_static.value)
140 |
141 | # if we were passed in a string arg, we need to get its axis size out from some result
142 | true_axis = _infer_axis_size_from_tree(result, axis)
143 | if true_axis is None:
144 | raise ValueError("vmap failed to infer axis size from result")
145 |
146 | result = jax.tree_util.tree_map(_prepend_named_batch_axis(true_axis), result, is_leaf=_is_passive_array)
147 | return result
148 |
149 | return wrapped_vmap_fn
150 |
151 |
152 | def _to_unbatched_named_array(axis_to_strip: AxisSelector):
153 | def to_unbatched_named_array(leaf):
154 | if isinstance(leaf, _PassiveNamedArray):
155 | if selects_axis(leaf.main_axes, axis_to_strip):
156 | return leaf.strip_axis(axis_to_strip)
157 | else:
158 | return leaf.to_named_array()
159 | else:
160 | return leaf
161 |
162 | return to_unbatched_named_array
163 |
164 |
165 | __all__ = ["scan", "fold", "vmap", "map"]
166 |
--------------------------------------------------------------------------------
/src/haliax/nn/__init__.py:
--------------------------------------------------------------------------------
1 | import jax.nn as jnn
2 | import jax.numpy as jnp
3 |
4 | import haliax
5 | import haliax as hax
6 | import haliax.nn.activations
7 | import haliax.nn.attention as attention
8 | import haliax.nn.normalization
9 |
10 | from ..axis import Axis
11 | from ..core import NamedArray
12 | from .activations import (
13 | celu,
14 | elu,
15 | gelu,
16 | glu,
17 | hard_sigmoid,
18 | hard_silu,
19 | hard_swish,
20 | hard_tanh,
21 | leaky_relu,
22 | log_sigmoid,
23 | quick_gelu,
24 | relu,
25 | relu6,
26 | selu,
27 | sigmoid,
28 | silu,
29 | soft_sign,
30 | softplus,
31 | swish,
32 | )
33 | from .conv import Conv, ConvTranspose
34 | from .dropout import Dropout, dropout
35 | from .embedding import Embedding
36 | from .linear import Linear, MoELinear
37 | from .loss import binary_cross_entropy_loss, cross_entropy_loss, cross_entropy_loss_and_log_normalizers, reduce_loss
38 | from .mlp import MLP
39 | from .normalization import LayerNorm, RmsNorm, log_softmax, logsumexp, softmax, standardize
40 | from .pool import max_pool, mean_pool, min_pool
41 | from .scan import BlockSeq, ScanCheckpointPolicy, Stacked
42 |
43 |
44 | def one_hot(x: NamedArray | int, class_axis: Axis, *, dtype=None) -> NamedArray:
45 | """
46 | Convert an integer to a one-hot vector. This is basically a generalization of [jax.nn.one_hot][]
47 | for NamedArrays.
48 |
49 | Args:
50 | x: the integer or NamedArray of integers to convert
51 | class_axis: the axis to convert to one-hot
52 | dtype: the dtype of the result. If None, it will default to jax's default (currently float_)
53 | Returns:
54 | a NamedArray with the same axes as `x` plus `class_axis`, with 1s in the appropriate places
55 | """
56 | if isinstance(x, NamedArray):
57 | array = jnn.one_hot(x.array, num_classes=class_axis.size, dtype=dtype)
58 | # Disabling this to prevent a crash in XLA on GPU
59 | # return hax.auto_sharded(hax.named(array, x.axes + (class_axis,)))
60 | return hax.named(array, x.axes + (class_axis,))
61 | else:
62 | assert isinstance(x, int)
63 | assert class_axis.size > x >= -class_axis.size
64 |
65 | one = 1
66 | if dtype is not None:
67 | one = dtype(one)
68 |
69 | array = jnp.zeros(class_axis.size, dtype=dtype).at[x].set(one)
70 | return hax.auto_sharded(haliax.named(array, class_axis))
71 |
72 |
73 | __all__ = [
74 | "attention",
75 | "one_hot",
76 | "binary_cross_entropy_loss",
77 | "reduce_loss",
78 | "cross_entropy_loss",
79 | "cross_entropy_loss_and_log_normalizers",
80 | "Conv",
81 | "ConvTranspose",
82 | "Dropout",
83 | "dropout",
84 | "LayerNorm",
85 | "Linear",
86 | "MoELinear",
87 | "Embedding",
88 | "RmsNorm",
89 | "Stacked",
90 | "BlockSeq",
91 | "MLP",
92 | "relu",
93 | "gelu",
94 | "quick_gelu",
95 | "glu",
96 | "relu6",
97 | "sigmoid",
98 | "soft_sign",
99 | "softplus",
100 | "swish",
101 | "silu",
102 | "log_sigmoid",
103 | "leaky_relu",
104 | "hard_sigmoid",
105 | "hard_silu",
106 | "hard_swish",
107 | "hard_tanh",
108 | "logsumexp",
109 | "softmax",
110 | "log_softmax",
111 | "standardize",
112 | "elu",
113 | "celu",
114 | "selu",
115 | "max_pool",
116 | "mean_pool",
117 | "min_pool",
118 | "ScanCheckpointPolicy",
119 | ]
120 |
--------------------------------------------------------------------------------
/src/haliax/nn/activations.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | from jax import nn as jnn
4 | from jax import numpy as jnp
5 |
6 | from ..axis import Axis
7 | from ..core import NamedArray
8 | from ..types import Scalar
9 | from ..wrap import wrap_elemwise_unary
10 |
11 |
12 | A = typing.TypeVar("A", Scalar, NamedArray, jnp.ndarray)
13 |
14 |
15 | def relu(a: A) -> A:
16 | return wrap_elemwise_unary(jnn.relu, a)
17 |
18 |
19 | def relu6(a: A) -> A:
20 | return wrap_elemwise_unary(jnn.relu6, a)
21 |
22 |
23 | def sigmoid(a: A) -> A:
24 | return wrap_elemwise_unary(jnn.sigmoid, a)
25 |
26 |
27 | def softplus(a: A) -> A:
28 | return wrap_elemwise_unary(jnn.softplus, a)
29 |
30 |
31 | def soft_sign(a: A) -> A:
32 | return wrap_elemwise_unary(jnn.soft_sign, a)
33 |
34 |
35 | def silu(a: A) -> A:
36 | return wrap_elemwise_unary(jnn.silu, a)
37 |
38 |
39 | def swish(a: A) -> A:
40 | return wrap_elemwise_unary(jnn.swish, a)
41 |
42 |
43 | def log_sigmoid(a: A) -> A:
44 | return wrap_elemwise_unary(jnn.log_sigmoid, a)
45 |
46 |
47 | def leaky_relu(a: A) -> A:
48 | return wrap_elemwise_unary(jnn.leaky_relu, a)
49 |
50 |
51 | def hard_sigmoid(a: A) -> A:
52 | return wrap_elemwise_unary(jnn.hard_sigmoid, a)
53 |
54 |
55 | def hard_silu(a: A) -> A:
56 | return wrap_elemwise_unary(jnn.hard_silu, a)
57 |
58 |
59 | def hard_swish(a: A) -> A:
60 | return wrap_elemwise_unary(jnn.hard_swish, a)
61 |
62 |
63 | def hard_tanh(a: A) -> A:
64 | return wrap_elemwise_unary(jnn.hard_tanh, a)
65 |
66 |
67 | def elu(a: A) -> A:
68 | return wrap_elemwise_unary(jnn.elu, a)
69 |
70 |
71 | def celu(a: A) -> A:
72 | return wrap_elemwise_unary(jnn.celu, a)
73 |
74 |
75 | def selu(a: A) -> A:
76 | return wrap_elemwise_unary(jnn.selu, a)
77 |
78 |
79 | def gelu(a: A, approximate: bool = True) -> A:
80 | return wrap_elemwise_unary(jnn.gelu, a, approximate=approximate)
81 |
82 |
83 | def glu(x: NamedArray, axis: Axis) -> NamedArray:
84 | axis_index = x.axes.index(axis)
85 | return jnn.glu(x.array, axis_index)
86 |
87 |
88 | def quick_gelu(x):
89 | return x * sigmoid(1.702 * x)
90 |
--------------------------------------------------------------------------------
/src/haliax/nn/dropout.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import equinox as eqx
4 | import jax
5 | from jaxtyping import PRNGKeyArray
6 |
7 | import haliax
8 | from haliax.axis import AxisSpec
9 | from haliax.core import NamedArray
10 | from haliax.util import ensure_tuple
11 |
12 |
13 | def dropout(x, pdrop, broadcast_axes=None, *, inference, key=None):
14 | """Applies dropout.
15 |
16 | **Arguments:**
17 |
18 | - `x`: An any-dimensional JAX array to dropout.
19 | - `pdrop`: The fraction of entries to set to zero.
20 | - `broadcast_axes`: The dimensions to broadcast the dropout mask over. If set, these axes will share the same mask
21 | - `key`: A `jax.random.PRNGKey` used to provide randomness for calculating
22 | which elements to dropout. (Keyword only argument.)
23 | - `inference`: As per [`equinox.nn.Dropout.__init__`][]. If `True` or
24 | `False` then it will take priority over `self.inference`. If `None`
25 | then the value from `self.inference` will be used.
26 | """
27 | if inference:
28 | return x
29 | elif isinstance(pdrop, (int, float)) and pdrop == 0:
30 | return x
31 | elif isinstance(pdrop, (int, float)) and pdrop == 1:
32 | return haliax.zeros_like(x)
33 | elif key is None:
34 | raise RuntimeError("Dropout requires a key when running in non-deterministic mode.")
35 | else:
36 | with jax.named_scope(name="dropout"):
37 | if broadcast_axes is None:
38 | if isinstance(x, NamedArray):
39 | shape_to_generate = x.axes
40 | else:
41 | shape_to_generate = x.shape
42 | else:
43 | axes = ensure_tuple(broadcast_axes)
44 | shape_to_generate = tuple(ax for ax in x.axes if ax not in axes)
45 |
46 | q = 1 - pdrop
47 | mask: NamedArray = haliax.random.bernoulli(key, shape_to_generate, q)
48 | q = x.dtype.type(q)
49 |
50 | out = haliax.where(mask, x / q, 0)
51 | assert out.dtype == x.dtype
52 | return out
53 |
54 |
55 | class Dropout(eqx.Module):
56 | """Applies dropout.
57 |
58 | Attributes:
59 | pdrop: The fraction of entries to set to zero.
60 | broadcast_axes: The dimensions to broadcast the dropout mask over. If set, these axes will share the same mask
61 | """
62 |
63 | # key difference from equinox: these are static fields
64 | pdrop: float = eqx.static_field()
65 | broadcast_axes: Optional[AxisSpec] = eqx.static_field()
66 | inference: bool = False # note: not static
67 |
68 | def __init__(
69 | self,
70 | pdrop: float = 0.5,
71 | broadcast_axes: Optional[AxisSpec] = None,
72 | inference: bool = False,
73 | ):
74 | self.pdrop = pdrop
75 | self.broadcast_axes = broadcast_axes
76 | self.inference = inference
77 |
78 | @property
79 | def is_active(self):
80 | """Returns `True` if dropout is active (and therefore needs a key), `False` otherwise."""
81 | return not self.inference and self.pdrop > 0
82 |
83 | def __call__(
84 | self,
85 | x: NamedArray,
86 | *,
87 | inference: Optional[bool] = None,
88 | key: Optional[PRNGKeyArray] = None,
89 | ) -> NamedArray:
90 | """**Arguments:**
91 |
92 | - `x`: An any-dimensional JAX array to dropout.
93 | - `key`: A `jax.random.PRNGKey` used to provide randomness for calculating
94 | which elements to dropout. (Keyword only argument.)
95 | - `inference`: As per [`equinox.nn.Dropout.__init__`][]. If `True` or
96 | `False` then it will take priority over `self.inference`. If `None`
97 | then the value from `self.inference` will be used.
98 | """
99 | if inference is None:
100 | inference = self.inference
101 |
102 | return dropout(
103 | x,
104 | self.pdrop,
105 | broadcast_axes=self.broadcast_axes,
106 | inference=inference,
107 | key=key,
108 | )
109 |
--------------------------------------------------------------------------------
/src/haliax/nn/embedding.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import math
3 | import warnings
4 | from typing import Optional
5 |
6 | import equinox as eqx
7 | from jaxtyping import PRNGKeyArray
8 |
9 | import haliax as hax
10 |
11 | from ..axis import Axis, AxisSpec
12 | from ..core import NamedArray
13 | from ..jax_utils import named_call
14 | from ..tree_util import resize_axis
15 | from ..util import ensure_tuple
16 |
17 |
18 | class Embedding(eqx.Module):
19 | weight: NamedArray
20 |
21 | # axes
22 | Vocab: Axis = eqx.static_field()
23 | Embed: AxisSpec = eqx.static_field()
24 |
25 | @staticmethod
26 | def init(Vocab: Axis, Embed: AxisSpec, *, init_scale: float = 1, key, initializer_range: Optional[float] = None):
27 | """
28 | Initialize an Embedding module.
29 |
30 | An embedding module is a simple lookup table that maps integer indices to vectors or tensors.
31 | Weights are initialized with a truncated normal distribution with a standard deviation of
32 | `init_scale / output_size`.
33 |
34 | Args:
35 | Vocab: Size of the vocabulary
36 | Embed: Shape of the embedding vectors. May be a single axis or a full AxisSpec
37 | init_scale: Scale of the initialization
38 | key: PRNG key
39 | initializer_range: Deprecated. Use init_scale instead.
40 | """
41 | if initializer_range is not None:
42 | warnings.warn("initializer_range is deprecated. Use init_std instead.", DeprecationWarning)
43 | init_scale = initializer_range
44 |
45 | all_axes = (Vocab,) + ensure_tuple(Embed)
46 | output_size = hax.axis_size(Embed)
47 | weight = hax.random.truncated_normal(key, all_axes, -3, 3) * (init_scale / output_size)
48 | return Embedding(weight=weight, Vocab=Vocab, Embed=Embed)
49 |
50 | def __call__(self, input_ids: NamedArray, *, key: Optional[PRNGKeyArray] = None):
51 | """Alias for `embed`. key is ignored."""
52 | return self.embed(input_ids)
53 |
54 | @named_call
55 | def embed(self, input_ids: NamedArray):
56 | """
57 | Args:
58 | input_ids: token IDs with shape > {Vocab}
59 | """
60 | input_embeds = self.weight.take(self.Vocab, input_ids)
61 | return input_embeds
62 |
63 | def unembed(self, input_embeds: NamedArray):
64 | """
65 | Unembed the input embeddings back to the vocabulary space.
66 |
67 | Equivalent to `input_embeds.dot(self.weight, axis=self.Embed)`.
68 | """
69 | return input_embeds.dot(self.weight, axis=self.Embed)
70 |
71 | def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None):
72 | """
73 | Resize the embedding layer to a new size.
74 | Args:
75 | new_size: New size of the vocabulary
76 | key: PRNG key for initialization of any new weights
77 |
78 | Returns:
79 | Embedding: Resized embedding layer
80 |
81 | """
82 | new_weights = resize_axis(self.weight, self.Vocab, new_size, key=key)
83 | return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), weight=new_weights) # type: ignore
84 |
--------------------------------------------------------------------------------
/src/haliax/nn/loss.py:
--------------------------------------------------------------------------------
1 | import typing
2 | import warnings
3 | from typing import Optional
4 |
5 | from jax import numpy as jnp
6 |
7 | import haliax as hax
8 | from haliax.axis import AxisSelection, AxisSelector
9 | from haliax.core import NamedArray
10 | from haliax.util import UNSPECIFIED, Unspecified
11 | from haliax.wrap import ReductionFunction
12 |
13 |
14 | @typing.overload
15 | def cross_entropy_loss(
16 | logits: NamedArray,
17 | Label: AxisSelector,
18 | targets: NamedArray,
19 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED,
20 | where: Optional[NamedArray] = None,
21 | reduction_axis: None = None,
22 | ) -> jnp.ndarray | NamedArray:
23 | ...
24 |
25 |
26 | @typing.overload
27 | def cross_entropy_loss(
28 | logits: NamedArray,
29 | Label: AxisSelector,
30 | targets: NamedArray,
31 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED,
32 | where: Optional[NamedArray] = None,
33 | reduction_axis: AxisSelection = ...,
34 | ) -> NamedArray:
35 | ...
36 |
37 |
38 | def cross_entropy_loss(
39 | logits: NamedArray,
40 | Label: AxisSelector,
41 | targets: NamedArray,
42 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED,
43 | where: Optional[NamedArray] = None,
44 | reduction_axis: Optional[AxisSelection] = None,
45 | ) -> jnp.ndarray | NamedArray:
46 | loss, _ = cross_entropy_loss_and_log_normalizers(logits, Label, targets)
47 |
48 | # if target_y isn't some kind of floating point, something is wrong, so warn
49 | if not jnp.issubdtype(targets.dtype, jnp.floating):
50 | warnings.warn(
51 | f"target_y has dtype {targets.dtype}, which is not a floating point type. This is probably a mistake."
52 | )
53 |
54 | loss = maybe_reduce_loss(loss, reduction, reduction_axis, where)
55 |
56 | return loss
57 |
58 |
59 | @typing.overload
60 | def binary_cross_entropy_loss(
61 | logits: NamedArray,
62 | targets: NamedArray,
63 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED,
64 | where: Optional[NamedArray] = None,
65 | reduction_axis: None = None,
66 | ) -> jnp.ndarray | NamedArray:
67 | ...
68 |
69 |
70 | @typing.overload
71 | def binary_cross_entropy_loss(
72 | logits: NamedArray,
73 | targets: NamedArray,
74 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED,
75 | where: Optional[NamedArray] = None,
76 | reduction_axis: AxisSelection = ...,
77 | ) -> NamedArray:
78 | ...
79 |
80 |
81 | def binary_cross_entropy_loss(
82 | logits: NamedArray,
83 | targets: NamedArray,
84 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED,
85 | where: Optional[NamedArray] = None,
86 | reduction_axis: Optional[AxisSelection] = None,
87 | ) -> jnp.ndarray | NamedArray:
88 | log_p = hax.nn.log_sigmoid(logits)
89 | log_not_p = hax.nn.log_sigmoid(-logits) # == log(1-sigmoid(x))
90 | targets = targets.astype(logits.dtype)
91 | loss = -targets * log_p - (1.0 - targets) * log_not_p
92 |
93 | loss = maybe_reduce_loss(loss, reduction, reduction_axis, where)
94 | return loss
95 |
96 |
97 | def reduce_loss(
98 | arr,
99 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED,
100 | reduction_axis: Optional[AxisSelection] = None,
101 | where: Optional[NamedArray] = None,
102 | ):
103 | """
104 | Reduce a loss array according to the given reduction and reduction axis.
105 | If reduction is None, the loss is not reduced.
106 | If reduction is UNSPECIFIED, the default reduction is used (mean).
107 | If reduction_axis is None (default), the loss is reduced over all axes.
108 | """
109 | return maybe_reduce_loss(arr, reduction, reduction_axis, where)
110 |
111 |
112 | def maybe_reduce_loss(
113 | arr,
114 | reduction: Optional[ReductionFunction] | Unspecified,
115 | reduction_axis: Optional[AxisSelection],
116 | where: Optional[NamedArray],
117 | ):
118 | if reduction is not None and reduction_axis != ():
119 | if reduction is UNSPECIFIED:
120 | reduction = hax.mean
121 | arr = reduction(arr, where=where, axis=reduction_axis)
122 | elif where is not None:
123 | arr = hax.where(where, arr, 0)
124 | return arr
125 |
126 |
127 | def cross_entropy_loss_and_log_normalizers(
128 | pred_y: NamedArray,
129 | Label: AxisSelector,
130 | target_y: NamedArray,
131 | ) -> tuple[NamedArray, NamedArray]:
132 | """
133 | Compute the cross entropy loss and log normalizers for a batch of predictions and targets.
134 |
135 | :param pred_y: a NamedArray with the Label axis (and possibly others for e.g. batch and seq) containing the logits
136 | :param Label: the Label axis
137 | :param target_y: a NamedArray with the Label axis (and possibly others) containing the targets
138 |
139 | :return: tuple of two named arrays, with "per position" losses and log normalizers
140 | """
141 | log_normalizers = hax.nn.logsumexp(pred_y, Label)
142 | neg_log_normalized = log_normalizers - pred_y
143 |
144 | loss = hax.dot(target_y, neg_log_normalized, axis=Label)
145 |
146 | return loss, log_normalizers
147 |
--------------------------------------------------------------------------------
/src/haliax/nn/mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Sequence
2 |
3 | import equinox as eqx
4 | import jax
5 | from jaxtyping import PRNGKeyArray
6 |
7 | from ..axis import Axis, AxisSpec
8 | from ..core import NamedArray
9 | from ..jax_utils import maybe_rng_split
10 | from ..quantization import DotGeneralOp
11 | from .activations import relu
12 | from .linear import Linear
13 |
14 |
15 | DEFAULT_WIDTH_NAME = "mlp"
16 |
17 |
18 | class MLP(eqx.Module):
19 | """
20 | A multilayer perceptron (MLP) / feed-forward neural network (FFNN).
21 |
22 | MLPs, with their stacked linear layers often with non-semantic axes for hidden
23 | dims, are not a particular strength of Haliax's design philosophy. Nonetheless, they are a useful tool for
24 | many tasks, and so we provide this module.
25 |
26 | In Haliax, all axes must have names, and names must be unique within an array. We considered a few strategies
27 | for naming the axes of an MLP, and settled on the following: By default, we alternate hidden names between "mlp"
28 | and "mlp2". Input and output names must be specified, and are not repeated. This naming scheme is not perfect,
29 | but does mean that model parallelism works reasonably well.
30 |
31 | NB: unlike Equinox's MLP, this MLP uses a static field for activation. If you want a learnable activation, you
32 | likely want a unique activation per layer, which neither version provides. Instead, you should use a
33 | [haliax.nn.Stacked][] with a suitable block.
34 | """
35 |
36 | activation: Callable = eqx.field(static=True)
37 | layers: Sequence[Linear]
38 |
39 | @staticmethod
40 | def init(
41 | Input: AxisSpec,
42 | Output: AxisSpec,
43 | width: int | Axis,
44 | depth: int,
45 | activation: Callable = relu,
46 | *,
47 | out_first: bool = True,
48 | use_bias: bool = True,
49 | use_final_bias: bool = True,
50 | key: PRNGKeyArray,
51 | dot_general: Optional[DotGeneralOp] = None,
52 | init_scale: float = 1.0,
53 | ):
54 | Width = _get_width(width)
55 | Width2 = Width.alias(Width.name + "2")
56 |
57 | keys = jax.random.split(key, depth + 1)
58 |
59 | layers = []
60 |
61 | kwargs: dict = {
62 | "use_bias": use_bias,
63 | "dot_general": dot_general,
64 | "init_scale": init_scale,
65 | "out_first": out_first,
66 | }
67 |
68 | last_kwargs: dict = {
69 | "use_bias": use_final_bias,
70 | "dot_general": dot_general,
71 | "init_scale": init_scale,
72 | "out_first": out_first,
73 | }
74 |
75 | if depth == 0:
76 | # special case: no hidden layers
77 | layers.append(Linear.init(Input, Output, key=keys[0], **last_kwargs))
78 | else:
79 | # first hidden layer
80 | layers.append(Linear.init(Input, Width, key=keys[0], **kwargs))
81 | # middle hidden layers
82 | cur = Width
83 | next = Width2
84 | for i in range(1, depth):
85 | layers.append(Linear.init(cur, next, key=keys[i], **kwargs))
86 | cur, next = next, cur
87 | # final layer
88 | layers.append(Linear.init(cur, Output, key=keys[-1], **last_kwargs))
89 |
90 | return MLP(
91 | layers=tuple(layers),
92 | activation=activation,
93 | )
94 |
95 | @property
96 | def In(self) -> AxisSpec:
97 | return self.layers[0].In
98 |
99 | @property
100 | def Out(self) -> AxisSpec:
101 | return self.layers[-1].Out
102 |
103 | def __call__(self, x: NamedArray, *, key=None) -> NamedArray:
104 | keys = maybe_rng_split(key, len(self.layers))
105 | for layer, k in zip(self.layers[:-1], keys):
106 | x = self.activation(layer(x, key=k))
107 | return self.layers[-1](x, key=keys[-1])
108 |
109 |
110 | def _get_width(Width: int | Axis) -> Axis:
111 | if isinstance(Width, int):
112 | return Axis(DEFAULT_WIDTH_NAME, Width)
113 | else:
114 | return Width
115 |
--------------------------------------------------------------------------------
/src/haliax/nn/normalization.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from abc import abstractmethod
3 | from typing import Optional, TypeVar
4 |
5 | import equinox as eqx
6 | from jax import nn as jnn
7 | from jax import numpy as jnp
8 |
9 | import haliax
10 | import haliax as hax
11 |
12 | from .._src.state_dict import Mod, ModuleWithStateDictSerialization
13 | from ..axis import AxisSelection, AxisSpec
14 | from ..core import NamedArray
15 | from ..types import Scalar
16 | from ..wrap import unwrap_namedarrays, wrap_axiswise_call, wrap_reduction_call
17 |
18 |
19 | A = TypeVar("A", Scalar, NamedArray, jnp.ndarray)
20 |
21 |
22 | class LayerNormBase(ModuleWithStateDictSerialization):
23 | axis: AxisSpec = eqx.static_field()
24 | weight: Optional[NamedArray]
25 | bias: Optional[NamedArray]
26 | eps: float = eqx.static_field(default=1e-5)
27 | dtype: Optional[jnp.dtype] = eqx.field(default=None, static=True)
28 |
29 | @abstractmethod
30 | def __call__(self, x: NamedArray) -> NamedArray:
31 | pass
32 |
33 | @classmethod
34 | def init(
35 | cls,
36 | axis: AxisSpec,
37 | eps: float = 1e-5,
38 | *,
39 | use_weight: bool = True,
40 | use_bias: bool = True,
41 | dtype: Optional[jnp.dtype] = None,
42 | ):
43 | if use_weight:
44 | weight = hax.ones(axis)
45 | else:
46 | weight = None
47 |
48 | if use_bias:
49 | bias = hax.zeros(axis)
50 | else:
51 | bias = None
52 |
53 | return cls(axis, weight, bias, eps, dtype)
54 |
55 | def flatten_for_export(self: Mod) -> Mod:
56 | if isinstance(self.axis, hax.Axis):
57 | return self
58 |
59 | if self.weight is not None:
60 | weight = self.weight.flatten("__OUT")
61 | else:
62 | weight = None
63 |
64 | if self.bias is not None:
65 | bias = self.bias.flatten("__OUT")
66 | else:
67 | bias = None
68 |
69 | return dataclasses.replace(self, weight=weight, bias=bias, axis=hax.flatten_axes(self.axis, "__OUT"))
70 |
71 | def unflatten_from_export(self: Mod, template: Mod) -> Mod:
72 | if template.axis == self.axis:
73 | return self
74 |
75 | if self.weight is not None:
76 | assert isinstance(self.axis, hax.Axis), "Cannot unflatten weight with non-axis axis"
77 | weight = hax.unflatten_axis(self.weight, self.axis, template.axis)
78 | else:
79 | weight = None
80 |
81 | if self.bias is not None:
82 | assert isinstance(self.axis, hax.Axis), "Cannot unflatten weight with non-axis axis"
83 | bias = hax.unflatten_axis(self.bias, self.axis, template.axis)
84 |
85 | else:
86 | bias = None
87 |
88 | return dataclasses.replace(self, weight=weight, bias=bias, axis=template.axis)
89 |
90 |
91 | class LayerNorm(LayerNormBase):
92 | r"""
93 | Normalises the input along the specified axis (or axes), using the mean and variance of the
94 | input along that axis.
95 | """
96 | axis: AxisSpec = eqx.field(static=True)
97 | weight: Optional[NamedArray]
98 | bias: Optional[NamedArray]
99 |
100 | eps: float = eqx.field(default=1e-5, static=True)
101 | dtype: Optional[jnp.dtype] = eqx.field(default=None, static=True)
102 |
103 | def __call__(self, x: NamedArray) -> NamedArray:
104 | dtype = x.dtype
105 | mean = x.mean(self.axis)
106 | var = x.var(self.axis)
107 | inv = hax.rsqrt(var + self.eps)
108 | out = (x - mean) * inv
109 | out = out.astype(dtype)
110 |
111 | if self.weight is not None:
112 | out = self.weight * out
113 | if self.bias is not None:
114 | out = out + self.bias
115 | return out
116 |
117 |
118 | class RmsNorm(LayerNormBase):
119 | r"""
120 | Implements RMS normalization, which normalizes the input by dividing by the root mean square of the input.
121 | """
122 |
123 | def __call__(self, x: NamedArray) -> NamedArray:
124 | in_dtype = x.dtype
125 | x = x.astype(self.dtype)
126 | var = hax.mean(hax.square(x), axis=self.axis)
127 | inv = hax.rsqrt(var + self.eps)
128 | out = x * inv
129 | out = out.astype(in_dtype)
130 |
131 | if self.weight is not None:
132 | out = self.weight * out
133 | if self.bias is not None:
134 | out = out + self.bias
135 | return out
136 |
137 |
138 | def logsumexp(a: A, axis: Optional[AxisSelection] = None) -> A:
139 | # TODO: logsumexp indirectly supports where via `b`. we should support it directly
140 | return wrap_reduction_call(jnn.logsumexp, a, axis=axis, single_axis_only=False, supports_where=False)
141 |
142 |
143 | # TODO: support where in softmax, etc
144 |
145 |
146 | def softmax(a: A, axis: Optional[AxisSelection] = None) -> A:
147 | return wrap_axiswise_call(jnn.softmax, a, axis=axis, single_axis_only=False)
148 |
149 |
150 | def log_softmax(a: A, axis: Optional[AxisSelection] = None) -> A:
151 | return wrap_axiswise_call(jnn.log_softmax, a, axis=axis, single_axis_only=False)
152 |
153 |
154 | def standardize(
155 | x: NamedArray,
156 | axis: AxisSpec,
157 | *,
158 | mean: Optional[NamedArray] = None,
159 | variance: Optional[NamedArray] = None,
160 | epsilon: float = 1e-5,
161 | where: Optional[NamedArray] = None,
162 | ) -> NamedArray:
163 | """Analogous to [jax.nn.standardize][], but with support for NamedArrays."""
164 | x, mean, variance, where = haliax.broadcast_arrays(x, mean, variance, where) # type: ignore
165 | raw_x, mean, variance, where = unwrap_namedarrays(x, mean, variance, where)
166 | axis_indices = x._lookup_indices(axis)
167 |
168 | plain = jnn.standardize(raw_x, axis_indices, mean=mean, variance=variance, epsilon=epsilon, where=where)
169 | return NamedArray(plain, x.axes)
170 |
--------------------------------------------------------------------------------
/src/haliax/ops.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from typing import Optional, Union
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | from .axis import Axis, AxisSelector
8 | from .core import NamedArray, NamedOrNumeric, broadcast_arrays, broadcast_arrays_and_return_axes, named
9 | from .jax_utils import is_scalarish
10 |
11 |
12 | def trace(array: NamedArray, axis1: AxisSelector, axis2: AxisSelector, offset=0, dtype=None) -> NamedArray:
13 | """Compute the trace of an array along two named axes."""
14 | a1_index = array._lookup_indices(axis1)
15 | a2_index = array._lookup_indices(axis2)
16 |
17 | if a1_index is None:
18 | raise ValueError(f"Axis {axis1} not found in array. Available axes: {array.axes}")
19 | if a2_index is None:
20 | raise ValueError(f"Axis {axis2} not found in array. Available axes: {array.axes}")
21 |
22 | if a1_index == a2_index:
23 | raise ValueError(f"Cannot trace along the same axis. Got {axis1} and {axis2}")
24 |
25 | inner = jnp.trace(array.array, offset=offset, axis1=a1_index, axis2=a2_index, dtype=dtype)
26 | # remove the two indices
27 | axes = tuple(a for i, a in enumerate(array.axes) if i not in (a1_index, a2_index))
28 | return NamedArray(inner, axes)
29 |
30 |
31 | @typing.overload
32 | def where(
33 | condition: NamedOrNumeric | bool,
34 | x: NamedOrNumeric,
35 | y: NamedOrNumeric,
36 | ) -> NamedArray:
37 | ...
38 |
39 |
40 | @typing.overload
41 | def where(
42 | condition: NamedArray,
43 | *,
44 | fill_value: int,
45 | new_axis: Axis,
46 | ) -> tuple[NamedArray, ...]:
47 | ...
48 |
49 |
50 | def where(
51 | condition: Union[NamedOrNumeric, bool],
52 | x: Optional[NamedOrNumeric] = None,
53 | y: Optional[NamedOrNumeric] = None,
54 | fill_value: Optional[int] = None,
55 | new_axis: Optional[Axis] = None,
56 | ) -> NamedArray | tuple[NamedArray, ...]:
57 | """Like jnp.where, but with named axes."""
58 |
59 | if (x is None) != (y is None):
60 | raise ValueError("Must either specify both x and y, or neither")
61 |
62 | # one argument form
63 | if (x is None) and (y is None):
64 | if not isinstance(condition, NamedArray):
65 | raise ValueError(f"condition {condition} must be a NamedArray in single argument mode")
66 | if fill_value is None or new_axis is None:
67 | raise ValueError("Must specify both fill_value and new_axis")
68 | return tuple(
69 | NamedArray(idx, (new_axis,))
70 | for idx in jnp.where(condition.array, size=new_axis.size, fill_value=fill_value)
71 | )
72 |
73 | # if x or y is a NamedArray, the other must be as well. wrap as needed for scalars
74 |
75 | if is_scalarish(condition):
76 | if x is None or y is None:
77 | raise ValueError("Must specify x and y when condition is a scalar")
78 |
79 | if isinstance(x, NamedArray) and not isinstance(y, NamedArray):
80 | if not is_scalarish(y):
81 | raise ValueError("y must be a NamedArray or scalar if x is a NamedArray")
82 | y = named(y, ())
83 | elif isinstance(y, NamedArray) and not isinstance(x, NamedArray):
84 | if not is_scalarish(x):
85 | raise ValueError("x must be a NamedArray or scalar if y is a NamedArray")
86 | x = named(x, ())
87 | x, y = broadcast_arrays(x, y)
88 | return jax.lax.cond(condition, lambda _: x, lambda _: y, None)
89 |
90 | condition, x, y = broadcast_arrays(condition, x, y) # type: ignore
91 |
92 | assert isinstance(condition, NamedArray)
93 |
94 | def _array_if_named(x):
95 | if isinstance(x, NamedArray):
96 | return x.array
97 | return x
98 |
99 | raw = jnp.where(condition.array, _array_if_named(x), _array_if_named(y))
100 | return NamedArray(raw, condition.axes)
101 |
102 |
103 | def clip(array: NamedOrNumeric, a_min: NamedOrNumeric, a_max: NamedOrNumeric) -> NamedArray:
104 | """Like jnp.clip, but with named axes. This version currently only accepts the three argument form."""
105 | (array, a_min, a_max), axes = broadcast_arrays_and_return_axes(array, a_min, a_max)
106 | array = raw_array_or_scalar(array)
107 | a_min = raw_array_or_scalar(a_min)
108 | a_max = raw_array_or_scalar(a_max)
109 |
110 | return NamedArray(jnp.clip(array, a_min, a_max), axes)
111 |
112 |
113 | def tril(array: NamedArray, axis1: Axis, axis2: Axis, k=0) -> NamedArray:
114 | """Compute the lower triangular part of an array along two named axes."""
115 | array = array.rearrange((..., axis1, axis2))
116 |
117 | inner = jnp.tril(array.array, k=k)
118 | return NamedArray(inner, array.axes)
119 |
120 |
121 | def triu(array: NamedArray, axis1: Axis, axis2: Axis, k=0) -> NamedArray:
122 | """Compute the upper triangular part of an array along two named axes."""
123 | array = array.rearrange((..., axis1, axis2))
124 |
125 | inner = jnp.triu(array.array, k=k)
126 | return NamedArray(inner, array.axes)
127 |
128 |
129 | def isclose(a: NamedArray, b: NamedArray, rtol=1e-05, atol=1e-08, equal_nan=False) -> NamedArray:
130 | """Returns a boolean array where two arrays are element-wise equal within a tolerance."""
131 | a, b = broadcast_arrays(a, b)
132 | # TODO: numpy supports an array atol and rtol, but we don't yet
133 | return NamedArray(jnp.isclose(a.array, b.array, rtol=rtol, atol=atol, equal_nan=equal_nan), a.axes)
134 |
135 |
136 | def pad_left(array: NamedArray, axis: Axis, new_axis: Axis, value=0) -> NamedArray:
137 | """Pad an array along named axes."""
138 | amount_to_pad_to = new_axis.size - axis.size
139 | if amount_to_pad_to < 0:
140 | raise ValueError(f"Cannot pad {axis} to {new_axis}")
141 |
142 | idx = array._lookup_indices(axis)
143 |
144 | padding = [(0, 0)] * array.ndim
145 | if idx is None:
146 | raise ValueError(f"Axis {axis} not found in array. Available axes: {array.axes}")
147 | padding[idx] = (amount_to_pad_to, 0)
148 |
149 | padded = jnp.pad(array.array, padding, constant_values=value)
150 | return NamedArray(padded, array.axes[:idx] + (new_axis,) + array.axes[idx + 1 :])
151 |
152 |
153 | def raw_array_or_scalar(x: NamedOrNumeric):
154 | if isinstance(x, NamedArray):
155 | return x.array
156 | return x
157 |
158 |
159 | __all__ = ["trace", "where", "tril", "triu", "isclose", "pad_left", "clip"]
160 |
--------------------------------------------------------------------------------
/src/haliax/specialized_fns.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from .axis import Axis, AxisSelector
7 | from .core import NamedArray
8 |
9 |
10 | def top_k(
11 | arr: NamedArray, axis: AxisSelector, k: int, new_axis: Optional[AxisSelector] = None
12 | ) -> Tuple[NamedArray, NamedArray]:
13 | """
14 | Select the top k elements along the given axis.
15 | Args:
16 | arr (NamedArray): array to select from
17 | axis (AxisSelector): axis to select from
18 | k (int): number of elements to select
19 | new_axis (Optional[AxisSelector]): new axis name, if none, the original axis will be resized to k
20 |
21 | Returns:
22 | NamedArray: array with the top k elements along the given axis
23 | NamedArray: array with the top k elements' indices along the given axis
24 | """
25 | pos = arr._lookup_indices(axis)
26 | if pos is None:
27 | raise ValueError(f"Axis {axis} not found in {arr}")
28 | new_array = jnp.moveaxis(arr.array, pos, -1) # move axis to the last position
29 | values, indices = jax.lax.top_k(new_array, k=k)
30 | values = jnp.moveaxis(values, -1, pos) # move axis back to its original position
31 | indices = jnp.moveaxis(indices, -1, pos)
32 |
33 | if new_axis is None:
34 | axis = arr.resolve_axis(axis)
35 | new_axis = axis.resize(k)
36 | elif isinstance(new_axis, str):
37 | new_axis = Axis(new_axis, k)
38 |
39 | updated_axes = arr.axes[:pos] + (new_axis,) + arr.axes[pos + 1 :]
40 | return NamedArray(values, updated_axes), NamedArray(indices, updated_axes)
41 |
--------------------------------------------------------------------------------
/src/haliax/state_dict.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, TypeVar
2 |
3 | import equinox
4 |
5 | from haliax.jax_utils import is_jax_array_like
6 | from haliax.types import FilterSpec
7 |
8 | from ._src.state_dict import (
9 | ModuleWithStateDictSerialization,
10 | StateDict,
11 | flatten_modules_for_export,
12 | from_state_dict,
13 | from_torch_compatible_state_dict,
14 | load_state_dict,
15 | save_state_dict,
16 | to_numpy_state_dict,
17 | to_state_dict,
18 | unflatten_modules_from_export,
19 | with_prefix,
20 | )
21 |
22 |
23 | T = TypeVar("T")
24 |
25 |
26 | def to_torch_compatible_state_dict(
27 | t: T, *, flatten: bool = True, prefix: Optional[str] = None, filter: FilterSpec = is_jax_array_like
28 | ) -> StateDict:
29 | """
30 | Convert a tree to a state dict that is compatible with torch-style state dicts.
31 |
32 | This applies the same logic as [to_state_dict][] but also uses [haliax.state_dict.ModuleWithStateDictSerialization.flatten_for_export][] to flatten
33 |
34 | Args:
35 | t: The tree to convert
36 | flatten: Whether to flatten axes using flatten_for_export
37 | prefix: The prefix to use for the state dict keys
38 | filter: The filter to use for selecting which nodes to include in the state dict. By default, this includes only
39 | array-like objects (e.g. JAX and NumPy arrays).
40 | """
41 | t = equinox.filter(t, filter)
42 | if flatten:
43 | t = flatten_modules_for_export(t)
44 | return to_numpy_state_dict(t, prefix=prefix)
45 |
46 |
47 | __all__ = [
48 | "ModuleWithStateDictSerialization",
49 | "from_torch_compatible_state_dict",
50 | "load_state_dict",
51 | "save_state_dict",
52 | "from_state_dict",
53 | "with_prefix",
54 | "to_state_dict",
55 | "to_numpy_state_dict",
56 | "StateDict",
57 | "to_torch_compatible_state_dict",
58 | "flatten_modules_for_export",
59 | "unflatten_modules_from_export",
60 | ]
61 |
--------------------------------------------------------------------------------
/src/haliax/tree_util.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import functools
3 | from typing import Optional
4 |
5 | import equinox as eqx
6 | import jax
7 | import jax.tree_util as jtu
8 | from jaxtyping import PRNGKeyArray, PyTree
9 |
10 | import haliax.nn
11 |
12 | from .axis import AxisSelector
13 | from .core import NamedArray
14 | from .jax_utils import maybe_rng_split
15 | from .util import is_named_array
16 |
17 |
18 | def tree_map(fn, tree, *rest, is_leaf=None):
19 | """
20 | Version of [jax.tree_util.tree_map][] that automatically treats NamedArrays as leaves.
21 | """
22 | old_is_leaf = is_leaf
23 | if is_leaf is None:
24 | is_leaf = lambda x: isinstance(x, NamedArray)
25 | else:
26 | is_leaf = lambda x: old_is_leaf(x) or is_named_array(x)
27 |
28 | return jax.tree.map(fn, tree, *rest, is_leaf=is_leaf)
29 |
30 |
31 | def scan_aware_tree_map(fn, tree, *rest, is_leaf=None):
32 | """
33 | Version of [haliax.tree_util.tree_map][] that is aware of the scan-layer pattern, specifically as implemented
34 | in hax.nn.Stacked. This function will (implicitly) apply the transform to each layer in each Stacked module
35 | (using vmap). If there are no Stacked modules in the tree, this function is equivalent to [haliax.tree_util.tree_map][].
36 |
37 | """
38 | old_is_leaf = is_leaf
39 | if is_leaf is None:
40 | is_leaf = lambda x: isinstance(x, haliax.nn.Stacked)
41 | else:
42 | is_leaf = lambda x: old_is_leaf(x) or isinstance(x, haliax.nn.Stacked)
43 |
44 | mapped_fn = functools.partial(scan_aware_tree_map, fn, is_leaf=is_leaf)
45 |
46 | def rec_fn(x, *rest):
47 | if isinstance(x, haliax.nn.Stacked):
48 | new_inner = haliax.vmap(mapped_fn, x.Block)(x.stacked, *[r.stacked for r in rest])
49 | return dataclasses.replace(x, stacked=new_inner) # type: ignore
50 | else:
51 | return fn(x, *rest)
52 |
53 | return tree_map(rec_fn, tree, *rest, is_leaf=is_leaf)
54 |
55 |
56 | def tree_flatten(tree, is_leaf=None):
57 | """
58 | Version of [jax.tree_util.tree_flatten][] that automatically treats NamedArrays as leaves.
59 | """
60 | if is_leaf is None:
61 | is_leaf = lambda x: isinstance(x, NamedArray)
62 | else:
63 | is_leaf = lambda x: is_leaf(x) or is_named_array(x)
64 |
65 | return jax.tree_util.tree_flatten(tree, is_leaf=is_leaf)
66 |
67 |
68 | def tree_unflatten(treedef, leaves):
69 | """
70 | Provided for consistency with tree_flatten.
71 | """
72 | return jax.tree_util.tree_unflatten(treedef, leaves)
73 |
74 |
75 | def tree_leaves(tree, is_leaf=None):
76 | """
77 | Version of [jax.tree_util.tree_leaves][] that automatically treats NamedArrays as leaves.
78 | """
79 | if is_leaf is None:
80 | is_leaf = lambda x: isinstance(x, NamedArray)
81 | else:
82 | is_leaf = lambda x: is_leaf(x) or is_named_array(x)
83 |
84 | return jax.tree_util.tree_leaves(tree, is_leaf=is_leaf)
85 |
86 |
87 | def tree_structure(tree, is_leaf=None):
88 | """
89 | Version of [jax.tree_util.tree_structure][] that automatically treats NamedArrays as leaves.
90 | """
91 | if is_leaf is None:
92 | is_leaf = lambda x: isinstance(x, NamedArray)
93 | else:
94 | is_leaf = lambda x: is_leaf(x) or is_named_array(x)
95 |
96 | return jax.tree_util.tree_structure(tree, is_leaf=is_leaf)
97 |
98 |
99 | def resize_axis(tree: PyTree[NamedArray], old_axis: AxisSelector, new_size: int, key: Optional[PRNGKeyArray] = None):
100 | """Resizes the NamedArrays of a PyTree along a given axis. If the array needs to grow and key is not none, then the
101 | new elements are sampled from a truncated normal distribution with the same mean and standard deviation as the
102 | existing elements. If the key is none, they're just initialized to the mean. If the array needs to shrink, then it's
103 | truncated.
104 |
105 | Note: if you have a module that stores a reference to the old axis, then you'll need to update that reference
106 | manually.
107 |
108 | """
109 | import haliax.random
110 |
111 | def _resize_one(x, key):
112 | if not is_named_array(x):
113 | return x
114 |
115 | assert isinstance(x, NamedArray)
116 |
117 | try:
118 | current_axis = x.resolve_axis(old_axis)
119 | except ValueError:
120 | return x
121 |
122 | if new_size == current_axis.size:
123 | return x
124 | elif current_axis.size > new_size:
125 | return x.slice(current_axis, start=0, length=new_size)
126 | else:
127 | num_padding = new_size - current_axis.size
128 |
129 | mean = x.mean(current_axis)
130 | std = x.std(current_axis)
131 |
132 | # the shape of the padding is the same as the original array, except with the axis size changed
133 | padding_axes = list(x.axes)
134 | padding_axes[padding_axes.index(current_axis)] = current_axis.resize(num_padding)
135 |
136 | if key is None:
137 | padding = mean.broadcast_axis(padding_axes)
138 | else:
139 | padding = haliax.random.truncated_normal(key, padding_axes, lower=-2, upper=2) * std + mean
140 |
141 | return haliax.concatenate(current_axis.name, [x, padding])
142 |
143 | leaves, structure = jax.tree_util.tree_flatten(tree, is_leaf=is_named_array)
144 | keys = maybe_rng_split(key, len(leaves))
145 |
146 | new_leaves = [_resize_one(x, key) for x, key in zip(leaves, keys)]
147 |
148 | return jax.tree_util.tree_unflatten(structure, new_leaves)
149 |
150 |
151 | # old version of eqx's partition functions
152 | def hashable_partition(pytree, filter_spec):
153 | dynamic, static = eqx.partition(pytree, filter_spec)
154 | static_leaves, static_treedef = jtu.tree_flatten(static)
155 | static_leaves = tuple(static_leaves)
156 | return dynamic, (static_leaves, static_treedef)
157 |
158 |
159 | def hashable_combine(dynamic, static):
160 | static_leaves, static_treedef = static
161 | static = jtu.tree_unflatten(static_treedef, static_leaves)
162 | return eqx.combine(dynamic, static)
163 |
--------------------------------------------------------------------------------
/src/haliax/types.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Literal, Protocol, Tuple, TypeAlias, Union
2 |
3 | import jax.numpy as jnp
4 | import numpy as np
5 | from jax.lax import Precision
6 | from jaxtyping import PyTree
7 |
8 |
9 | DType: TypeAlias = np.dtype
10 |
11 | try:
12 | from jax.typing import DTypeLike
13 | except ImportError:
14 | # Cribbed from jax.typing, for older versions of JAX
15 | class SupportsDType(Protocol):
16 | @property
17 | def dtype(self) -> DType:
18 | ...
19 |
20 | DTypeLike = Union[
21 | str, # like 'float32', 'int32'
22 | type, # like np.float32, np.int32, float, int
23 | np.dtype, # like np.dtype('float32'), np.dtype('int32')
24 | SupportsDType, # like jnp.float32, jnp.int32
25 | ]
26 |
27 |
28 | Scalar = Union[float, int, jnp.ndarray] # ndarray b/c array(1) is a scalar
29 | IntScalar = Union[int, jnp.ndarray]
30 |
31 | PrecisionLike = Union[None, str, Precision, Tuple[str, str], Tuple[Precision, Precision]]
32 |
33 | GatherScatterModeStr = Literal["promise_in_bounds", "clip", "drop", "fill"]
34 |
35 |
36 | FilterSpec = Union[bool, Callable[[Any], bool]]
37 | """
38 | A filter specification. Typically used on a pytree to filter out certain subtrees. Boolean values are
39 | treated as-is, while callables are called on each element of the pytree. If the callable returns True, the element
40 | is kept, otherwise it is filtered out.
41 | """
42 |
43 | FilterTree = FilterSpec | PyTree[FilterSpec]
44 |
--------------------------------------------------------------------------------
/src/haliax/util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import Sequence, Tuple, TypeAlias, TypeVar, Union
3 |
4 | import equinox
5 |
6 | from haliax.jax_utils import is_jax_array_like
7 |
8 |
9 | T = TypeVar("T")
10 |
11 | Unspecified: TypeAlias = type("NotSpecified", (), {}) # type: ignore
12 | UNSPECIFIED = Unspecified()
13 |
14 |
15 | def is_named_array(leaf):
16 | from .core import NamedArray
17 |
18 | "Typically used as the is_leaf predicate in tree_map"
19 | return isinstance(leaf, NamedArray)
20 |
21 |
22 | def ensure_tuple(x: Union[Sequence[T], T]) -> Tuple[T, ...]:
23 | if isinstance(x, str):
24 | return (x,) # type: ignore
25 | elif isinstance(x, Sequence):
26 | return tuple(x)
27 | return (x,)
28 |
29 |
30 | def maybe_untuple(x: Union[Sequence[T], T]) -> Union[T, Sequence[T]]:
31 | """
32 | If x is a tuple with one element, return that element. Otherwise return x.
33 | """
34 | if isinstance(x, tuple) and len(x) == 1:
35 | return x[0]
36 | return x
37 |
38 |
39 | class StringHolderEnum(type):
40 | """Like a python enum but just holds string constants, as opposed to wrapped string constants"""
41 |
42 | # https://stackoverflow.com/questions/62881486/a-group-of-constants-in-python
43 |
44 | def __new__(cls, name, bases, members):
45 | cls.members = [v for k, v in members.items() if not k.startswith("__") and not callable(v)]
46 | return super().__new__(cls, name, bases, members)
47 |
48 | # giving your class an __iter__ method gives you membership checking
49 | # and the ability to easily convert to another iterable
50 | @classmethod
51 | def __iter__(cls):
52 | yield from cls.members
53 |
54 |
55 | def is_jax_or_hax_array_like(x):
56 | return is_jax_array_like(x) or is_named_array(x)
57 |
58 |
59 | def safe_wraps(fn):
60 | """
61 | Equinox has a special [equinox.module_update_wrapper][] that works with [equinox.Module][]s, but
62 | doesn't work with regular functions. Likewise, functools.update_wrapper doesn't work with [equinox.Module][]s.
63 |
64 | This function is a wrapper around both of them that works with both [equinox.Module][]s and regular functions.
65 |
66 | Use this if you get this exception: `dataclasses.FrozenInstanceError: cannot assign to field '__module__'`
67 | """
68 | return functools.partial(safe_update_wrapper, wrapped=fn)
69 |
70 |
71 | def safe_update_wrapper(wrapper, wrapped):
72 | """
73 | As [safe_wraps][] but not a decorator.
74 | Args:
75 | wrapper:
76 | wrapped:
77 |
78 | Returns:
79 |
80 | """
81 | if isinstance(wrapper, equinox.Module):
82 | return equinox.module_update_wrapper(wrapper, wrapped)
83 | else:
84 | return functools.update_wrapper(wrapper, wrapped)
85 |
86 |
87 | __all__ = [
88 | "is_named_array",
89 | "ensure_tuple",
90 | "StringHolderEnum",
91 | "is_jax_or_hax_array_like",
92 | ]
93 |
--------------------------------------------------------------------------------
/src/haliax/wrap.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from typing import Optional, Protocol
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | from haliax.core import NamedArray, _broadcast_order, broadcast_to
8 | from haliax.jax_utils import is_scalarish
9 | from haliax.util import ensure_tuple
10 |
11 | from .axis import AxisSelection, AxisSelector, selects_axis
12 |
13 |
14 | def wrap_elemwise_unary(f, a, *args, **kwargs):
15 | if isinstance(a, NamedArray):
16 | return NamedArray(f(a.array, *args, **kwargs), a.axes)
17 | else:
18 | return f(a, *args, **kwargs)
19 |
20 |
21 | def wrap_reduction_call(
22 | fn,
23 | a,
24 | axis: Optional[AxisSelection],
25 | where: Optional[NamedArray] = None,
26 | single_axis_only: bool = False,
27 | supports_where: bool = True,
28 | **kwargs,
29 | ):
30 | kwargs = dict(kwargs)
31 | if where is not None and not supports_where:
32 | raise ValueError(f"where is not supported by {fn.__name__}")
33 |
34 | if kwargs.get("out", None) is not None:
35 | raise ValueError("out is not supported yet for NamedArray")
36 | if kwargs.get("keepdims", False):
37 | raise ValueError("keepdims is not supported for NamedArray")
38 |
39 | def reduce_one_leaf(a):
40 | nonlocal axis, where
41 | if isinstance(a, NamedArray):
42 | if where is not None:
43 | if not isinstance(where, NamedArray):
44 | raise TypeError(f"where must be a NamedArray if a is a NamedArray, but is {where}")
45 | where = broadcast_to(where, a.axes)
46 | kwargs["where"] = where.array
47 |
48 | if axis is None:
49 | result = fn(a.array, axis=None, **kwargs)
50 | return NamedArray(result, ())
51 | else:
52 | axis = ensure_tuple(axis)
53 | if single_axis_only and len(axis) > 1:
54 | raise ValueError(f"{fn.__name__} only supports a single axis")
55 | indices = a._lookup_indices(axis)
56 | if indices is None or any(x is None for x in indices):
57 | raise ValueError(f"axis {axis} is not in {a.axes}")
58 | new_axes = [ax for ax in a.axes if not selects_axis(axis, ax)]
59 | if single_axis_only:
60 | result = fn(a.array, axis=indices[0], **kwargs)
61 | else:
62 | result = fn(a.array, axis=indices, **kwargs)
63 | return NamedArray(result, tuple(new_axes))
64 | else:
65 | if where is not None:
66 | kwargs["where"] = where
67 | return fn(a, axis=axis, **kwargs)
68 |
69 | return jax.tree_util.tree_map(reduce_one_leaf, a, is_leaf=lambda x: isinstance(x, NamedArray))
70 |
71 |
72 | def wrap_axiswise_call(fn, a, axis: Optional[AxisSelection], *, single_axis_only: bool, **kwargs):
73 | if isinstance(a, NamedArray):
74 | if axis is None:
75 | return fn(a.array, axis=None, **kwargs)
76 | else:
77 | indices = ensure_tuple(a._lookup_indices(axis))
78 | if any(x is None for x in indices):
79 | raise ValueError(f"axis {axis} is not in {a.axes}")
80 | if len(indices) == 1:
81 | return NamedArray(fn(a.array, axis=indices[0], **kwargs), a.axes)
82 | elif single_axis_only:
83 | raise ValueError(f"{fn.__name__} only supports a single axis")
84 | else:
85 | return NamedArray(fn(a.array, axis=indices, **kwargs), a.axes)
86 |
87 | else:
88 | return fn(a, axis=axis, **kwargs)
89 |
90 |
91 | def wrap_elemwise_binary(op):
92 | def binop(a, b):
93 | if isinstance(a, NamedArray) and isinstance(b, NamedArray):
94 | axes = _broadcast_order(a, b)
95 | a = broadcast_to(a, axes)
96 | b = broadcast_to(b, axes)
97 | return NamedArray(op(a.array, b.array), axes)
98 | elif isinstance(a, NamedArray):
99 | return NamedArray(op(a.array, b), a.axes)
100 | elif isinstance(b, NamedArray):
101 | return NamedArray(op(a, b.array), b.axes)
102 | else:
103 | return op(a, b)
104 |
105 | return binop
106 |
107 |
108 | def unwrap_namedarrays(*a):
109 | return tuple(x.array if isinstance(x, NamedArray) else x for x in a)
110 |
111 |
112 | class ReductionFunction(Protocol):
113 | def __call__(
114 | self,
115 | array: NamedArray,
116 | axis: Optional[AxisSelection] = None,
117 | where: Optional[NamedArray] = None,
118 | **kwargs,
119 | ) -> NamedArray:
120 | ...
121 |
122 |
123 | class SimpleReductionFunction(Protocol):
124 | def __call__(self, array: NamedArray, axis: Optional[AxisSelector] = None, **kwargs) -> NamedArray:
125 | ...
126 |
127 |
128 | __all__ = [
129 | "wrap_elemwise_unary",
130 | "wrap_reduction_call",
131 | "wrap_axiswise_call",
132 | "wrap_elemwise_binary",
133 | "unwrap_namedarrays",
134 | "ReductionFunction",
135 | "SimpleReductionFunction",
136 | ]
137 |
--------------------------------------------------------------------------------
/tests/test_attention.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import numpy as np
3 | from jax.random import PRNGKey
4 |
5 | import haliax as hax
6 | from haliax.nn.attention import (
7 | alibi_attention_bias,
8 | causal_mask,
9 | dot_product_attention,
10 | dot_product_attention_weights,
11 | forgetful_causal_mask,
12 | self_attention,
13 | )
14 | from test_utils import skip_if_no_torch
15 |
16 |
17 | def test_dot_product_attention_requires_axis_to_be_present():
18 | Pos = hax.Axis("Pos", 20)
19 | KeyPos = hax.Axis("Pos_key", 20)
20 | NumHeads = hax.Axis("NumHeads", 1)
21 | Hid = hax.Axis("Hid", 8)
22 |
23 | query = hax.ones((NumHeads, KeyPos, Hid)) # NB: KeyPos not Pos
24 | key = hax.ones((KeyPos, NumHeads, Hid))
25 | value = hax.ones((KeyPos, NumHeads, Hid))
26 |
27 | try:
28 | dot_product_attention(Pos, Hid, query, key, value)
29 | except ValueError as e:
30 | assert "not found" in str(e)
31 | else:
32 | raise AssertionError("Should have raised an error")
33 |
34 |
35 | def test_attention_doesnt_allow_overlapping_axes():
36 | KeyPos = hax.Axis("Pos_key", 20)
37 | NumHeads = hax.Axis("NumHeads", 1)
38 | Hid = hax.Axis("Hid", 8)
39 |
40 | query = hax.ones((NumHeads, KeyPos, Hid)) # NB: KeyPos not Pos
41 | key = hax.ones((KeyPos, NumHeads, Hid))
42 | value = hax.ones((KeyPos, NumHeads, Hid))
43 |
44 | try:
45 | dot_product_attention(KeyPos, Hid, query, key, value)
46 | except ValueError as e:
47 | assert "must be distinct" in str(e)
48 | else:
49 | raise AssertionError("Should have raised an error")
50 |
51 |
52 | def test_self_attention_basically_works():
53 | Pos = hax.Axis("Pos", 20)
54 | KeyPos = hax.Axis("Pos_key", 20)
55 | NumHeads = hax.Axis("NumHeads", 1)
56 | Hid = hax.Axis("Hid", 8)
57 |
58 | query = hax.ones((NumHeads, Pos, Hid))
59 |
60 | result = self_attention(Pos, Hid, query, query, query, is_causal=True)
61 | assert result.axes == (NumHeads, Pos, Hid)
62 |
63 | k = query.rename({Pos: KeyPos})
64 | cmask = causal_mask(Pos, KeyPos)
65 | result2 = dot_product_attention(KeyPos, Hid, query, k, k, mask=cmask)
66 | assert result2.axes == (NumHeads, Pos, Hid)
67 |
68 | # tight tolerances because it should be exactly the same computation
69 | assert jnp.allclose(result.array, result2.array)
70 |
71 |
72 | def test_alibi_attention_bias():
73 | KeyPos = hax.Axis("KeyPos", 20)
74 | NumHeads = hax.Axis("NumHeads", 1)
75 | Hid = hax.Axis("Hid", 8)
76 |
77 | bias = alibi_attention_bias(NumHeads, KeyPos)
78 |
79 | query = hax.ones((NumHeads, Hid))
80 | key = hax.ones((KeyPos, NumHeads, Hid))
81 |
82 | weights_bias = dot_product_attention_weights(Hid, KeyPos, query, key, bias=bias)
83 | weights_no_bias = dot_product_attention_weights(Hid, KeyPos, query, key)
84 |
85 | assert weights_bias[KeyPos, -1] > weights_bias[KeyPos, -2]
86 | assert weights_bias[KeyPos, -1] > weights_no_bias[KeyPos, -1]
87 |
88 | assert weights_no_bias[KeyPos, -1] == weights_no_bias[KeyPos, -2]
89 |
90 |
91 | @skip_if_no_torch
92 | def test_alibi_attention_compared_to_hf():
93 | import torch
94 | from transformers.models.bloom.modeling_bloom import build_alibi_tensor
95 |
96 | L, H = hax.make_axes(L=1, H=16)
97 |
98 | # Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
99 | torch_tensor = (
100 | build_alibi_tensor(torch.ones(1, L.size), H.size, dtype=torch.float32).numpy().reshape(H.size, L.size)
101 | )
102 |
103 | hax_tensor = np.array(alibi_attention_bias(H, L).array)
104 |
105 | assert np.allclose(torch_tensor, hax_tensor)
106 |
107 |
108 | def test_fcm_attention_mask():
109 | KeyPos, QueryPos, Head = hax.make_axes(KeyPos=20, QueryPos=10, Head=8)
110 |
111 | mask = forgetful_causal_mask(KeyPos, mask_prob=0.6, sample_prob=False, key=PRNGKey(0))
112 |
113 | assert mask.axes == (KeyPos,)
114 | assert mask.array[0].item() == 1
115 |
116 | assert mask.astype(float).sum().item() <= KeyPos.size
117 |
118 | query = hax.arange(QueryPos).broadcast_axis(Head)
119 | key = hax.arange(KeyPos).broadcast_axis(Head)
120 |
121 | weights = dot_product_attention_weights(Head, KeyPos, query, key, mask=mask)
122 |
123 | # check that all masked out values are zero
124 | weights = weights.rearrange((KeyPos, QueryPos))
125 |
126 | assert (weights * (mask == 0)).sum() == 0
127 | assert (weights * (mask == 1)).sum() > 0
128 |
--------------------------------------------------------------------------------
/tests/test_axis.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from haliax.axis import Axis, eliminate_axes, make_axes, overlapping_axes, rearrange_for_partial_order
4 |
5 |
6 | def test_eliminate_axes():
7 | H, W, C = make_axes(H=3, W=4, C=5)
8 |
9 | assert eliminate_axes((H, W), (H,)) == (W,)
10 | assert eliminate_axes((H, W), (W,)) == (H,)
11 | assert eliminate_axes((H, W), (H, W)) == ()
12 |
13 | with pytest.raises(ValueError):
14 | eliminate_axes((H, W), (C,))
15 |
16 | with pytest.raises(ValueError):
17 | eliminate_axes((H, W), (H, C))
18 |
19 | # test string references
20 | assert eliminate_axes((H, W), ("H",)) == (W,)
21 | assert eliminate_axes(("H", W), (H,)) == (W,)
22 | assert eliminate_axes(("H", W), ("H",)) == (W,)
23 | assert eliminate_axes(("H", W), ("H", "W")) == ()
24 |
25 |
26 | def assert_partial_order_respected(partial_order, output):
27 | positions = {el: i for i, el in enumerate(output)}
28 |
29 | last_pos = -1
30 | for el in partial_order:
31 | if el is ...:
32 | # Reset last_pos for flexible positions
33 | last_pos = -1
34 | else:
35 | # Check if the element is in the correct order
36 | assert el in positions, f"{el} is missing in the output"
37 | assert positions[el] > last_pos, f"Partial order not respected for {el}"
38 | last_pos = positions[el]
39 |
40 |
41 | def test_basic_order():
42 | partial_order = ("apple", ..., "banana")
43 | candidates = ("banana", "apple", "cherry")
44 | expected_output = ("apple", "cherry", "banana")
45 | actual_output = rearrange_for_partial_order(partial_order, candidates)
46 | assert actual_output == expected_output
47 | assert_partial_order_respected(partial_order, actual_output)
48 |
49 |
50 | def test_start_with_ellipsis():
51 | partial_order = (..., "apple", "banana")
52 | candidates = ("banana", "apple", "cherry")
53 | actual_output = rearrange_for_partial_order(partial_order, candidates)
54 | assert_partial_order_respected(partial_order, actual_output)
55 | assert actual_output == ("cherry", "apple", "banana")
56 |
57 |
58 | def test_end_with_ellipsis():
59 | partial_order = ("apple", ..., "banana", ...)
60 | candidates = ("banana", "apple", "cherry")
61 | actual_output = rearrange_for_partial_order(partial_order, candidates)
62 | assert_partial_order_respected(partial_order, actual_output)
63 |
64 | # this one could be either but we'll assert the order so we catch changes
65 | assert actual_output == ("apple", "banana", "cherry")
66 |
67 |
68 | def test_full_flexibility():
69 | partial_order = (...,)
70 | candidates = ("banana", "apple", "cherry")
71 | actual_output = rearrange_for_partial_order(partial_order, candidates)
72 | assert_partial_order_respected(partial_order, actual_output)
73 |
74 |
75 | def test_no_flexibility():
76 | partial_order = ("apple", "banana")
77 | candidates = ("banana", "apple", "cherry")
78 | with pytest.raises(ValueError):
79 | rearrange_for_partial_order(partial_order, candidates)
80 |
81 |
82 | def test_final_ellipsis():
83 | partial_order = ("apple", "banana", ...)
84 | candidates = ("banana", "apple", "cherry")
85 | actual_output = rearrange_for_partial_order(partial_order, candidates)
86 | assert_partial_order_respected(partial_order, actual_output)
87 | assert actual_output == ("apple", "banana", "cherry")
88 |
89 |
90 | def test_lots_of_ellipsis():
91 | partial_order = ("apple", ..., "banana", ..., "cherry", ...)
92 | candidates = ("banana", "orange", "cherry", "apple", "grape")
93 | actual_output = rearrange_for_partial_order(partial_order, candidates)
94 | assert_partial_order_respected(partial_order, actual_output)
95 | assert actual_output == ("apple", "banana", "orange", "cherry", "grape")
96 |
97 |
98 | def test_no_ellipsis():
99 | partial_order = ("apple", "banana", "cherry")
100 | candidates = ("banana", "apple", "cherry")
101 | actual_output = rearrange_for_partial_order(partial_order, candidates)
102 | assert_partial_order_respected(partial_order, actual_output)
103 | assert actual_output == ("apple", "banana", "cherry")
104 |
105 |
106 | def test_no_elements():
107 | partial_order = (...,)
108 | candidates = ()
109 | actual_output = rearrange_for_partial_order(partial_order, candidates)
110 | assert_partial_order_respected(partial_order, actual_output)
111 | assert actual_output == ()
112 |
113 |
114 | def test_missing_elements_errors():
115 | partial_order = ("qux", ...)
116 | candidates = ("banana", "apple", "cherry")
117 | with pytest.raises(ValueError):
118 | rearrange_for_partial_order(partial_order, candidates)
119 |
120 |
121 | def test_duplicate_elements_errors():
122 | partial_order: tuple = ("apple", "apple", ...)
123 | candidates = ("banana", "apple", "cherry")
124 | with pytest.raises(ValueError):
125 | rearrange_for_partial_order(partial_order, candidates)
126 |
127 | candidates = ("banana", "apple", "apple")
128 |
129 | with pytest.raises(ValueError):
130 | rearrange_for_partial_order(partial_order, candidates)
131 |
132 | partial_order = ("apple", "banana", "grape", ...)
133 |
134 | with pytest.raises(ValueError):
135 | rearrange_for_partial_order(partial_order, candidates)
136 |
137 |
138 | def test_overlapping_axes_with_different_sizes():
139 | A1 = Axis("A", 10)
140 | A2 = Axis("A", 12)
141 | B = Axis("B", 14)
142 | C = Axis("C", 16)
143 | D = Axis("D", 18)
144 |
145 | ax1 = (A1, B, C)
146 | ax2 = (A2, C, D)
147 |
148 | overlapping_names = overlapping_axes(ax1, ax2) # Should not error
149 | assert overlapping_names == ("A", "C")
150 |
--------------------------------------------------------------------------------
/tests/test_conv.py:
--------------------------------------------------------------------------------
1 | import equinox as eqx
2 | import jax
3 | import jax.numpy as jnp
4 |
5 | import haliax as hax
6 | from haliax.nn.conv import Conv, ConvTranspose
7 |
8 |
9 | def test_conv_basic_equiv_to_eqx():
10 | key = jax.random.PRNGKey(0)
11 | In = hax.Axis("In", 3)
12 | Out = hax.Axis("Out", 4)
13 | hax_conv = Conv.init(("Height", "Width"), In, Out, kernel_size=3, key=key)
14 | eqx_conv = eqx.nn.Conv(2, 3, 4, kernel_size=3, key=key)
15 |
16 | assert hax_conv.weight.array.shape == eqx_conv.weight.shape
17 | assert hax_conv.bias.array.shape == eqx_conv.bias.shape[0:1]
18 | assert jnp.all(hax_conv.weight.array == eqx_conv.weight)
19 |
20 | input = hax.random.normal(jax.random.PRNGKey(1), (In, hax.Axis("Height", 5), hax.Axis("Width", 6)))
21 | hax_output = hax_conv(input)
22 | eqx_output = eqx_conv(input.array)
23 |
24 | assert hax_output.array.shape == eqx_output.shape
25 | assert jnp.all(hax_output.array == eqx_output)
26 |
27 | # test batched
28 | input = hax.random.normal(
29 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6))
30 | )
31 | hax_output = hax_conv(input)
32 | eqx_output = eqx.filter_vmap(eqx_conv)(input.array)
33 |
34 | assert hax_output.array.shape == eqx_output.shape
35 | assert jnp.all(hax_output.array == eqx_output)
36 |
37 | # test multibatch
38 | input = hax.random.normal(
39 | jax.random.PRNGKey(1),
40 | (hax.Axis("Batch", 2), hax.Axis("Batch2", 3), In, hax.Axis("Height", 5), hax.Axis("Width", 6)),
41 | )
42 | hax_output = hax_conv(input)
43 | eqx_output = eqx.filter_vmap(eqx.filter_vmap(eqx_conv))(input.array)
44 |
45 | assert hax_output.array.shape == eqx_output.shape
46 | assert jnp.allclose(hax_output.array, eqx_output)
47 |
48 | input = hax.random.normal(
49 | jax.random.PRNGKey(1),
50 | (
51 | hax.Axis("Batch", 2),
52 | In,
53 | hax.Axis("Height", 5),
54 | hax.Axis("Width", 6),
55 | hax.Axis("Batch2", 3),
56 | ),
57 | )
58 | hax_output = hax_conv(input).rearrange(("Batch", "Batch2", "Out", "Height", "Width"))
59 | eqx_output = eqx.filter_vmap(eqx.filter_vmap(eqx_conv))(
60 | input.rearrange(("Batch", "Batch2", "In", "Height", "Width")).array
61 | )
62 |
63 | assert hax_output.array.shape == eqx_output.shape
64 | assert jnp.allclose(hax_output.array, eqx_output)
65 |
66 |
67 | def test_conv_grouped_equiv_to_eqx():
68 | key = jax.random.PRNGKey(0)
69 | In = hax.Axis("In", 4)
70 | Out = hax.Axis("Out", 6)
71 | hax_conv = Conv.init(("Height", "Width"), In, Out, kernel_size=3, groups=2, key=key)
72 | eqx_conv = eqx.nn.Conv(2, 4, 6, kernel_size=3, groups=2, key=key)
73 |
74 | assert hax_conv.weight.array.shape == eqx_conv.weight.shape
75 | assert hax_conv.bias.array.shape == eqx_conv.bias.shape[0:1]
76 | assert jnp.all(hax_conv.weight.array == eqx_conv.weight)
77 |
78 | input = hax.random.normal(jax.random.PRNGKey(1), (In, hax.Axis("Height", 5), hax.Axis("Width", 6)))
79 | eqx_output = eqx_conv(input.array)
80 | hax_output = hax_conv(input)
81 |
82 | assert hax_output.array.shape == eqx_output.shape
83 | assert jnp.all(hax_output.array == eqx_output)
84 |
85 | # test batched
86 | input = hax.random.normal(
87 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6))
88 | )
89 | hax_output = hax_conv(input)
90 | eqx_output = eqx.filter_vmap(eqx_conv)(input.array)
91 |
92 | assert hax_output.array.shape == eqx_output.shape
93 | assert jnp.all(hax_output.array == eqx_output)
94 |
95 | # test multibatch
96 | input = hax.random.normal(
97 | jax.random.PRNGKey(1),
98 | (hax.Axis("Batch", 2), hax.Axis("Batch2", 3), In, hax.Axis("Height", 5), hax.Axis("Width", 6)),
99 | )
100 | hax_output = hax_conv(input)
101 | eqx_output = eqx.filter_vmap(eqx.filter_vmap(eqx_conv))(input.array)
102 |
103 | assert hax_output.array.shape == eqx_output.shape
104 | assert jnp.allclose(hax_output.array, eqx_output)
105 |
106 |
107 | def test_conv_weird_order():
108 | key = jax.random.PRNGKey(0)
109 | In = hax.Axis("In", 3)
110 | Out = hax.Axis("Out", 4)
111 | hax_conv = Conv.init(("Height", "Width"), In, Out, kernel_size=3, key=key)
112 | eqx_conv = eqx.nn.Conv(2, 3, 4, kernel_size=3, key=key)
113 |
114 | assert hax_conv.weight.array.shape == eqx_conv.weight.shape
115 | assert hax_conv.bias.array.shape == eqx_conv.bias.shape[0:1]
116 | assert jnp.all(hax_conv.weight.array == eqx_conv.weight)
117 |
118 | input = hax.random.normal(
119 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6))
120 | )
121 | hax_output = hax_conv(input)
122 |
123 | # test weird orders
124 | input = input.rearrange(("In", "Height", "Width", "Batch"))
125 | hax_output2 = hax_conv(input).rearrange(("Batch", "Out", "Height", "Width"))
126 |
127 | assert jnp.allclose(hax_output.array, hax_output2.array)
128 |
129 |
130 | def test_conv_transpose_basic_equiv_to_eqx():
131 | key = jax.random.PRNGKey(0)
132 | In = hax.Axis("In", 3)
133 | Out = hax.Axis("Out", 4)
134 | hax_conv = ConvTranspose.init(
135 | ("Height", "Width"), In, Out, kernel_size=3, dilation=2, output_padding=1, stride=2, key=key
136 | )
137 | eqx_conv = eqx.nn.ConvTranspose(2, 3, 4, kernel_size=3, dilation=2, output_padding=1, stride=2, key=key)
138 |
139 | assert hax_conv.weight.array.shape == eqx_conv.weight.shape
140 | assert hax_conv.bias.array.shape == eqx_conv.bias.shape[0:1]
141 | assert jnp.all(hax_conv.weight.array == eqx_conv.weight)
142 |
143 | input = hax.random.normal(jax.random.PRNGKey(1), (In, hax.Axis("Height", 5), hax.Axis("Width", 6)))
144 | hax_output = hax_conv(input)
145 | eqx_output = eqx_conv(input.array)
146 |
147 | assert hax_output.array.shape == eqx_output.shape
148 | assert jnp.all(hax_output.array == eqx_output)
149 |
150 | # test batched
151 | input = hax.random.normal(
152 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6))
153 | )
154 | hax_output = hax_conv(input)
155 | eqx_output = eqx.filter_vmap(eqx_conv)(input.array)
156 |
157 | assert hax_output.array.shape == eqx_output.shape
158 | assert jnp.all(hax_output.array == eqx_output)
159 |
160 | # test multibatch
161 | input = hax.random.normal(
162 | jax.random.PRNGKey(1),
163 | (hax.Axis("Batch", 2), hax.Axis("Batch2", 3), In, hax.Axis("Height", 5), hax.Axis("Width", 6)),
164 | )
165 | hax_output = hax_conv(input)
166 | eqx_output = eqx.filter_vmap(eqx.filter_vmap(eqx_conv))(input.array)
167 |
168 | assert hax_output.array.shape == eqx_output.shape
169 | assert jnp.allclose(hax_output.array, eqx_output)
170 |
171 |
172 | def test_weird_orders_conv_transpose():
173 | key = jax.random.PRNGKey(0)
174 | In = hax.Axis("In", 3)
175 | Out = hax.Axis("Out", 4)
176 | hax_conv = ConvTranspose.init(
177 | ("Height", "Width"), In, Out, kernel_size=3, dilation=2, output_padding=1, stride=2, key=key
178 | )
179 |
180 | input = hax.random.normal(
181 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6))
182 | )
183 | hax_output = hax_conv(input)
184 |
185 | # test weird orders
186 | input = input.rearrange(("In", "Height", "Width", "Batch"))
187 | hax_output2 = hax_conv(input).rearrange(("Batch", "Out", "Height", "Width"))
188 |
189 | assert jnp.allclose(hax_output.array, hax_output2.array)
190 |
--------------------------------------------------------------------------------
/tests/test_debug.py:
--------------------------------------------------------------------------------
1 | import equinox as eqx
2 | import jax.numpy as jnp
3 | import pytest
4 |
5 | import haliax as hax
6 |
7 |
8 | def test_diagnose_common_issues_repeated():
9 | class M(eqx.Module):
10 | a: jnp.ndarray = eqx.field()
11 | b: jnp.ndarray = eqx.field()
12 |
13 | def __init__(self):
14 | super().__init__()
15 | self.a = jnp.zeros(1)
16 | self.b = self.a
17 |
18 | try:
19 | hax.debug.diagnose_common_issues(M())
20 | pytest.fail("Should have raised an exception")
21 | except hax.debug.ModuleProblems as e:
22 | assert len(e.reused_arrays) == 1
23 | assert len(e.static_arrays) == 0
24 |
25 |
26 | def test_diagnose_common_issues_repeated_nested():
27 | class M(eqx.Module):
28 | a: jnp.ndarray = eqx.field()
29 | b: jnp.ndarray = eqx.field()
30 |
31 | def __init__(self):
32 | super().__init__()
33 | self.a = jnp.zeros(1)
34 | self.b = self.a
35 |
36 | class N(eqx.Module):
37 | m: M = eqx.field()
38 | c: jnp.ndarray = eqx.field()
39 |
40 | def __init__(self):
41 | super().__init__()
42 | self.m = M()
43 | self.c = self.m.a
44 |
45 | try:
46 | hax.debug.diagnose_common_issues(N())
47 | pytest.fail("Should have raised an exception")
48 | except hax.debug.ModuleProblems as e:
49 | assert len(e.reused_arrays) == 1
50 | assert e.reused_arrays[0][1] == [".m.a", ".m.b", ".c"]
51 | assert len(e.static_arrays) == 0
52 |
53 |
54 | def test_diagnose_common_issues_static():
55 | class M(eqx.Module):
56 | a: jnp.ndarray = eqx.static_field()
57 | b: hax.NamedArray = eqx.static_field()
58 |
59 | def __init__(self):
60 | super().__init__()
61 | self.a = jnp.zeros(1)
62 | self.b = hax.named(jnp.zeros(3), "a")
63 |
64 | try:
65 | hax.debug.diagnose_common_issues(M())
66 | pytest.fail("Should have raised an exception")
67 | except hax.debug.ModuleProblems as e:
68 | assert len(e.reused_arrays) == 0
69 | assert len(e.static_arrays) == 2
70 |
71 |
72 | def test_diagnose_common_issues_static_nested():
73 | class M(eqx.Module):
74 | a: jnp.ndarray = eqx.static_field()
75 | b: hax.NamedArray = eqx.static_field()
76 |
77 | def __init__(self):
78 | super().__init__()
79 | self.a = jnp.zeros(1)
80 | self.b = hax.named(jnp.zeros(3), "a")
81 |
82 | class N(eqx.Module):
83 | m: M = eqx.field()
84 | c: jnp.ndarray = eqx.field()
85 |
86 | def __init__(self):
87 | super().__init__()
88 | self.m = M()
89 | self.c = self.m.a
90 |
91 | try:
92 | hax.debug.diagnose_common_issues(N())
93 | pytest.fail("Should have raised an exception")
94 | except hax.debug.ModuleProblems as e:
95 | assert len(e.reused_arrays) == 0
96 | assert len(e.static_arrays) == 2
97 |
--------------------------------------------------------------------------------
/tests/test_dot.py:
--------------------------------------------------------------------------------
1 | # these test if the rearrange logic works for partial orders
2 | import pytest
3 | from jax import numpy as jnp
4 |
5 | import haliax as hax
6 | from haliax import Axis, NamedArray
7 |
8 |
9 | def test_dot():
10 | Height = Axis("Height", 2)
11 | Width = Axis("Width", 3)
12 | Depth = Axis("Depth", 4)
13 |
14 | m1 = hax.ones((Height, Width, Depth))
15 | m2 = hax.ones((Depth, Width, Height))
16 |
17 | assert jnp.all(jnp.equal(hax.dot(m1, m2, axis=Height).array, jnp.einsum("ijk,kji->jk", m1.array, m2.array)))
18 | assert jnp.all(
19 | jnp.equal(
20 | hax.dot(m1, m2, axis=(Height, Width)).array,
21 | jnp.einsum("ijk,kji->k", m1.array, m2.array),
22 | )
23 | )
24 | assert jnp.all(
25 | jnp.equal(
26 | hax.dot(m1, m2, axis=(Height, Width, Depth)).array,
27 | jnp.einsum("ijk,kji->", m1.array, m2.array),
28 | )
29 | )
30 |
31 | # reduce to scalar
32 | assert jnp.all(
33 | jnp.equal(
34 | hax.dot(m1, m2, axis=None),
35 | jnp.einsum("ijk,kji->", m1.array, m2.array),
36 | )
37 | )
38 |
39 |
40 | def test_dot_string_selection():
41 | Height = Axis("Height", 2)
42 | Width = Axis("Width", 3)
43 | Depth = Axis("Depth", 4)
44 |
45 | m1 = hax.ones((Height, Width, Depth))
46 | m2 = hax.ones((Depth, Width, Height))
47 |
48 | assert jnp.all(jnp.equal(hax.dot(m1, m2, axis="Height").array, jnp.einsum("ijk,kji->jk", m1.array, m2.array)))
49 | assert jnp.all(
50 | jnp.equal(
51 | hax.dot(m1, m2, axis=("Height", "Width")).array,
52 | jnp.einsum("ijk,kji->k", m1.array, m2.array),
53 | )
54 | )
55 | assert jnp.all(
56 | jnp.equal(
57 | hax.dot(m1, m2, axis=("Height", "Width", "Depth")).array,
58 | jnp.einsum("ijk,kji->", m1.array, m2.array),
59 | )
60 | )
61 |
62 |
63 | def test_dot_errors_if_different_sized_axes():
64 | Height = Axis("Height", 2)
65 | Width = Axis("Width", 3)
66 | Depth = Axis("Depth", 4)
67 |
68 | H2 = Axis("Height", 4)
69 |
70 | m1 = hax.ones((Height, Width, Depth))
71 | m2 = hax.ones((Depth, Width, H2))
72 |
73 | with pytest.raises(ValueError):
74 | hax.dot(m1, m2, axis="Height")
75 |
76 |
77 | def test_dot_with_output_axes():
78 | Height = Axis("Height", 2)
79 | Width = Axis("Width", 3)
80 | Depth = Axis("Depth", 4)
81 |
82 | m1 = hax.ones((Height, Width, Depth))
83 | m2 = hax.ones((Depth, Width, Height))
84 |
85 | assert jnp.all(
86 | jnp.equal(
87 | hax.dot(m1, m2, axis=Height, out_axes=(Width, ...)).array,
88 | jnp.einsum("ijk,kji->jk", m1.array, m2.array),
89 | )
90 | )
91 |
92 | assert jnp.all(
93 | jnp.equal(
94 | hax.dot(m1, m2, axis=Height, out_axes=(Depth, ...)).array,
95 | jnp.einsum("ijk,kji->kj", m1.array, m2.array),
96 | )
97 | )
98 |
99 | assert jnp.all(
100 | jnp.equal(
101 | hax.dot(m1, m2, axis=Height, out_axes=(Depth, Width)).array,
102 | jnp.einsum("ijk,kji->kj", m1.array, m2.array),
103 | )
104 | )
105 |
106 | assert jnp.all(
107 | jnp.equal(
108 | hax.dot(m1, m2, axis=(), out_axes=(Depth, Height, ...)).array,
109 | jnp.einsum("ijk,kji->kij", m1.array, m2.array),
110 | )
111 | )
112 |
113 | assert jnp.all(
114 | jnp.equal(
115 | hax.dot(m1, m2, axis=(), out_axes=(..., Depth, Height)).array,
116 | jnp.einsum("ijk,kji->jki", m1.array, m2.array),
117 | )
118 | )
119 |
120 | assert jnp.all(
121 | jnp.equal(
122 | hax.dot(m1, m2, axis=(), out_axes=(..., Depth, ...)).array,
123 | jnp.einsum("ijk,kji->ijk", m1.array, m2.array),
124 | )
125 | )
126 |
127 | assert jnp.all(
128 | jnp.equal(
129 | hax.dot(m1, m2, axis=(), out_axes=(..., Depth, Height, ...)).array,
130 | jnp.einsum("ijk,kji->kij", m1.array, m2.array),
131 | )
132 | )
133 |
--------------------------------------------------------------------------------
/tests/test_fp8.py:
--------------------------------------------------------------------------------
1 | import chex
2 | import equinox as eqx
3 | import jax.numpy as jnp
4 | import jax.random as jrandom
5 | import jax.tree_util
6 | import numpy as np
7 | from chex import assert_trees_all_close
8 |
9 | import haliax as hax
10 | from haliax._src.fp8 import compute_scale
11 | from haliax.nn import Linear
12 | from haliax.quantization import (
13 | Fp8DotGeneralOp,
14 | QuantizationConfig,
15 | apply_updates,
16 | partition_for_grad_overwrite,
17 | quantize_linear_layers,
18 | )
19 |
20 |
21 | def test_fp8_is_reasonable():
22 | In = hax.Axis("In", 8)
23 | Out = hax.Axis("Out", 8)
24 | linear = Linear.init(In, Out, key=jrandom.PRNGKey(0), init_scale=0.1)
25 |
26 | fp8_linear = Linear.init(
27 | In, Out, key=jrandom.PRNGKey(0), dot_general=hax.quantization.Fp8DotGeneralOp.init(), init_scale=0.1
28 | )
29 |
30 | input = hax.random.normal(jrandom.PRNGKey(3), In)
31 | output = linear(input)
32 | fp8_output = fp8_linear(input)
33 |
34 | assert output.shape == fp8_output.shape
35 | assert output.dtype == fp8_output.dtype
36 |
37 | assert_trees_all_close(output.array, fp8_output.array, atol=2e-2, rtol=5e-2)
38 |
39 |
40 | # https://github.com/google/flax/blob/6f2b08e024c2fd2f8cec42a6c82408cb35412319/tests/linen/linen_test.py#L1222
41 | def test_fp_loop():
42 | key, init_key, random_key = jrandom.split(jrandom.PRNGKey(seed=123), 3)
43 | Batch = hax.Axis("Batch", 16)
44 | In = hax.Axis("In", 16)
45 | Out = hax.Axis("Out", 32)
46 | linear = Linear.init(In, Out, key=init_key, dot_general=Fp8DotGeneralOp.init())
47 |
48 | def _roll_and_update(amax_h, update):
49 | return jnp.roll(amax_h, shift=-1, axis=0).at[0].set(update)
50 |
51 | lr = 1e-3
52 |
53 | def apply_gradients(model, grads):
54 | overwrites, grads = partition_for_grad_overwrite(grads)
55 | updates = jax.tree_util.tree_map(lambda g: -lr * g, grads)
56 | model = apply_updates(model, updates, overwrites)
57 | return model
58 |
59 | def _train_step(model, x, dy):
60 | def loss_fn(lin):
61 | y = lin(x)
62 | loss = y * dy.astype(y.dtype)
63 | return hax.sum(loss).scalar()
64 |
65 | grad_fn = eqx.filter_grad(loss_fn)
66 | grads = grad_fn(model)
67 | return apply_gradients(model, grads)
68 |
69 | train_fn = eqx.filter_jit(_train_step)
70 |
71 | scale_x, amax_history_x = jnp.ones(()), jnp.zeros((1024,))
72 | scale_k, amax_history_k = jnp.ones(()), jnp.zeros((1024,))
73 | scale_g, amax_history_g = jnp.ones(()), jnp.zeros((1024,))
74 | e4m3_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32)
75 | e5m2_max = jnp.finfo(jnp.float8_e5m2).max.astype(jnp.float32)
76 |
77 | for _ in range(5):
78 | key, random_key = jrandom.split(key, 2)
79 | # x = jrandom.normal(random_key, (16, 16), dtype=jnp.float32)
80 | # g = jrandom.normal(random_key, (16, 32), dtype=jnp.float32)
81 | x = hax.random.normal(
82 | random_key,
83 | (
84 | Batch,
85 | In,
86 | ),
87 | )
88 | g = hax.random.normal(
89 | random_key,
90 | (
91 | Batch,
92 | Out,
93 | ),
94 | )
95 |
96 | # Manually compute the expected amax history and scaling factors.
97 | amax_from_history_x = jnp.max(amax_history_x, axis=0)
98 | amax_from_history_k = jnp.max(amax_history_k, axis=0)
99 | amax_from_history_g = jnp.max(amax_history_g, axis=0)
100 | scale_x = compute_scale(amax_from_history_x, scale_x, e4m3_max)
101 | scale_k = compute_scale(amax_from_history_k, scale_k, e4m3_max)
102 | scale_g = compute_scale(amax_from_history_g, scale_g, e5m2_max)
103 | amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x.array)))
104 | amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(linear.weight.array)))
105 | amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g.array)))
106 |
107 | linear = train_fn(linear, x, g)
108 |
109 | rtol, atol = 0.001, 0.001
110 | np.testing.assert_allclose(
111 | linear.dot_general.input_amax_history, # type: ignore
112 | amax_history_x,
113 | rtol=rtol,
114 | atol=atol,
115 | )
116 | np.testing.assert_allclose(
117 | linear.dot_general.kernel_amax_history, # type: ignore
118 | amax_history_k,
119 | rtol=rtol,
120 | atol=atol,
121 | )
122 | np.testing.assert_allclose(
123 | linear.dot_general.output_grad_amax_history, # type: ignore
124 | amax_history_g,
125 | rtol=rtol,
126 | atol=atol,
127 | )
128 |
129 | np.testing.assert_allclose(linear.dot_general.input_scale, scale_x, rtol=rtol, atol=atol) # type: ignore
130 | np.testing.assert_allclose(linear.dot_general.kernel_scale, scale_k, rtol=rtol, atol=atol) # type: ignore
131 | np.testing.assert_allclose(linear.dot_general.output_grad_scale, scale_g, rtol=rtol, atol=atol) # type: ignore
132 |
133 |
134 | def test_layer_splicing():
135 | key, init_key, random_key = jrandom.split(jrandom.PRNGKey(seed=123), 3)
136 | Input = hax.Axis("Input", 16)
137 | Hidden = hax.Axis("Hidden", 64)
138 | Output = hax.Axis("Output", 32)
139 | mlp = hax.nn.MLP.init(Input, Output, Hidden, 3, key=init_key, init_scale=0.1)
140 |
141 | mlp_q = quantize_linear_layers(mlp, QuantizationConfig(fp8=True))
142 | for layer in mlp_q.layers:
143 | assert isinstance(layer.dot_general, Fp8DotGeneralOp)
144 |
145 | input = hax.random.normal(jrandom.PRNGKey(0), Input) * 10 # 10 so we don't underflow
146 | output = mlp(input)
147 | output_q = mlp_q(input)
148 | chex.assert_trees_all_close(output.array, output_q.array, atol=1e-3, rtol=1e-3)
149 | assert not jnp.allclose(output_q.array, 0) # don't want them to all underflow
150 |
151 | mlp_q = quantize_linear_layers(mlp, QuantizationConfig(targets="layers.0", fp8=True))
152 | for i, layer in enumerate(mlp_q.layers):
153 | if i == 0:
154 | assert isinstance(layer.dot_general, Fp8DotGeneralOp)
155 | else:
156 | assert not isinstance(layer.dot_general, Fp8DotGeneralOp)
157 |
158 | mlp_q = quantize_linear_layers(mlp, QuantizationConfig(targets=["0", "1"], fp8=True))
159 | for i, layer in enumerate(mlp_q.layers):
160 | if i < 2:
161 | assert isinstance(layer.dot_general, Fp8DotGeneralOp)
162 | else:
163 | assert not isinstance(layer.dot_general, Fp8DotGeneralOp)
164 |
165 |
166 | def test_fp8ize_stacking():
167 | class Block(eqx.Module):
168 | up_proj: hax.nn.Linear
169 | down_proj: hax.nn.Linear
170 |
171 | @staticmethod
172 | def init(In, Out, key):
173 | up_proj = hax.nn.Linear.init(In, Out, key=key)
174 | down_proj = hax.nn.Linear.init(Out, In, key=key)
175 | return Block(up_proj, down_proj)
176 |
177 | def __call__(self, x):
178 | return self.down_proj(self.up_proj(x))
179 |
180 | Layer = hax.Axis("Layer", 3)
181 |
182 | class Tformer(eqx.Module):
183 | blocks: hax.nn.Stacked[Block]
184 |
185 | @staticmethod
186 | def init(In, Out, key):
187 | blocks = hax.nn.Stacked.init(Layer, Block)(In, Out, key=jax.random.split(key, Layer.size))
188 | return Tformer(blocks)
189 |
190 | def __call__(self, x):
191 | return self.blocks.fold(x)
192 |
193 | In = hax.Axis("In", 16)
194 | Out = hax.Axis("Out", 32)
195 | tformer = Tformer.init(In, Out, key=jrandom.PRNGKey(0))
196 | tformer_q = quantize_linear_layers(tformer, QuantizationConfig(fp8=True))
197 |
198 | # want to be sure this vmaps the dot_general to the right places
199 | dg = tformer_q.blocks.stacked.up_proj.dot_general
200 | assert isinstance(dg, Fp8DotGeneralOp)
201 | assert dg.input_scale.shape == (Layer.size, 1)
202 | assert dg.input_amax_history.shape == (Layer.size, 1024)
203 | dg = tformer_q.blocks.stacked.down_proj.dot_general
204 | assert isinstance(dg, Fp8DotGeneralOp)
205 |
206 | # just stack the up_proj
207 | tformer_q = quantize_linear_layers(tformer, QuantizationConfig(targets=["up_proj"], fp8=True))
208 | dg = tformer_q.blocks.stacked.up_proj.dot_general
209 | assert isinstance(dg, Fp8DotGeneralOp)
210 | dg = tformer_q.blocks.stacked.down_proj.dot_general
211 | assert not isinstance(dg, Fp8DotGeneralOp)
212 |
--------------------------------------------------------------------------------
/tests/test_int8.py:
--------------------------------------------------------------------------------
1 | import jax.random as jrandom
2 | from chex import assert_trees_all_close
3 |
4 | import haliax as hax
5 | from haliax.nn import Linear
6 | from haliax.quantization import Int8DotGeneralOp
7 |
8 |
9 | def test_int8_is_reasonable():
10 | In = hax.Axis("In", 8)
11 | Out = hax.Axis("Out", 8)
12 | linear = Linear.init(In, Out, key=jrandom.PRNGKey(0), init_scale=0.1)
13 |
14 | int8_linear = Linear.init(In, Out, key=jrandom.PRNGKey(0), dot_general=Int8DotGeneralOp.init(), init_scale=0.1)
15 |
16 | input = hax.random.normal(jrandom.PRNGKey(3), In)
17 | output = linear(input)
18 | int8_output = int8_linear(input)
19 |
20 | assert output.shape == int8_output.shape
21 | assert output.dtype == int8_output.dtype
22 |
23 | assert_trees_all_close(output.array, int8_output.array, atol=1e-2, rtol=5e-2)
24 |
--------------------------------------------------------------------------------
/tests/test_nn.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import equinox as eqx
4 | import jax.nn
5 | import jax.random as jrandom
6 | import pytest
7 | from jax import numpy as jnp
8 |
9 | import haliax as hax
10 | from haliax import Axis, NamedArray
11 |
12 |
13 | def _compare_eqx_and_haliax(hax_mod: eqx.Module, eqx_mod: eqx.Module):
14 | def f(x: NamedArray, *args, **kwargs):
15 | unnamed_x = x.array
16 | hax_out = hax_mod(x, *args, **kwargs) # type: ignore
17 | eqx_out = eqx_mod(unnamed_x, *args, **kwargs) # type: ignore
18 |
19 | assert jnp.allclose(hax_out.array, eqx_out)
20 | return hax_out
21 |
22 | return f
23 |
24 |
25 | def test_layer_norm():
26 | H = Axis("H", 10)
27 | hax_ln = hax.nn.LayerNorm.init(H)
28 | eqx_ln = eqx.nn.LayerNorm(shape=(H.size,))
29 |
30 | f = _compare_eqx_and_haliax(hax_ln, eqx_ln)
31 | out = f(hax.random.uniform(jrandom.PRNGKey(0), (H,)))
32 |
33 | assert out.axes == (H,)
34 |
35 |
36 | def test_dropout():
37 | H = Axis("H", 10)
38 | key = jrandom.PRNGKey(0)
39 | hax_dropout = hax.nn.Dropout(0.5)
40 | eqx_dropout = eqx.nn.Dropout(0.5)
41 |
42 | f = _compare_eqx_and_haliax(hax_dropout, eqx_dropout)
43 | out = f(hax.random.uniform(jrandom.PRNGKey(0), (H,)), key=key, inference=False)
44 |
45 | assert out.axes == (H,)
46 |
47 |
48 | def test_one_hot():
49 | i, c = hax.make_axes(i=3, c=3)
50 | actual = hax.nn.one_hot(hax.NamedArray(jnp.array([0, 1, 2]), (i,)), c)
51 | expected = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
52 |
53 | assert actual.axes == (i, c)
54 | assert jnp.all(jnp.isclose(actual.array, expected))
55 |
56 | actual = hax.nn.one_hot(hax.NamedArray(jnp.array([1, 2, 0]), (i,)), c)
57 | expected = jnp.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
58 | assert actual.axes == (i, c)
59 | assert jnp.all(jnp.isclose(actual.array, expected))
60 |
61 |
62 | def test_one_hot_out_of_bound():
63 | i, c = hax.make_axes(i=2, c=3)
64 | actual = hax.nn.one_hot(hax.NamedArray(jnp.array([-1, 3]), (i,)), c)
65 | expected = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
66 | assert jnp.all(jnp.isclose(actual.array, expected))
67 |
68 |
69 | def test_standardize():
70 | b, c = hax.make_axes(b=2, c=3)
71 | actual = hax.nn.standardize(hax.NamedArray(jnp.array([0, 1, 2]), (c,)), c)
72 | expected = jax.nn.standardize(jnp.array([0, 1, 2]), axis=0)
73 |
74 | assert actual.axes == (c,)
75 | assert jnp.all(jnp.isclose(actual.array, expected))
76 |
77 | actual = hax.nn.standardize(hax.NamedArray(jnp.array([[0, 1, 2], [3, 4, 5]]), (b, c)), c)
78 | expected = jax.nn.standardize(jnp.array([[0, 1, 2], [3, 4, 5]]), axis=1)
79 |
80 | assert actual.axes == (b, c)
81 | assert jnp.all(jnp.isclose(actual.array, expected))
82 |
83 | actual = hax.nn.standardize(hax.NamedArray(jnp.array([[0, 1, 2], [3, 4, 5]]), (b, c)), b)
84 | expected = jax.nn.standardize(jnp.array([[0, 1, 2], [3, 4, 5]]), axis=0)
85 |
86 | assert actual.axes == (b, c)
87 | assert jnp.all(jnp.isclose(actual.array, expected))
88 |
89 | # test passing in where
90 | mask = hax.NamedArray(jnp.array([True, False, True]), (c,))
91 | actual = hax.nn.standardize(hax.NamedArray(jnp.array([[0, 1, 2], [3, 4, 5]]), (b, c)), b, where=mask)
92 | expected = jax.nn.standardize(jnp.array([[0, 1, 2], [3, 4, 5]]), axis=0, where=mask.array)
93 |
94 | assert actual.axes == (b, c)
95 | assert jnp.all(jnp.isclose(actual.array, expected) | jnp.isnan(expected))
96 |
97 | # now mean/variance
98 | input = hax.NamedArray(jnp.array([[0, 1, 2], [3, 4, 5]]), (b, c))
99 | mean = hax.mean(input, c)
100 | variance = hax.var(input, c)
101 | actual = hax.nn.standardize(input, c, mean=mean, variance=variance)
102 | expected = jax.nn.standardize(
103 | input.array, axis=1, mean=mean.array.reshape(-1, 1), variance=variance.array.reshape(-1, 1)
104 | )
105 |
106 | assert actual.axes == (b, c)
107 | assert jnp.all(jnp.isclose(actual.array, expected))
108 |
109 |
110 | @pytest.mark.parametrize("depth", [0, 1, 2, 3, 4, 5])
111 | def test_mlp(depth):
112 | key = jrandom.PRNGKey(0)
113 | H, C, W, E = hax.make_axes(H=10, C=12, W=14, E=16)
114 |
115 | hax_mlp = hax.nn.MLP.init((H, C, W), E, width=8, depth=depth, key=key)
116 | x = hax.random.uniform(key, (H, C, W))
117 | assert hax_mlp(x).axes == (E,)
118 |
119 | hax_mlp = hax.nn.MLP.init((H, W), E, width=8, depth=depth, key=key)
120 | assert hax_mlp(x).axes == (C, E)
121 |
122 | # with a named width
123 | M = Axis("M", 18)
124 | hax_mlp = hax.nn.MLP.init((H, W), E, width=M, depth=depth, key=key)
125 | assert hax_mlp(x).axes == (C, E)
126 |
127 | # ensure that we actually use the name for the named width
128 | if depth > 0:
129 | assert hax_mlp.layers[0].Out == M
130 |
131 | if depth % 2 == 0:
132 | assert hax_mlp.layers[-1].In == M.alias("M2")
133 | else:
134 | assert hax_mlp.layers[-1].In == M
135 |
136 | for i in range(1, depth):
137 | if i % 2 == 0:
138 | assert hax_mlp.layers[i].In == M.alias("M2")
139 | assert hax_mlp.layers[i].Out == M
140 | else:
141 | assert hax_mlp.layers[i].In == M
142 | assert hax_mlp.layers[i].Out == M.alias("M2")
143 |
144 |
145 | def test_linear_has_no_function_leaves_by_default():
146 | H, C, W, E = hax.make_axes(H=10, C=12, W=14, E=16)
147 |
148 | hax_linear = hax.nn.Linear.init((H, C, W), E, key=jrandom.PRNGKey(0))
149 | assert all(not isinstance(v, Callable) for v in jax.tree_util.tree_leaves(hax_linear)) # type: ignore
150 |
--------------------------------------------------------------------------------
/tests/test_parsing.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from haliax._src.parsing import parse_einsum, parse_rearrangement
4 |
5 |
6 | def _simplify_captures(expr):
7 | def simplify_capture(capture):
8 | if capture == Ellipsis:
9 | return Ellipsis
10 | elif (capture.binding == capture.axes[0] or capture.binding is None) and len(capture.axes) == 1:
11 | return capture.axes[0]
12 | elif capture.binding is None:
13 | return capture.axes
14 | else:
15 | return {capture.binding: capture.axes}
16 |
17 | return [simplify_capture(capture) for capture in expr.captures]
18 |
19 |
20 | def test_parse_rearrangement_simple():
21 | lhs, rhs = parse_rearrangement("a b c d -> b c a d")
22 | assert lhs.is_ordered
23 | assert _simplify_captures(lhs) == ["a", "b", "c", "d"]
24 | assert rhs.is_ordered
25 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"]
26 |
27 | lhs, rhs = parse_rearrangement("a ... c d -> b c a d")
28 | assert lhs.is_ordered
29 | assert _simplify_captures(lhs) == ["a", Ellipsis, "c", "d"]
30 | assert rhs.is_ordered
31 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"]
32 |
33 | # longer identifiers
34 | lhs, rhs = parse_rearrangement("a_longer b123 c d -> b123 c a_longer d")
35 | assert lhs.is_ordered
36 | assert _simplify_captures(lhs) == ["a_longer", "b123", "c", "d"]
37 | assert rhs.is_ordered
38 | assert _simplify_captures(rhs) == ["b123", "c", "a_longer", "d"]
39 |
40 |
41 | def test_parse_einsum_ordered():
42 | # We could support this syntax for dot with something like:
43 | #
44 | # support normal einops syntax, including short name-capture: hax.dot("... c h w, h w d -> ... c d", a, b)
45 | # hax.dot("{h, w} -> ", a, b) means "contract h and w", analogous to hax.dot(a, b, axis=("h", "w"))
46 | # hax.dot("{h, w} -> ... channel embed", a, b) means "contract h and w and ensure that the result ends with [channel, embed]" (by transposing/einsum)
47 | # hax.dot(" -> batch channel embed", a, b) could mean "contract all but the named dims". Not entirely sure how I feel about that one, but used situationally it's probably ok
48 |
49 | lhses, rhs = parse_einsum("a b c d, b c e f -> a d e f")
50 | assert lhses is not None
51 | assert len(lhses) == 2
52 | assert all(lhs.is_ordered for lhs in lhses)
53 | lhs0_captures = _simplify_captures(lhses[0])
54 | lhs1_captures = _simplify_captures(lhses[1])
55 | assert lhs0_captures == ["a", "b", "c", "d"]
56 | assert lhs1_captures == ["b", "c", "e", "f"]
57 | assert rhs.is_ordered
58 | assert _simplify_captures(rhs) == ["a", "d", "e", "f"]
59 |
60 | lhses, rhs = parse_einsum("... c h w, h w d -> ... c d")
61 | assert lhses is not None
62 | assert len(lhses) == 2
63 | assert all(lhs.is_ordered for lhs in lhses)
64 | lhs0_captures = _simplify_captures(lhses[0])
65 | lhs1_captures = _simplify_captures(lhses[1])
66 | assert lhs0_captures == [..., "c", "h", "w"]
67 | assert lhs1_captures == ["h", "w", "d"]
68 | assert rhs.is_ordered
69 | assert _simplify_captures(rhs) == [..., "c", "d"]
70 |
71 | lhses, rhs = parse_einsum("{...} -> batch channel embed")
72 | assert lhses is not None
73 | assert len(lhses) == 1
74 | assert not lhses[0].is_ordered
75 | assert _simplify_captures(lhses[0]) == [...]
76 | assert rhs.is_ordered
77 | assert _simplify_captures(rhs) == ["batch", "channel", "embed"]
78 |
79 | # just lhs
80 | lhses, rhs = parse_einsum("batch channel embed -> ")
81 | assert lhses is not None
82 | assert len(lhses) == 1
83 | assert lhses[0].is_ordered
84 | assert _simplify_captures(lhses[0]) == ["batch", "channel", "embed"]
85 | assert rhs.is_ordered
86 | assert _simplify_captures(rhs) == []
87 |
88 | # lhs x 2, empty rhs
89 | lhses, rhs = parse_einsum("batch channel embed, batch channel embed ->")
90 | assert lhses is not None
91 | assert len(lhses) == 2
92 | assert all(lhs.is_ordered for lhs in lhses)
93 | assert _simplify_captures(lhses[0]) == ["batch", "channel", "embed"]
94 | assert _simplify_captures(lhses[1]) == ["batch", "channel", "embed"]
95 | assert rhs.is_ordered
96 | assert _simplify_captures(rhs) == []
97 |
98 |
99 | def test_parse_einsum_unordered():
100 | lhses, rhs = parse_einsum("{a, b} -> ")
101 | assert lhses is not None
102 | assert len(lhses) == 1
103 | assert not lhses[0].is_ordered
104 | assert _simplify_captures(lhses[0]) == ["a", "b"]
105 | assert rhs.is_ordered
106 | assert _simplify_captures(rhs) == []
107 |
108 | lhses, rhs = parse_einsum("{...} -> ")
109 | assert lhses is not None
110 | assert len(lhses) == 1
111 | assert not lhses[0].is_ordered
112 | assert _simplify_captures(lhses[0]) == [...]
113 | assert rhs.is_ordered
114 | assert _simplify_captures(rhs) == []
115 |
116 | lhses, rhs = parse_einsum("{h, w} -> ... channel embed")
117 | assert lhses is not None
118 | assert len(lhses) == 1
119 | assert not lhses[0].is_ordered
120 | assert _simplify_captures(lhses[0]) == ["h", "w"]
121 | assert rhs.is_ordered
122 | assert _simplify_captures(rhs) == [..., "channel", "embed"]
123 |
124 |
125 | def test_parse_paren_groups():
126 | lhs, rhs = parse_rearrangement("a (b c) d -> b c a d")
127 | assert lhs.is_ordered
128 | assert _simplify_captures(lhs) == ["a", ("b", "c"), "d"]
129 | assert rhs.is_ordered
130 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"]
131 |
132 | lhs, rhs = parse_rearrangement("a (b: c) (d: e f) -> b c a d")
133 | assert lhs.is_ordered
134 | assert _simplify_captures(lhs) == ["a", {"b": ("c",)}, {"d": ("e", "f")}]
135 |
136 |
137 | def test_parse_unordered():
138 | lhs, rhs = parse_rearrangement("{a b c d} -> {b c a d}")
139 | assert not lhs.is_ordered
140 | assert _simplify_captures(lhs) == ["a", "b", "c", "d"]
141 | assert not rhs.is_ordered
142 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"]
143 |
144 | lhs, rhs = parse_rearrangement("{(c: a b) d e} -> (q: a d e) b")
145 | assert not lhs.is_ordered
146 | assert _simplify_captures(lhs) == [{"c": ("a", "b")}, "d", "e"]
147 | assert rhs.is_ordered
148 | assert _simplify_captures(rhs) == [{"q": ("a", "d", "e")}, "b"]
149 |
150 |
151 | def test_parse_quoted_identifiers():
152 | lhs, rhs = parse_rearrangement("a \"b c\" d -> 'b c' a d")
153 | assert lhs.is_ordered
154 | assert _simplify_captures(lhs) == ["a", "b c", "d"]
155 | assert rhs.is_ordered
156 | assert _simplify_captures(rhs) == ["b c", "a", "d"]
157 |
158 | lhs, rhs = parse_rearrangement("{a \"b c\" (d: 'hello')} -> b c a d")
159 | assert not lhs.is_ordered
160 | assert _simplify_captures(lhs) == ["a", "b c", {"d": ("hello",)}]
161 | assert rhs.is_ordered
162 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"]
163 |
164 |
165 | def test_parse_errors():
166 | with pytest.raises(ValueError, match="Unexpected end of string"):
167 | parse_rearrangement("a b")
168 |
169 | with pytest.raises(ValueError, match="Expected }"):
170 | parse_rearrangement("{ a -> c")
171 |
172 | with pytest.raises(ValueError, match="Unexpected }"):
173 | parse_rearrangement("a } -> c")
174 |
175 | with pytest.raises(ValueError, match="Unexpected }"):
176 | parse_rearrangement("(a: b } -> c")
177 |
178 | with pytest.raises(ValueError, match="Unexpected character"):
179 | parse_rearrangement("(a b ! -> c d e")
180 |
181 | with pytest.raises(ValueError, match="Identifier expected"):
182 | parse_rearrangement("a b ! -> c d e")
183 |
--------------------------------------------------------------------------------
/tests/test_pool.py:
--------------------------------------------------------------------------------
1 | import equinox as eqx
2 | import jax
3 | import jax.numpy as jnp
4 |
5 | import haliax
6 | import haliax as hax
7 | from haliax.nn.pool import max_pool, mean_pool
8 |
9 |
10 | # Tests largely cribbed from equinox
11 |
12 |
13 | def test_maxpool1d():
14 | D = hax.Axis("D", 14)
15 | x = hax.arange(D)
16 | output = max_pool((D.resize(2),), x, stride=(3,))
17 | answer = jnp.array([1, 4, 7, 10, 13], dtype=jnp.int32)
18 |
19 | assert jnp.all(output.array == answer)
20 |
21 | answer = jnp.array([2, 5, 8, 11])
22 | output = max_pool(D.resize(3), x, stride=(3,), padding=0)
23 | assert jnp.all(output.array == answer)
24 |
25 | # max_pool = eqx.nn.MaxPool1d(kernel_size=3, stride=3, padding=0, use_ceil=True)
26 | answer = jnp.array([2, 5, 8, 11, 13])
27 | output = max_pool(D.resize(3), x, stride=(3,), padding=0, use_ceil=True)
28 | assert jnp.all(output.array == answer)
29 |
30 | # test batch axes
31 | B = hax.Axis("B", 2)
32 | x = x.rearrange("(B D) -> B D", B=B)
33 | output = max_pool(D.resize(2), x, stride=(3,), padding="VALID")
34 | answer = jnp.array([[1, 4], [8, 11]])
35 | assert jnp.all(output.array == answer)
36 |
37 | output = max_pool(D.resize(2), x, stride=(3,), use_ceil=True)
38 | answer = jnp.array([[1, 4, 6], [8, 11, 13]])
39 | assert jnp.all(output.array == answer)
40 |
41 | output = max_pool(D.resize(3), x, stride=(3,), padding=0)
42 | answer = jnp.array([[2, 5], [9, 12]])
43 | assert jnp.all(output.array == answer)
44 |
45 | output = max_pool(D.resize(3), x, stride=(3,), padding=0, use_ceil=True)
46 | answer = jnp.array([[2, 5, 6], [9, 12, 13]])
47 | assert jnp.all(output.array == answer)
48 |
49 | output = max_pool(D.resize(2), x, stride=(3,), padding="SAME")
50 | answer = jnp.array([[1, 4, 6], [8, 11, 13]])
51 | assert jnp.all(output.array == answer)
52 |
53 |
54 | def test_maxpool2d():
55 | _x = jnp.arange(36).reshape(6, 6)
56 | x = hax.named(_x, ("H", "W"))
57 |
58 | # max_pool = eqx.nn.MaxPool2d(2, (3, 2))
59 | output = max_pool((hax.Axis("H", 2), hax.Axis("W", 2)), x, stride=(3, 2))
60 | answer = jnp.array([[7, 9, 11], [25, 27, 29]])
61 |
62 | assert jnp.all(output.array == answer)
63 |
64 | output = max_pool((hax.Axis("H", 3), hax.Axis("W", 3)), x, stride=2, padding=1)
65 | answer = jnp.array([[7, 9, 11], [19, 21, 23], [31, 33, 35]])
66 |
67 | assert jnp.all(output.array == answer)
68 |
69 | # test batch axes
70 | B = hax.Axis("B", 2)
71 | x = haliax.stack(B, [x, x])
72 |
73 | output = max_pool((hax.Axis("H", 2), hax.Axis("W", 2)), x, stride=(3, 2))
74 | answer = jnp.array([[[7, 9, 11], [25, 27, 29]], [[7, 9, 11], [25, 27, 29]]])
75 |
76 | assert jnp.all(output.array == answer)
77 |
78 |
79 | def test_maxpool3d():
80 | _x = jnp.arange(64).reshape(4, 4, 4)
81 | x = hax.named(_x, ("H", "W", "D"))
82 | output = max_pool((hax.Axis("H", 2), hax.Axis("W", 2), hax.Axis("D", 2)), x, stride=(3, 2, 1))
83 |
84 | answer = jnp.array([[[21, 22, 23], [29, 30, 31]]])
85 |
86 | assert jnp.all(output.array == answer)
87 |
88 | answer = jnp.asarray(
89 | [
90 | [[37, 39, 39], [45, 47, 47], [45, 47, 47]],
91 | [[53, 55, 55], [61, 63, 63], [61, 63, 63]],
92 | ]
93 | )
94 | output = max_pool(
95 | (hax.Axis("H", 3), hax.Axis("W", 3), hax.Axis("D", 3)),
96 | x,
97 | stride=2,
98 | padding=((0, 1), (1, 1), (1, 1)),
99 | use_ceil=True,
100 | )
101 | assert jnp.all(output.array == answer)
102 |
103 |
104 | def test_mean_pool_1d():
105 | D = hax.Axis("D", 14)
106 | x = hax.arange(D)
107 | output = mean_pool((D.resize(2),), x, stride=(3,))
108 | answer = jnp.array([0.5, 3.5, 6.5, 9.5, 12.5])
109 |
110 | assert jnp.all(output.array == answer)
111 |
112 | # no pad
113 | output = mean_pool(D.resize(3), x, stride=(3,), padding=0)
114 | answer = jnp.array([1, 4, 7, 10])
115 | assert jnp.all(output.array == answer)
116 |
117 | # pad, no include pad in avg
118 | output = mean_pool(D.resize(3), x, stride=(3,), padding="SAME", count_include_pad=False)
119 | answer = jnp.array([1, 4, 7, 10, 12.5])
120 |
121 | assert jnp.all(output.array == answer)
122 |
123 | output = mean_pool(D.resize(3), x, stride=(3,), padding="SAME", count_include_pad=True)
124 | answer = jnp.array([1, 4, 7, 10, (12 + 13) / 3.0])
125 |
126 | assert jnp.all(output.array == answer)
127 |
128 |
129 | def test_mean_pool_2d():
130 | _x = jnp.arange(36).reshape(6, 6)
131 | x = hax.named(_x, ("H", "W"))
132 |
133 | output = mean_pool((hax.Axis("H", 1), hax.Axis("W", 3)), x, stride=2)
134 | answer = jnp.array([[1, 3], [13, 15], [25, 27]])
135 |
136 | assert jnp.all(output.array == answer)
137 |
138 | # test batch axes
139 | B = hax.Axis("B", 2)
140 | x = haliax.stack(B, [x, x])
141 |
142 | output = mean_pool((hax.Axis("H", 1), hax.Axis("W", 3)), x, stride=2)
143 | answer = jnp.array([[[1, 3], [13, 15], [25, 27]], [[1, 3], [13, 15], [25, 27]]])
144 |
145 | assert jnp.all(output.array == answer)
146 |
147 |
148 | def test_mean_pool3d():
149 | _x = jnp.arange(64).reshape(4, 4, 4)
150 | x = hax.named(_x, ("H", "W", "D"))
151 | output = mean_pool((hax.Axis("H", 1), hax.Axis("W", 3), hax.Axis("D", 1)), x, stride=2)
152 |
153 | answer = jnp.array([[[4, 6]], [[36, 38]]])
154 |
155 | assert jnp.all(output.array == answer)
156 |
157 |
158 | def test_pool_backprop():
159 | def max_pool_mean(x):
160 | pooled = max_pool(
161 | (hax.Axis("H", 2), hax.Axis("W", 2), hax.Axis("D", 2)), x, stride=1, padding=((0, 1), (0, 1), (0, 1))
162 | )
163 | return hax.mean(pooled).scalar()
164 |
165 | _x = jnp.arange(64, dtype=jnp.float32).reshape(1, 4, 4, 4)
166 | x = hax.named(_x, ("B", "H", "W", "D"))
167 | grad_fn = jax.value_and_grad(max_pool_mean)
168 |
169 | hax_loss, hax_grad = grad_fn(x)
170 |
171 | # compare it to eqx
172 |
173 | eqx_max_pool = eqx.nn.MaxPool3d(2, (1, 1, 1), padding=((0, 1), (0, 1), (0, 1)))
174 |
175 | def eqx_max_pool_mean(x):
176 | pooled = eqx_max_pool(x)
177 | return pooled.mean()
178 |
179 | eqx_grad_fn = jax.value_and_grad(eqx_max_pool_mean)
180 | eqx_loss, eqx_grad = eqx_grad_fn(_x)
181 |
182 | assert jnp.allclose(hax_loss, eqx_loss)
183 | assert jnp.allclose(hax_grad.array, eqx_grad)
184 |
--------------------------------------------------------------------------------
/tests/test_specialized_fns.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 | import haliax
5 | import haliax.specialized_fns as hfns
6 | from haliax import NamedArray
7 |
8 |
9 | def test_top_k():
10 | H, W, D = haliax.make_axes(H=3, W=4, D=5)
11 |
12 | rand = jax.random.uniform(jax.random.PRNGKey(0), (H.size, W.size, D.size))
13 | n_rand = NamedArray(rand, (H, W, D))
14 |
15 | values, indices = hfns.top_k(n_rand, D, 2)
16 |
17 | assert values.array.shape == (H.size, W.size, 2)
18 | assert indices.array.shape == (H.size, W.size, 2)
19 | assert jnp.all(
20 | jnp.equal(jax.lax.top_k(rand, 2)[0], values.array)
21 | ) # test that selecting last axis is same as default
22 | assert jnp.all(
23 | jnp.equal(jnp.moveaxis(n_rand.take(D, indices).array, 0, -1), values.array)
24 | ) # test that indexing using indices is same as selected values
25 |
26 | for idx, i in enumerate(n_rand.axes): # then test selecting all axes
27 | t = jnp.transpose(rand, (*range(idx), *range(idx + 1, len(n_rand.axes)), idx))
28 | t = jax.lax.top_k(t, 2)[0]
29 | t = jnp.moveaxis(t, -1, idx)
30 | values, indices = hfns.top_k(n_rand, i, 2)
31 | assert jnp.all(jnp.equal(t, values.array))
32 | assert jnp.all(jnp.equal(jnp.moveaxis(n_rand.take(i, indices).array, 0, idx), values.array))
33 |
--------------------------------------------------------------------------------
/tests/test_state_dict.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Any
3 |
4 | import equinox as eqx
5 | import jax
6 | import jax.numpy as jnp
7 | import pytest
8 |
9 | import haliax as hax
10 | from haliax._src.state_dict import flatten_modules_for_export, unflatten_modules_from_export
11 | from haliax.nn import Linear
12 | from haliax.nn.scan import Stacked, _stack_state_dict, _unstack_state_dict
13 | from haliax.state_dict import from_state_dict, to_state_dict
14 |
15 |
16 | @pytest.mark.parametrize("out_dims_first", [True, False])
17 | def test_flatten_linear_layers(out_dims_first: bool):
18 | H = hax.Axis("H", 10)
19 | W = hax.Axis("W", 20)
20 | D = hax.Axis("D", 30)
21 | B = hax.Axis("B", 40)
22 | linear = hax.nn.Linear.init((H, W), (D, B), key=jax.random.PRNGKey(0), use_bias=True, out_first=out_dims_first)
23 |
24 | if out_dims_first:
25 | assert linear.weight.axes == (D, B, H, W)
26 | else:
27 | assert linear.weight.axes == (H, W, D, B)
28 |
29 | flat_linear = linear.flatten_for_export()
30 |
31 | flat_state_dict = to_state_dict(flat_linear)
32 | if out_dims_first:
33 | assert flat_state_dict["weight"].shape == (D.size * B.size, H.size * W.size)
34 | else:
35 | assert flat_state_dict["weight"].shape == (H.size * W.size, D.size * B.size)
36 | assert flat_state_dict["bias"].shape == (D.size * B.size,)
37 | assert flat_state_dict["weight"].dtype == flat_state_dict["bias"].dtype == linear.weight.dtype
38 |
39 | # now unflatten it
40 | linear2 = Linear.init((H, W), (D, B), key=jax.random.PRNGKey(1), use_bias=True, out_first=out_dims_first)
41 | new_linear = flat_linear.unflatten_from_export(linear2)
42 |
43 | if out_dims_first:
44 | assert new_linear.weight.axes == (D, B, H, W)
45 | else:
46 | assert new_linear.weight.axes == (H, W, D, B)
47 | assert new_linear.bias.axes == (D, B) # type: ignore
48 |
49 | assert linear == new_linear
50 |
51 |
52 | # Test cases for stack_state_dict
53 | @pytest.mark.parametrize(
54 | "input_dict, prefix, expected_output",
55 | [
56 | # Single block stacking
57 | (
58 | {
59 | "block.0.weight": jnp.array([1, 2]),
60 | "block.0.bias": jnp.array([3]),
61 | "block.1.weight": jnp.array([4, 5]),
62 | "block.1.bias": jnp.array([6]),
63 | },
64 | "block",
65 | {
66 | "block.weight": jnp.array([[1, 2], [4, 5]]),
67 | "block.bias": jnp.array([[3], [6]]),
68 | },
69 | ),
70 | # Mixed data types and unmatched items remain unchanged
71 | (
72 | {
73 | "block.0.weight": jnp.array([1, 2]),
74 | "block.0.bias": jnp.array([3]),
75 | "block.1.weight": jnp.array([4, 5]),
76 | "block.1.bias": jnp.array([6.0]),
77 | "unrelated.item": jnp.array([7]),
78 | },
79 | "block",
80 | {
81 | "block.weight": jnp.array([[1, 2], [4, 5]]),
82 | "block.bias": jnp.array([[3.0], [6.0]]),
83 | "unrelated.item": jnp.array([7]),
84 | },
85 | ),
86 | # No items match prefix, all items should remain unchanged
87 | (
88 | {
89 | "module.0.param": jnp.array([1]),
90 | "module.1.param": jnp.array([2]),
91 | },
92 | "block",
93 | {
94 | "module.0.param": jnp.array([1]),
95 | "module.1.param": jnp.array([2]),
96 | },
97 | ),
98 | ],
99 | )
100 | def test_stack_state_dict(input_dict, prefix, expected_output):
101 | result = _stack_state_dict(input_dict, prefix)
102 | for key in expected_output:
103 | assert jnp.all(jnp.array_equal(result[key], expected_output[key])), f"Failed on key: {key}"
104 |
105 | # now unstack it
106 | unstacked = _unstack_state_dict(result, prefix)
107 | for key in input_dict:
108 | assert jnp.all(jnp.array_equal(unstacked[key], input_dict[key])), f"Failed on key: {key}"
109 |
110 |
111 | class M(eqx.Module):
112 | a: Any
113 | b: Any
114 |
115 | def __init__(self, a, b):
116 | self.a = a
117 | self.b = b
118 |
119 |
120 | def test_to_from_state_dict():
121 | a = jnp.array([1, 2])
122 | b = jnp.array([3, 4])
123 | m = M(a, b)
124 |
125 | state_dict = to_state_dict(m)
126 | assert state_dict == {"a": a, "b": b}
127 |
128 | m2 = M(jnp.array([0, 0]), jnp.array([0, 0]))
129 | m2 = from_state_dict(m2, state_dict)
130 | assert jnp.all(m2.a == a)
131 | assert jnp.all(m2.b == b)
132 |
133 |
134 | def test_export_layer_norm():
135 | D = hax.Axis("D", 10)
136 | E = hax.Axis("E", 20)
137 | layer_norm = hax.nn.LayerNorm.init((D, E), eps=1e-5, use_weight=True, use_bias=True)
138 |
139 | flat_layer_norm = layer_norm.flatten_for_export()
140 |
141 | flat_state_dict = to_state_dict(flat_layer_norm)
142 |
143 | assert flat_state_dict["weight"].shape == (D.size * E.size,)
144 | assert flat_state_dict["bias"].shape == (D.size * E.size,)
145 | assert flat_state_dict["weight"].dtype == flat_state_dict["bias"].dtype == layer_norm.weight.dtype
146 |
147 | # now unflatten it
148 | layer_norm2 = hax.nn.LayerNorm.init((D, E), eps=1e-5, use_weight=True, use_bias=True)
149 | # ensure we have different weights
150 | layer_norm2 = dataclasses.replace(layer_norm2, weight=layer_norm2.weight + 1, bias=layer_norm2.bias + 1)
151 |
152 | new_layer_norm = flat_layer_norm.unflatten_from_export(layer_norm2)
153 |
154 | assert layer_norm == new_layer_norm
155 |
156 |
157 | def test_stacked_layer_norm():
158 | L = hax.Axis("L", 4)
159 | D = hax.Axis("D", 10)
160 | E = hax.Axis("E", 20)
161 |
162 | norms = Stacked.init(L, hax.nn.LayerNorm)((D, E), eps=1e-5, use_weight=True, use_bias=True)
163 |
164 | norms_flat = flatten_modules_for_export(norms)
165 |
166 | flat_state_dict = to_state_dict(norms_flat)
167 |
168 | assert flat_state_dict["0.weight"].shape == (D.size * E.size,)
169 | assert flat_state_dict["0.bias"].shape == (D.size * E.size,)
170 | assert flat_state_dict["1.weight"].shape == (D.size * E.size,)
171 |
172 | # now unflatten it
173 | norms2 = Stacked.init(L, hax.nn.LayerNorm)((D, E), eps=1e-5, use_weight=True, use_bias=True)
174 |
175 | new_norms = unflatten_modules_from_export(norms_flat, norms2)
176 |
177 | assert norms == new_norms
178 |
--------------------------------------------------------------------------------
/tests/test_tree_util.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import equinox as eqx
4 | import jax
5 | import jax.numpy as jnp
6 | from chex import assert_trees_all_close
7 |
8 | import haliax as hax
9 | import haliax.tree_util as htu
10 | from haliax import Axis
11 |
12 |
13 | def test_resize_axis():
14 | A, B, C = hax.make_axes(A=10, B=20, C=30)
15 |
16 | class Module(eqx.Module):
17 | name1: hax.NamedArray
18 | name2: hax.NamedArray
19 | name3: hax.NamedArray
20 |
21 | module = Module(
22 | name1=hax.random.normal(jax.random.PRNGKey(0), (B, A, C)),
23 | name2=hax.zeros((B, C)),
24 | name3=hax.zeros((Axis("A", 20),)),
25 | )
26 |
27 | NewA = A.resize(15)
28 |
29 | module2 = htu.resize_axis(module, "A", 15, key=jax.random.PRNGKey(1))
30 |
31 | assert module2.name1.axes == (B, NewA, C)
32 | assert module2.name2.axes == (B, C)
33 | assert module2.name3.axes == (NewA,)
34 |
35 | # we don't mess with the mean or std of the original array too much
36 | assert jnp.allclose(module2.name1.mean(), module.name1.mean(), rtol=1e-1, atol=1e-2)
37 |
38 |
39 | def test_scan_aware_tree_map():
40 | Embed = hax.Axis("embed", 10)
41 | Up = hax.Axis("up", 20)
42 | Block = hax.Axis("block", 4)
43 |
44 | class Module(eqx.Module):
45 | up: hax.nn.Linear
46 | down: hax.nn.Linear
47 |
48 | def __call__(self, x, *, key):
49 | return self.down(self.up(x), key=key)
50 |
51 | @staticmethod
52 | def init(layer_idx, *, key):
53 | k1, k2 = jax.random.split(key)
54 | up = hax.nn.Linear.init(Embed, Up, key=k1)
55 | down = hax.nn.Linear.init(Up, Embed, key=k2)
56 |
57 | up = dataclasses.replace(up, weight=up.weight + layer_idx) # type: ignore
58 | down = dataclasses.replace(down, weight=down.weight + layer_idx) # type: ignore
59 |
60 | return Module(up=up, down=down)
61 |
62 | class Model(eqx.Module):
63 | layers: hax.nn.Stacked[eqx.Module]
64 |
65 | def __call__(self, x, *, key):
66 | return self.layers.fold(x, key=jax.random.split(key, self.layers.Block.size))
67 |
68 | @staticmethod
69 | def init(Layers, *, key):
70 | stack = hax.nn.Stacked.init(Layers, Module)(
71 | layer_idx=hax.arange(Layers), key=jax.random.split(key, Layers.size)
72 | )
73 | return Model(layers=stack)
74 |
75 | model = Model.init(Block, key=jax.random.PRNGKey(0))
76 |
77 | def transform_linear(x):
78 | if not isinstance(x, hax.nn.Linear):
79 | return x
80 |
81 | # do something that distinguishes doing weights jointly from independently
82 | new_weight = x.weight - hax.mean(x.weight)
83 | return dataclasses.replace(x, weight=new_weight) # type: ignore
84 |
85 | model2 = htu.scan_aware_tree_map(transform_linear, model, is_leaf=lambda x: isinstance(x, hax.nn.Linear))
86 | model3 = htu.tree_map(transform_linear, model, is_leaf=lambda x: isinstance(x, hax.nn.Linear))
87 |
88 | assert hax.all(model2.layers.stacked.up.weight != model3.layers.stacked.up.weight)
89 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | import jax
4 | import pytest
5 |
6 |
7 | T = TypeVar("T")
8 |
9 |
10 | def skip_if_not_enough_devices(count: int):
11 | return pytest.mark.skipif(len(jax.devices()) < count, reason=f"Not enough devices ({len(jax.devices())})")
12 |
13 |
14 | def has_torch():
15 | try:
16 | import torch # noqa F401
17 |
18 | return True
19 | except ImportError:
20 | return False
21 |
22 |
23 | def skip_if_no_torch(f):
24 | return pytest.mark.skipif(not has_torch(), reason="torch not installed")(f)
25 |
--------------------------------------------------------------------------------