├── .coveragerc ├── .flake8 ├── .github └── workflows │ ├── publish_dev.yaml │ ├── run_pre_commit.yaml │ ├── run_quick_levanter_tests.yaml │ └── run_tests.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── api.md ├── broadcasting.md ├── cheatsheet.md ├── css │ ├── material.css │ └── mkdocstrings.css ├── faq.md ├── figures │ ├── data_parallel_mesh.png │ ├── data_parallel_mesh_replicated.png │ ├── device_mesh_1d.png │ ├── device_mesh_1d_zero.png │ ├── device_mesh_2d.png │ ├── device_mesh_2d_batch_partitioned.png │ ├── device_mesh_2d_data_replicated.png │ ├── device_mesh_2d_data_replicated_mlp_partitioned.png │ ├── device_mesh_2d_intermediate_fully_partitioned.png │ └── device_mesh_2d_zero.png ├── fp8.md ├── index.md ├── indexing.md ├── matmul.md ├── nn.md ├── partitioning.md ├── rearrange.ipynb ├── rearrange.md ├── requirements.txt ├── scan.md ├── state-dict.md ├── tutorial.md └── vmap.md ├── mkdocs.yml ├── pyproject.toml ├── src └── haliax │ ├── __about__.py │ ├── __init__.py │ ├── _src │ ├── __init__.py │ ├── compile_utils.py │ ├── dot.py │ ├── einsum.py │ ├── fp8.py │ ├── parsing.py │ ├── rearrange.py │ ├── scan.py │ ├── state_dict.py │ └── util.py │ ├── axis.py │ ├── core.py │ ├── debug.py │ ├── hof.py │ ├── jax_utils.py │ ├── nn │ ├── __init__.py │ ├── activations.py │ ├── attention.py │ ├── conv.py │ ├── dropout.py │ ├── embedding.py │ ├── linear.py │ ├── loss.py │ ├── mlp.py │ ├── normalization.py │ ├── pool.py │ └── scan.py │ ├── ops.py │ ├── partitioning.py │ ├── quantization.py │ ├── random.py │ ├── specialized_fns.py │ ├── state_dict.py │ ├── tree_util.py │ ├── types.py │ ├── util.py │ └── wrap.py └── tests ├── core_test.py ├── test_attention.py ├── test_axis.py ├── test_conv.py ├── test_debug.py ├── test_dot.py ├── test_einsum.py ├── test_fp8.py ├── test_hof.py ├── test_int8.py ├── test_nn.py ├── test_ops.py ├── test_parsing.py ├── test_partitioning.py ├── test_pool.py ├── test_random.py ├── test_rearrange.py ├── test_scan.py ├── test_specialized_fns.py ├── test_state_dict.py ├── test_tree_util.py └── test_utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | pragma: not covered 4 | pragma: no cover 5 | @overload 6 | @typing.overload 7 | @abc.abstractmethod 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git 3 | max-line-length = 120 4 | ignore = E203, E501, W503, W605, F821, E266, E402, E731, F401, F403, F40 5 | per-file-ignores = 6 | */__init__.py: F401 7 | examples/*.py: E402 8 | -------------------------------------------------------------------------------- /.github/workflows/publish_dev.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Dev Build 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Run Tests"] 6 | types: 7 | - completed 8 | branches: [main] 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build-package: 13 | runs-on: ubuntu-latest 14 | if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success'}} 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v4 18 | with: 19 | fetch-depth: 0 20 | - name: Set up Python 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: '3.x' 24 | 25 | - name: Calculate Version and Build Number 26 | run: | 27 | PROJECT_VERSION=$(sed -n 's/__version__ = "\(.*\)"/\1/p' src/haliax/__about__.py) 28 | BUILD_NUMBER=$(git rev-list --count HEAD) 29 | FULL_VERSION="${PROJECT_VERSION}.dev${BUILD_NUMBER}" 30 | echo "FULL_VERSION=${FULL_VERSION}" >> $GITHUB_ENV 31 | echo "Calculated version with build number: $FULL_VERSION" 32 | - name: Update pyproject.toml version 33 | run: | 34 | # replace the version in __about__.py 35 | echo "Updating version in __about__.py to $FULL_VERSION" 36 | sed -i "s/__version__ = \".*\"/__version__ = \"$FULL_VERSION\"/g" src/haliax/__about__.py 37 | - name: Build package 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install build 41 | python -m build 42 | 43 | - name: Upload package 44 | uses: actions/upload-artifact@v4 45 | with: 46 | name: package 47 | path: dist/ 48 | 49 | 50 | # cf https://test.pypi.org/manage/project/haliax/settings/publishing/ 51 | publish-dev: 52 | runs-on: ubuntu-latest 53 | needs: 54 | - build-package 55 | permissions: 56 | id-token: write 57 | steps: 58 | - name: Retrieve release distributions 59 | uses: actions/download-artifact@v4 60 | with: 61 | name: package 62 | path: dist/ 63 | 64 | - name: Publish release distributions to PyPI Test 65 | uses: pypa/gh-action-pypi-publish@release/v1 66 | 67 | 68 | -------------------------------------------------------------------------------- /.github/workflows/run_pre_commit.yaml: -------------------------------------------------------------------------------- 1 | name: Pre-Commit 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.10.11"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install flake8 pytest pre-commit 23 | pip install . 24 | # - name: Lint with flake8 25 | # run: | 26 | # # stop the build if there are Python syntax errors or undefined names 27 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 28 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 29 | # flake8 . --count --exit-zero --max-complexity=50 --max-line-length=127 --statistics 30 | - name: "Run Pre-commit" 31 | run: | 32 | pre-commit run --all-files --show-diff-on-failure 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/run_quick_levanter_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run Levanter Tests 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 3.10.11 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: 3.10.11 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install flake8 pytest 20 | pip install "jax[cpu]==0.5.3" "jaxlib[cpu]==0.5.3" .[dev] 21 | 22 | - name: Install Levanter from source 23 | run: | 24 | cd .. 25 | git clone https://github.com/stanford-crfm/levanter.git 26 | cd levanter 27 | pip install -e .[tests] 28 | pip install -r tests/requirements.txt 29 | # i don't know why this is necessary 30 | pip install tensorboardX 31 | - name: Install Haliax on top 32 | run: | 33 | # install second since levanter will install a built version of haliax 34 | cd ../haliax 35 | - name: Test levanter with pytest 36 | run: | 37 | cd ../levanter 38 | XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:../src pytest tests -m "not entry and not slow" 39 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 3.10.11 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: 3.10.11 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install flake8 pytest 20 | pip install jax==0.4.35 jaxlib==0.4.35 .[dev] 21 | - name: Test with pytest 22 | run: | 23 | XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /scratch 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # JetBrains 134 | .idea/ 135 | 136 | 137 | # Wandb stuff 138 | /wandb 139 | 140 | # dataset cache files 141 | *.parquet 142 | ledger.json 143 | 144 | /checkpoints 145 | *.jaxpr 146 | 147 | # local execution commands 148 | local_*.sh 149 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | exclude: ".git" 4 | default_stages: 5 | - commit 6 | fail_fast: true 7 | 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.0.1 11 | hooks: 12 | - id: trailing-whitespace 13 | - id: end-of-file-fixer 14 | - id: check-yaml 15 | - id: check-toml 16 | - id: check-merge-conflict 17 | - id: check-added-large-files 18 | 19 | - repo: https://github.com/psf/black 20 | rev: 22.3.0 21 | hooks: 22 | - id: black 23 | 24 | - repo: https://github.com/timothycrosley/isort 25 | rev: 5.11.5 26 | hooks: 27 | - id: isort 28 | 29 | - repo: https://github.com/PyCQA/flake8 30 | rev: 3.9.2 31 | hooks: 32 | - id: flake8 33 | additional_dependencies: [flake8-isort] 34 | 35 | - repo: https://github.com/pre-commit/mirrors-mypy 36 | rev: 'v1.5.1' 37 | hooks: 38 | - id: mypy 39 | args: [--ignore-missing-imports, --check-untyped-defs] 40 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for MkDocs projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | # Required 4 | version: 2 5 | # Set the version of Python and other tools you might need 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.11" 10 | 11 | mkdocs: 12 | configuration: mkdocs.yml 13 | 14 | # Optionally declare the Python requirements required to build your docs 15 | python: 16 | install: 17 | - requirements: docs/requirements.txt 18 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | Levanter is a growing code base, and we are excited for other folks to get involved. The instructions below walk you through our dev setup and how to submit a PR. 5 | 6 | Dev Installation 7 | ---------------- 8 | 9 | First follow the same instructions as provided for the [Levanter README](README.md) to install Levanter. 10 | 11 | The main addition for a dev environment is to install [`pre-commit`](https://pre-commit.com/): 12 | 13 | pre-commit install 14 | 15 | This will set up git hook scripts that ensure your code is formatted in a manner consistent with 16 | the repo. If any problems are found, the appropriate files will be updated with fixes. You will 17 | need to review and commit the fixed files. 18 | 19 | Forking The Repo 20 | ---------------- 21 | 22 | To submit changes, you will need to work off of a fork of the repo and issue a pull request. 23 | 24 | There are two easy ways to fork the repo. 25 | 26 | If you have installed the [GitHub CLI](https://cli.github.com/) you can issue this command: 27 | 28 | gh repo fork stanford-crfm/levanter --clone=true 29 | 30 | This will create the fork and clone the repo into your current directory. 31 | 32 | Alternatively you can fork the repo in your browser. While logged in to your GitHub account, 33 | go to the [Levanter repo](https://github.com/stanford-crfm/levanter) and click on the Fork 34 | button in the upper left hand corner. 35 | 36 | You can then clone your forked version of the Levanter repo like any other GitHub repo. 37 | 38 | Create A Branch For Your Submission 39 | ----------------------------------- 40 | 41 | You will generally need to create a branch of `main` for your code changes. In general every submission 42 | should be focused on a specific set of bug fixes or new features that are coherently 43 | related. Changes that are not related belong in different submissions. So you should 44 | be able to give your branch an informative name such as `checkpointer-time-bugfix` . 45 | 46 | You can create a branch off of `main` with this command: 47 | 48 | git checkout -b checkpointer-time-bugfix main 49 | 50 | Implement Your Changes 51 | ---------------------- 52 | 53 | As you implement your changes in your feature branch, the git hook scripts will check your 54 | code for proper formatting as you make commits. Make sure you have run `pre-commit install` 55 | before you start making commits. 56 | 57 | You can also check all files in the current branch with this command: 58 | 59 | pre-commit run --all-files 60 | 61 | When your changes are operational you should verify that the current tests are passing. 62 | 63 | Set up your environment for running the tests: 64 | 65 | export PYTHONPATH=/path/to/levanter/src:path/to/levanter/tests:$PYTHONPATH 66 | wandb offline 67 | 68 | You can run the tests with this command: 69 | 70 | pytest tests 71 | 72 | You should add tests for any functionality you have added consistent with the [pytest](https://docs.pytest.org/en/6.2.x/) format 73 | of the existing tests. 74 | 75 | Submit Pull Request 76 | ------------------- 77 | 78 | When your feature branch is ready you should submit a pull request. 79 | 80 | Detailed instructions for submtting a pull request from a fork can be found on [Github Docs](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork). 81 | 82 | The steps basically are: 83 | 84 | 1. While logged in to your GitHub, go to the original [Levanter repo pull request page](https://github.com/stanford-crfm/levanter/pulls) 85 | 2. Click on the highlighted text stating "compare across forks". 86 | 3. Set the base repository to `stanford-crfm/levanter` and the base branch to `main`. 87 | 4. Set the head repository to `your-org/levanter` and the compare branch to `your-feature-branch`. 88 | 5. Click on the "Create pull request" button and complete the pull request form. 89 | 90 | When submitting your pull request, you should provide a detailed description of what you've done. 91 | 92 | The following is a useful template: 93 | 94 | ## Description 95 | A brief and concise description of what your pull request is trying to accomplish. 96 | 97 | ## Fixes Issues 98 | A list of issues/bugs with # references. (e.g., #123) 99 | 100 | ## Unit test coverage 101 | Are there unit tests in place to make sure your code is functioning correctly? 102 | 103 | ## Known breaking changes/behaviors 104 | Does this break anything in Levanter's existing user interface? If so, what is it and how is it addressed? 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Haliax 3 | 4 | 5 | Build Status 6 | 7 | 8 | Documentation Status 9 | 10 | 11 | License 12 | 13 | 14 | PyPI 15 | 16 | 17 | > *Though you don’t seem to be much for listening, it’s best to be careful. If you managed to catch hold of even just a piece of my name, you’d have all manner of power over me.*
18 | > — Patrick Rothfuss, *The Name of the Wind* 19 | 20 | Haliax is a [JAX](https:://github.com/google/jax) library for building neural networks with named tensors, in the tradition of Alexander Rush's [Tensor Considered Harmful](https://nlp.seas.harvard.edu/NamedTensor). 21 | Named tensors improve the **legibility** and **compositionality** of tensor programs by using named axes instead of positional indices 22 | as typically used in NumPy, PyTorch, etc. 23 | 24 | Despite the focus on legibility, Haliax 25 | is also **fast**, typically about as fast as "pure" JAX code. 26 | Haliax is also built to be **scalable**: it 27 | can support [Fully-Sharded Data Parallelism (FSDP)](https://engineering.fb.com/2021/07/15/open-source/fsdp/) and Tensor Parallelism with [just a few lines of code](https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz). Haliax powers [Levanter](https://github.com/stanford-crfm/levanter), 28 | our companion library for training large language models and other foundation models, with scale proven up to 70B parameters 29 | and up to TPU v4-2048. 30 | 31 | ## Example: Attention 32 | 33 | Here's a minimal attention module implementation in Haliax. For a more detailed introduction, 34 | please see the [Haliax tutorial](https://colab.research.google.com/drive/1TiTcQQ4V5mopbgCu1SVl-oqJtXn7rFnC). 35 | (We use the excellent [Equinox](https://github.com/patrick-kidger/equinox) library for its module system and tree transformations.) 36 | 37 | ```python 38 | import equinox as eqx 39 | import jax 40 | import jax.numpy as jnp 41 | import haliax as hax 42 | import haliax.nn as hnn 43 | 44 | Pos = hax.Axis("position", 1024) # sequence length 45 | KPos = Pos.alias("key_position") 46 | Head = hax.Axis("head", 8) # number of attention heads 47 | Key = hax.Axis("key", 64) # key size 48 | Embed = hax.Axis("embed", 512) # embedding size 49 | 50 | # alternatively: 51 | #Pos, KPos, Head, Key, Embed = hax.make_axes(pos=1024, key_pos=1024, head=8, key=64, embed=512) 52 | 53 | 54 | def attention_scores(Key, KPos, query, key, mask): 55 | # how similar is each query to each key 56 | scores = hax.dot(query, key, axis=Key) / jnp.sqrt(Key.size) 57 | 58 | if mask is not None: 59 | scores -= 1E9 * (1.0 - mask) 60 | 61 | # convert to probabilities 62 | scores = haliax.nn.softmax(scores, KPos) 63 | return scores 64 | 65 | 66 | def attention(Key, KPos, query, key, value, mask): 67 | scores = attention_scores(Key, KPos, query, key, mask) 68 | answers = hax.dot(scores, value, axis=KPos) 69 | 70 | return answers 71 | 72 | 73 | # Causal Mask means that if pos >= key_pos, then pos can attend to key_pos 74 | causal_mask = hax.arange(Pos).broadcast_axis(KPos) >= hax.arange(KPos) 75 | 76 | 77 | class Attention(eqx.Module): 78 | proj_q: hnn.Linear # [Embed] -> [Head, Key] 79 | proj_k: hnn.Linear # [Embed] -> [Head, Key] 80 | proj_v: hnn.Linear # [Embed] -> [Head, Key] 81 | proj_answer: hnn.Linear # output projection from [Head, Key] -> [Embed] 82 | 83 | @staticmethod 84 | def init(Embed, Head, Key, *, key): 85 | k_q, k_k, k_v, k_ans = jax.random.split(key, 4) 86 | proj_q = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_q) 87 | proj_k = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_k) 88 | proj_v = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_v) 89 | proj_answer = hnn.Linear.init(In=(Head, Key), Out=Embed, key=k_ans) 90 | return Attention(proj_q, proj_k, proj_v, proj_answer) 91 | 92 | def __call__(self, x, mask=None): 93 | q = self.proj_q(x) 94 | # Rename "position" to "key_position" for self attention 95 | k = self.proj_k(x).rename({"position": "key_position"}) 96 | v = self.proj_v(x).rename({"position": "key_position"}) 97 | 98 | answers = attention(Key, KPos, q, k, v, causal_mask) 99 | 100 | x = self.proj_answer(answers) 101 | return x 102 | ``` 103 | 104 | Haliax was created by [Stanford's Center for Research on Foundation Models (CRFM)](https://crfm.stanford.edu/)'s research engineering team. 105 | You can find us in the #levanter channel on the unofficial [Jax LLM Discord](https://discord.gg/FkRGNX3ND). 106 | 107 | 108 | 109 | ## Documentation 110 | 111 | ### Tutorials 112 | 113 | These are some tutorials to get you started with Haliax. They are available as Colab notebooks: 114 | 115 | 116 | 117 | * [Introduction to Haliax with Transformers](https://colab.research.google.com/drive/1TiTcQQ4V5mopbgCu1SVl-oqJtXn7rFnC) 118 | * [Distributed Training in Haliax](https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz) (including FSDP) 119 | * [Tensor Parallelism in Haliax](https://colab.research.google.com/drive/18_BrtDpe1lu89M4T6fKzda8DdSLtFJhi) 120 | * [Mixed Precision with `jmp`](https://colab.research.google.com/drive/1_4cikwt-UhSH7yRzNRK8ze9msM9r2mEl?usp=sharing) (This one is really a tutorial for [jmp](https://github.com/deepmind/jmp) but it's how to use it with Haliax...) 121 | 122 | 123 | ### API Reference 124 | 125 | Haliax's API documentation is available at [haliax.readthedocs.io](https://haliax.readthedocs.io/en/latest/). 126 | 127 | ## Contributing 128 | 129 | We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more information. 130 | We also have a list of [good first issues](https://github.com/stanford-crfm/haliax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) 131 | to help you get started. (If those don't appeal, don't hesitate to reach out to us on Discord!) 132 | 133 | ## License 134 | 135 | Haliax is licensed under the Apache License, Version 2.0. See [LICENSE](LICENSE) for the full license text. 136 | -------------------------------------------------------------------------------- /docs/broadcasting.md: -------------------------------------------------------------------------------- 1 | # Broadcasting 2 | 3 | One area where Haliax's treatment of named axes differs substantially from Numpy-esque positional code is in broadcasting. In traditional positional code, [broadcasting works like this](https://numpy.org/doc/stable/user/basics.broadcasting.html): 4 | 5 | ```python 6 | import numpy as np 7 | 8 | # compute the outer product of two arrays 9 | a = np.arange(5) 10 | b = np.arange(4) 11 | 12 | c = a.reshape((-1, 1)) * b.reshape((1, -1)) 13 | print(c.shape) 14 | print(c) 15 | 16 | # alternatively 17 | c2 = a[:, np.newaxis] * b 18 | ``` 19 | 20 | This prints: 21 | ``` 22 | (5, 4) 23 | [[ 0 0 0 0] 24 | [ 0 1 2 3] 25 | [ 0 2 4 6] 26 | [ 0 3 6 9] 27 | [ 0 4 8 12]] 28 | ``` 29 | 30 | To quote the NumPy documentation, for positional arrays, "in order to broadcast, the size of the trailing axes for both arrays in an operation must either be the same size or one of them must be one." 31 | 32 | I have found this to be a source of bugs: it is easy to accidentally have an array of size [batch_size, 1] and combine it with an array of size [batch_size], yielding an array of [batch_size, batch_size]. 33 | 34 | In Haliax, broadcasting is done by matching names. The same operation in Haliax might look like this: 35 | 36 | ```python 37 | M = hax.Axis("M", 5) 38 | N = hax.Axis("N", 4) 39 | 40 | a = hax.arange(M) 41 | b = hax.arange(N) 42 | 43 | c = a.broadcast_axis(N) * b 44 | print(c.axes) 45 | print(c.array) 46 | ``` 47 | 48 | ``` 49 | (Axis(name='N', size=4), Axis(name='M', size=5)) 50 | [[ 0 0 0 0 0] 51 | [ 0 1 2 3 4] 52 | [ 0 2 4 6 8] 53 | [ 0 3 6 9 12]] 54 | ``` 55 | 56 | Haliax aims to be "order-independent" as much as possible (while still letting you choose the order for performance or compatibility with positional code). 57 | Its semantics are: "in order to broadcast, identically named Axes of the arrays must have the same size. 58 | In addition, they must share at least one named axis in common, unless one is a scalar." 59 | The second condition is there to avoid bugs: we want to be sure that the arrays have something in common. 60 | To satisfy the second condition, it is not uncommon to use [haliax.broadcast_axis][], like we did above. 61 | This method takes one or more axes and adds them to the array. 62 | 63 | ## Explicit Broadcasting Functions 64 | 65 | * [haliax.broadcast_axis][] 66 | * [haliax.broadcast_to][] 67 | 68 | 69 | 78 | 79 | ## A note on performance 80 | 81 | Under the hood, Haliax will automatically broadcast and permute axes so that the underlying positional code produces the correct result. 82 | This is usually not a substantial performance hit because XLA is pretty good about picking optimal shapes, 83 | but if you're doing repeated operations you may want to use [haliax.rearrange][] to change the order of axes. 84 | As an example, in Levanter's GPT-2 implementation, we found using `rearrange` led to a 5% speedup for small models. This 85 | isn't huge, but it's not nothing either. 86 | -------------------------------------------------------------------------------- /docs/css/material.css: -------------------------------------------------------------------------------- 1 | .md-main__inner { 2 | margin-bottom: 1.5rem; 3 | } 4 | 5 | /* Custom admonition: preview */ 6 | :root { 7 | --md-admonition-icon--preview: url('data:image/svg+xml;charset=utf-8,'); 8 | } 9 | 10 | .md-typeset .admonition.preview, 11 | .md-typeset details.preview { 12 | border-color: rgb(220, 139, 240); 13 | } 14 | 15 | .md-typeset .preview>.admonition-title, 16 | .md-typeset .preview>summary { 17 | background-color: rgba(142, 43, 155, 0.1); 18 | } 19 | 20 | .md-typeset .preview>.admonition-title::before, 21 | .md-typeset .preview>summary::before { 22 | background-color: rgb(220, 139, 240); 23 | -webkit-mask-image: var(--md-admonition-icon--preview); 24 | mask-image: var(--md-admonition-icon--preview); 25 | } 26 | -------------------------------------------------------------------------------- /docs/css/mkdocstrings.css: -------------------------------------------------------------------------------- 1 | /* Indentation. */ 2 | div.doc-contents { 3 | padding-left: 25px; 4 | border-left: .05rem solid var(--md-typeset-table-color); 5 | } 6 | 7 | 8 | div.doc-class:not(.doc-contents .doc-contents)::after { 9 | content: ""; 10 | display: block; 11 | width: 100%; 12 | height: 1px; /* Adjust thickness */ 13 | background-color: black; /* Adjust color */ 14 | margin: 10px 0; /* Adjust spacing */ 15 | } 16 | 17 | /* Mark external links as such. */ 18 | a.external::after, 19 | a.autorefs-external::after { 20 | /* https://primer.style/octicons/arrow-up-right-24 */ 21 | mask-image: url('data:image/svg+xml,'); 22 | content: ' '; 23 | 24 | display: inline-block; 25 | vertical-align: middle; 26 | position: relative; 27 | 28 | height: 1em; 29 | width: 1em; 30 | background-color: var(--md-typeset-a-color); 31 | } 32 | 33 | a.external:hover::after, 34 | a.autorefs-external:hover::after { 35 | background-color: var(--md-accent-fg-color); 36 | } 37 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # Tips and FAQ 2 | 3 | See also the [Equinox FAQ](https://docs.kidger.site/equinox/faq/) 4 | 5 | ## Tip 1: `hax.debug.diagnose_common_issues` 6 | 7 | `hax.debug.diagnose_common_issues` is a function that will raise an exception if it detects problems with your module. 8 | Currently, we diagnose: 9 | 10 | * Reuse of arrays or NamedArrays in a field. [Equinox modules must be trees.](https://docs.kidger.site/equinox/faq/#a-module-saved-in-two-places-has-become-two-independent-copies) 11 | * Use of arrays or NamedArrays in a static field. Static data in JAX/Equinox must be hashable, and arrays are not hashable. 12 | -------------------------------------------------------------------------------- /docs/figures/data_parallel_mesh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/data_parallel_mesh.png -------------------------------------------------------------------------------- /docs/figures/data_parallel_mesh_replicated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/data_parallel_mesh_replicated.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_1d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_1d.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_1d_zero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_1d_zero.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_batch_partitioned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_batch_partitioned.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_data_replicated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_data_replicated.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_data_replicated_mlp_partitioned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_data_replicated_mlp_partitioned.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_intermediate_fully_partitioned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_intermediate_fully_partitioned.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_zero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/docs/figures/device_mesh_2d_zero.png -------------------------------------------------------------------------------- /docs/fp8.md: -------------------------------------------------------------------------------- 1 | # Quantized Training 2 | 3 | !!! warning 4 | 5 | FP8 and Int8 training in Haliax is currently experimental and may change in the future. 6 | 7 | Haliax supports training with FP8 and int8. This is useful for training on hardware that is optimized for FP8 or Int8, 8 | such as the H100 (fp8) or A100s (int8) and TPU v5 and newer (int8). 9 | 10 | ## TL;DR 11 | 12 | Using FP8 with Haliax is actually pretty straightforward. To enable FP8, do this: 13 | 14 | ```python 15 | import haliax.quantization as haxq 16 | # setup 17 | module = haxq.quantize_linear_layers(module, haxq.QuantizationConfig(fp8=True)) 18 | 19 | # if using optax. This saves a tiny amount of memory so you can skip it if you want 20 | _, trainable_module = haxq.partition_for_grad_overwrite(module) 21 | opt_state = opt.initial_state(trainable_module) 22 | 23 | # train step 24 | grads = eqx.filter_grad(loss_fn)(module, data) 25 | overwrite, grads = haxq.partition_for_grad_overwrite(grads) 26 | updates, opt_state = opt.update(grads, opt_state, params=module) # or however you update your optimizer 27 | module = haxq.apply_updates(module, updates, overwrite) 28 | ``` 29 | 30 | And train your model like normal. 31 | 32 | Similarly, you can use `Int8` by setting `Int8=True` in the `QuantizationConfig` object. 33 | 34 | 35 | 36 | ## What is FP8? 37 | 38 | FP8 refers to 8-bit floating point numbers. FP8 is a massively reduced precision compared to the 32-bit floating point numbers 39 | or 16-bit floating point numbers that are typically used in deep learning: there are only 256 possible values in FP8, compared to 40 | the (almost) 2^32 in 32-bit and 2^16 in 16-bit. However, FP8 is still useful for training deep learning models, especially on 41 | hardware that is optimized for FP8. In particular, it can massively accelerate training on hardware that is optimized for FP8: 42 | H100 has 2x FP8 FLOPS compared to FP16 FLOPS and almost 60x(!) compared to F32 FLOPS. 43 | 44 | The FP8 in Haliax is currently designed to optimize throughput on FP8-enabled devices (currently H100) rather 45 | than to save memory. In particular, Haliax's FP8 support is not designed to quantize a model to FP8 for deployment, 46 | though this shouldn't be that hard to add for models that were trained using this functionality. 47 | We would be happy to accept contributions to add this functionality, 48 | and are happy to work with you to do so. In particular, adding this for models trained using Haliax's FP8 should be easy. 49 | 50 | See this [FP8 Primer](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for more information on FP8. 51 | 52 | ## What is Int8? 53 | 54 | Int8 refers to 8-bit integers. Int8 has the same number of bits as FP8, but the interpretation is different: instead of 55 | exponentially spaced numbers, Int8 has linearly spaced numbers. 56 | 57 | In Haliax, we support Int8 training through Google's [AQT](https://github.com/google/aqt) library. AQT (for 58 | "Accurate Quantization Training") is a library that allows you to train models with quantization-aware training (QAT). 59 | 60 | ## How to use FP8 or Int8 in Haliax 61 | 62 | To use quantized training in Haliax, you need to do three things: 63 | 64 | * Enable FP8 (or int8) for the layers you want 65 | * Modify your training step to be compatible 66 | 67 | Each of these is just a couple of lines of code. 68 | 69 | ```python 70 | import haliax as hax 71 | import equinox as eqx 72 | import jax 73 | 74 | In = hax.Axis("In", 32) 75 | Mid = hax.Axis("Mid", 128) 76 | Out = hax.Axis("Out", 16) 77 | Hidden = hax.Axis("Hidden", 64) 78 | 79 | 80 | class MyModule(eqx.Module): 81 | up_proj: hax.nn.Linear 82 | down_proj: hax.nn.Linear 83 | 84 | @staticmethod 85 | def init(*, key): 86 | super().__init__() 87 | k_up, k_down = jax.random.split(key) 88 | return MyModule( 89 | up_proj=hax.nn.Linear.init(In, Mid, key=k_up), 90 | down_proj=hax.nn.Linear.init(Mid, Out, key=k_down), 91 | ) 92 | 93 | def __call__(self, x): 94 | x = self.up_proj(x) 95 | x = hax.nn.relu(x) 96 | x = self.down_proj(x) 97 | return x 98 | 99 | module = MyModule.init(key=jax.random.PRNGKey(0)) 100 | 101 | # Enable FP8 102 | module = hax.quantization.quantize_linear_layers(module, QuantizationConfig(fp8=True)) 103 | 104 | # Enable FP8 for a specific layer 105 | from haliax.quantization import QuantizationConfig 106 | 107 | config = QuantizationConfig(targets=["up_proj"], fp8=True) 108 | module = hax.quantization.quantize_linear_layers(module, config) 109 | 110 | # Train step 111 | grads = eqx.filter_grad(loss_fn)(module, data) 112 | overwrite, grads = haxq.partition_for_grad_overwrite(grads) 113 | updates, opt_state = opt.update(grads, opt_state, params=module) # or however you update your optimizer 114 | module = hax.quantization.apply_updates(module, updates, grads) 115 | ``` 116 | 117 | That's it! Just a few lines of code to enable FP8. The `quantize_linear_layers` function will transform your module to use 118 | quantization-aware training for linear layers (or a subset if you want), and the combo of `partition_for_grad_overwrite` and `apply_updates` function will apply the updates to the module 119 | in a way that is compatible with FP8. 120 | 121 | ## How FP8 works 122 | 123 | For an overview of the FP8, see the [FP8 Primer](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html). 124 | You don't need to understand it though. Haliax's FP8 integration is more or less plug and play, as shown above. 125 | The implementation of FP8 in Haliax is more or less a straightforward port (including some copy and paste) of the 126 | [FP8 implementation in Flax](https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py). 127 | 128 | FP8 in JAX (as well as INT8) is typically implemented using "`dot_general` injection", where you pass 129 | a custom implementation of `dot_general` to functions and modules like [haliax.dot][] and [haliax.nn.Linear][]. 130 | The `dot_general` for FP8 is implemented by scaling 131 | the inputs, projecting the inputs to FP8, performing the computation in FP8, and then 132 | scaling the result back to the original precision. 133 | The subtle part of FP8 is that the scaling is a parameter that is trained based on a history of the inputs to the layer 134 | (as well as gradients coming in from backward). This means that the FP8 `dot_general` needs to maintain state. 135 | In Equinox, this means that the `dot_general` is actually a `Module` that packages together the state and the 136 | computation. (Unlike [equinox.nn.StatefulLayer][] which returns a state object you pass back into the module, the FP8 `dot_general` 137 | module hijacks the gradient computation to update its state. This is necessary because the FP8 scaling factors 138 | depend on the gradients.) 139 | 140 | The way this happens is by "hijacking" the gradient computation. When you call `eqx.filter_grad(loss_fn)(module, data)`, 141 | you will get the gradient computation as normal, but you'll also get the updated state of the FP8 `dot_general` module. 142 | This updated state needs to directly replace the state in the module (rather than be used for a gradient step), which is 143 | why you need to use the `partition_for_grad_overwrite` 144 | 145 | The FP8 `dot_general` module is implemented in [haliax.quantization.Fp8DotGeneralOp][]. It's actually not that complicated: 146 | 147 | 1) It holds a scaling factor and history of maximum values for each of (lhs, rhs, output) and updates them based on the 148 | gradients. 149 | 2) When invoked, it scales the inputs, projects them to FP8, performs the computation, and scales the result back to the 150 | original precision. It remembers the maximum absolute value for each of the inputs. 151 | 3) For the gradients, it scales the gradients, projects them to FP8, does the backward computation, 152 | and scales the gradients back to the original precision. It remembers the maximum absolute value for the incoming 153 | gradient and stores it in the gradient. 154 | 155 | ## How Int8 works 156 | 157 | Int8 is in principle the same, though the details differ. AQT is a much more flexible library than the FP8 implementation, 158 | because it can be a bit more finicky. We use AQT directly, and we recommend you look at the 159 | [AQT documentation](https://github.com/google/aqt?tab=readme-ov-file#how-aqt-works-internally) for more 160 | information on how it works. 161 | 162 | # API Reference 163 | 164 | ## Functions 165 | 166 | ::: haliax.quantization.quantize_linear_layers 167 | ::: haliax.quantization.partition_for_grad_overwrite 168 | ::: haliax.quantization.apply_updates 169 | 170 | 171 | ## Interfaces 172 | ::: haliax.quantization.DotGeneralOp 173 | ::: haliax.quantization.OverwriteWithGradient 174 | 175 | ## Modules 176 | 177 | 178 | ::: haliax.quantization.DefaultDotGeneralOp 179 | ::: haliax.quantization.Fp8DotGeneralOp 180 | ::: haliax.quantization.Int8DotGeneralOp 181 | 182 | ## Configuration 183 | 184 | ::: haliax.quantization.QuantizationConfig 185 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | {% 2 | include-markdown "../README.md" 3 | start="" 4 | end="" 5 | %} 6 | 7 | 8 | The code is released on GitHub: [Haliax Repository](https://github.com/stanford-crfm/haliax/). 9 | 10 | 11 | To contribute, please refer to the [Contributing Guide](https://github.com/stanford-crfm/haliax/blob/main/CONTRIBUTING.md). 12 | -------------------------------------------------------------------------------- /docs/matmul.md: -------------------------------------------------------------------------------- 1 | ## Matrix Multiplication 2 | 3 | Haliax has two ways to do matrix multiplication (and tensor contractions more generally): 4 | [haliax.dot][] and [haliax.einsum][]. [haliax.dot][] and [haliax.einsum][] 5 | can both express any tensor contraction, though in different situations one or the other may be 6 | more suitable for expressing a particular contraction In general: 7 | 8 | - Use [haliax.dot][] when you want to express a simple matrix multiplication over one or a few axes. 9 | - Use [haliax.einsum][] when you want to express a more complex tensor contraction. 10 | 11 | See also the API reference for [haliax.dot][] and [haliax.einsum][] and the 12 | [cheat sheet section](cheatsheet.md#matrix-multiplication). 13 | 14 | ### `haliax.dot` 15 | 16 | With [haliax.dot][], you specify the axes to contract over, without needing to write out the 17 | axes you want to keep (though you can if you want): 18 | 19 | ```python 20 | import haliax as hax 21 | 22 | H = hax.Axis("H", 3) 23 | W = hax.Axis("W", 4) 24 | D = hax.Axis("D", 5) 25 | 26 | x = hax.ones((H, W, D)) 27 | w = hax.ones((D,)) 28 | y = hax.dot(x, w, axis=D) # shape is (H, W), equivalent to np.einsum("hwd,d->hw", x, w) 29 | ``` 30 | 31 | [haliax.dot][] is at its best when you want to express a simple matrix multiplication over one or a few axes. 32 | Syntactically, [haliax.dot][] is similar to reduction operations like [haliax.sum][] and [haliax.mean][]. 33 | 34 | The [cheat sheet](cheatsheet.md) has a section on [matrix multiplication](cheatsheet.md#matrix-multiplication) 35 | that gives a few examples. Here are several more: 36 | 37 | ```python 38 | import haliax as hax 39 | 40 | H = hax.Axis("H", 3) 41 | W = hax.Axis("W", 4) 42 | D = hax.Axis("D", 5) 43 | C = hax.Axis("C", 6) 44 | 45 | x = hax.arange((H, W, D, C)) 46 | w = hax.arange((D, C)) 47 | c = hax.arange((C,)) 48 | 49 | y = hax.dot(x, c, axis=C) # shape is (H, W, D), equivalent to jnp.dot(x, c) 50 | 51 | y = hax.dot(x, w, axis=(D, C)) # shape is (H, W), equivalent to np.einsum("...dc,dc->...", x, w) 52 | y = hax.dot(x, w, axis=(D, C), out_axes=(W, H)) # shape is (W, H) instead of (H, W) 53 | y = hax.dot(x, w, c, axis=(D, C)) # shape is (H, W), equivalent to np.einsum("...dc,dc,c->...", x, w, c) 54 | y = hax.dot(x, c, axis=(H, D, C)) # shape is (W,), equivalent to np.einsum("hwdc,c->w", x, c) 55 | s = hax.dot(x, w, axis=None) # scalar output, equivalent to np.einsum("hwdc,dc->", x, w) 56 | y = hax.dot(x, w, c, axis=()) # shape is (H, W, D, C), equivalent to np.einsum("hwdc,dc,c->hwdc", x, w, c) 57 | y = hax.dot(x, w, c, axis=(), out_axes=(D, ..., H)) # shape is (D, W, C, H), equivalent to np.einsum("hwdc,dc,c->dwch", x, w, c) 58 | ``` 59 | 60 | ### `haliax.einsum` 61 | 62 | [haliax.einsum][] is at its best when you want to express a more complex tensor contraction. 63 | It is similar to [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) 64 | or [einops.einsum](https://einops.rocks/api/einsum/) in terms of syntax and behavior, 65 | but extended to work with named axes, including the added flexibility that named axes provide. 66 | Our "flavor" of `einsum` is most similar to `einops.einsum`'s flavor, in that 67 | it supports long names for axes (like `"batch h w, h w channel -> batch channel"`) 68 | rather than the compact notation of `numpy.einsum` (like `"bhwc,hwc->bc"`). 69 | 70 | Haliax's version of `einsum` comes in three modes: "ordered", "unordered", and "output axes". 71 | These modes are all accessible through the same function without any flags: the syntax 72 | of the `einsum` string determines which mode is used. 73 | 74 | The syntax for Haliax's `einsum` is similar to [`haliax.rearrange`](rearrange.md), which 75 | is in turn similar to [einops.rearrange](https://einops.rocks/api/rearrange/). 76 | 77 | #### Ordered Mode 78 | 79 | Haliax's `einsum` has an "ordered" mode that is similar to `einops.einsum`'s behavior. 80 | In this mode, the axes in the input arrays are matched to the axes in the `einsum` string in order. 81 | It supports ellipses in the same way as `einops.einsum`. The names in the einsum string need not 82 | match the names of the axes in the input arrays, but the order of the axes must match. 83 | 84 | ```python 85 | import haliax as hax 86 | 87 | H = hax.Axis("H", 3) 88 | W = hax.Axis("W", 4) 89 | D = hax.Axis("D", 5) 90 | 91 | x = hax.ones((H, W, D)) 92 | w = hax.ones((D,)) 93 | y = hax.einsum("h w d, d -> h w", x, w) # shape is (H, W), equivalent to jnp.einsum("hwd,d->hw", x, w) 94 | y = hax.einsum("... d, d -> ...", x, w) # same as above 95 | ``` 96 | 97 | The `...` syntax is used to indicate that the axes in the input arrays that are not mentioned in the `einsum` string 98 | should be preserved in the output. This should be the same as `einops.einsum`'s behavior, with the exception 99 | that the names of axes with the same label in the string must have the same names in the input arrays. 100 | 101 | (If you notice any differences between Haliax's `einsum`'s ordered syntax and `einops.einsum`, please let us know!) 102 | 103 | #### Unordered Mode 104 | 105 | In "unordered mode," the axes in the input arrays are matched to the axes in the `einsum` string by name, 106 | using similar rules to [haliax.rearrange][]. Names involved in the operation are specified inside `{}` 107 | on the left hand side of the `->` in the `einsum` string. Axes not specified are implicitly preserved. 108 | 109 | ```python 110 | import haliax as hax 111 | 112 | H = hax.Axis("H", 3) 113 | W = hax.Axis("W", 4) 114 | D = hax.Axis("D", 5) 115 | 116 | x = hax.ones((H, W, D)) 117 | w = hax.ones((D,)) 118 | 119 | y = hax.einsum("{H W D} -> H W", x) # shape is (H, W) 120 | y = hax.einsum("{D} -> ", w) # shape is (H, W) 121 | y = hax.einsum("{...} -> ", x) # shape is () 122 | y = hax.einsum("{H ...} -> H", x) # shape is (H,) 123 | y = hax.einsum("{H ...} -> ...", x) # shape is (W, D) 124 | ``` 125 | 126 | This mode is most similar to [haliax.dot][]'s behavior, though it's a bit more expressive. 127 | 128 | You can also use axis aliases in the `einsum` string, which can be useful for expressing contractions 129 | in library code or just for shortening the string: 130 | 131 | ```python 132 | Height = hax.Axis("Height", 3) 133 | Width = hax.Axis("Width", 4) 134 | Depth = hax.Axis("Depth", 5) 135 | 136 | x = hax.ones((Height, Width, Depth)) 137 | w = hax.ones((Depth,)) 138 | 139 | y = hax.einsum("{H W D} -> H W", x, H=Height, W=Width, D=Depth) # shape is (Height, Width) 140 | y = hax.einsum("{D} -> ", w, D=Depth) # shape is (Height, Width) 141 | ``` 142 | 143 | 144 | #### Output Axes Mode 145 | 146 | In "output axes" mode, you only specify the axes that should be in the output. All other 147 | axes are implicitly contracted over. This mode is a bit "dangerous" in that it's easy to 148 | accidentally contract over axes you didn't mean to, but it can be useful for expressing 149 | certain contractions concisely. 150 | 151 | ```python 152 | import haliax as hax 153 | 154 | H = hax.Axis("H", 3) 155 | W = hax.Axis("W", 4) 156 | D = hax.Axis("D", 5) 157 | 158 | x = hax.ones((H, W, D)) 159 | w = hax.ones((D,)) 160 | 161 | y = hax.einsum("-> H W", x) # shape is (H, W) 162 | y = hax.einsum("-> D", w) # shape is (D,) 163 | ``` 164 | 165 | We don't recommend using this mode except in cases when you're sure of the full shape of the input arrays 166 | or you are sure you don't want to let users implicitly batch over any axes. 167 | 168 | Output axes mode also supports axis aliases: 169 | 170 | ```python 171 | Height = hax.Axis("Height", 3) 172 | Width = hax.Axis("Width", 4) 173 | Depth = hax.Axis("Depth", 5) 174 | 175 | x = hax.ones((Height, Width, Depth)) 176 | w = hax.ones((Depth,)) 177 | y = hax.einsum("-> Height Width", x, Height=Height, Width=Width, Depth=Depth) # shape is (Height, Width) 178 | y = hax.einsum("-> Depth", w, Depth=Depth) # shape is (Depth,) 179 | ``` 180 | -------------------------------------------------------------------------------- /docs/nn.md: -------------------------------------------------------------------------------- 1 | # Neural Networks 2 | 3 | 4 | ## Modules 5 | 6 | Haliax provides a small number of neural network modules that are compatible with Equinox, though 7 | they naturally all use [haliax.NamedArray][]. (We welcome PRs for more modules! Nothing too exotic though.) 8 | 9 | The most interesting of these modules is [haliax.nn.Stacked][], which allows you to create homogenous "stacks" 10 | of the same module (e.g. transformer blocks), which is a common pattern in deep learning. 11 | 12 | ### Linear 13 | 14 | ::: haliax.nn.Embedding 15 | ::: haliax.nn.Linear 16 | 17 | ### Dropout 18 | ::: haliax.nn.Dropout 19 | 20 | ### Normalization 21 | 22 | ::: haliax.nn.normalization.LayerNormBase 23 | ::: haliax.nn.LayerNorm 24 | ::: haliax.nn.RmsNorm 25 | 26 | ### Meta 27 | 28 | ::: haliax.nn.MLP 29 | 30 | ### Stacked 31 | 32 | See the full documentation of [Stacked](scan.md#stacked). 33 | 34 | ### Convolution 35 | 36 | Unlike other frameworks, Haliax doesn't distinguish between 1D, 2D, and 3D, and general convolutions. Instead, we have 37 | a single [haliax.nn.Conv][] module that can be used for all of these, depending on the number of axes 38 | provided. Similarly, for transposed convolutions, we have [haliax.nn.ConvTranspose][]. 39 | 40 | ::: haliax.nn.Conv 41 | ::: haliax.nn.ConvTranspose 42 | 43 | ### Pooling 44 | 45 | As with convolutions, we don't distinguish between 1D, 2D, and 3D pooling, and instead have a single 46 | pooling operation for each of the kinds of reductions: 47 | 48 | ::: haliax.nn.max_pool 49 | ::: haliax.nn.mean_pool 50 | ::: haliax.nn.min_pool 51 | 52 | ## Attention 53 | 54 | We don't provide an explicit attention module, but we do provide an attention function and several related functions: 55 | 56 | :::haliax.nn.attention.dot_product_attention 57 | :::haliax.nn.attention.dot_product_attention_weights 58 | :::haliax.nn.attention.self_attention 59 | 60 | ### Masks 61 | ::: haliax.nn.attention.causal_mask 62 | ::: haliax.nn.attention.prefix_lm_mask 63 | ::: haliax.nn.attention.combine_masks_and 64 | ::: haliax.nn.attention.combine_masks_or 65 | ::: haliax.nn.attention.forgetful_causal_mask 66 | 67 | ### Biases 68 | 69 | ::: haliax.nn.attention.mask_to_bias 70 | ::: haliax.nn.attention.alibi_attention_bias 71 | 72 | ## Functions 73 | 74 | These functions wrap the equivalent in [jax.nn][]: 75 | 76 | ::: haliax.nn.relu 77 | ::: haliax.nn.relu6 78 | ::: haliax.nn.sigmoid 79 | ::: haliax.nn.softplus 80 | ::: haliax.nn.soft_sign 81 | ::: haliax.nn.silu 82 | ::: haliax.nn.swish 83 | ::: haliax.nn.log_sigmoid 84 | ::: haliax.nn.leaky_relu 85 | ::: haliax.nn.hard_sigmoid 86 | ::: haliax.nn.hard_silu 87 | ::: haliax.nn.hard_swish 88 | ::: haliax.nn.hard_tanh 89 | ::: haliax.nn.elu 90 | ::: haliax.nn.celu 91 | ::: haliax.nn.selu 92 | ::: haliax.nn.gelu 93 | ::: haliax.nn.quick_gelu 94 | ::: haliax.nn.glu 95 | ::: haliax.nn.logsumexp 96 | ::: haliax.nn.log_softmax 97 | ::: haliax.nn.softmax 98 | ::: haliax.nn.standardize 99 | ::: haliax.nn.one_hot 100 | 101 | ### Loss Functions 102 | 103 | ::: haliax.nn.cross_entropy_loss 104 | ::: haliax.nn.cross_entropy_loss_and_log_normalizers 105 | -------------------------------------------------------------------------------- /docs/partitioning.md: -------------------------------------------------------------------------------- 1 | # Partitioning 2 | 3 | Partitioning refers to the process of splitting arrays and computation across multiple devices. Haliax provides a number 4 | of functions for partitioning arrays and computation across multiple devices. 5 | 6 | 7 | ## Tutorial 8 | An introduction to using Haliax's partitioning functions to scale a transformer can be found here: [Distributed Training in Haliax](https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz). 9 | 10 | This page is designed to be more of a reference than a tutorial, and we assume you've read the tutorial before reading this page. 11 | 12 | 13 | ## Device Meshes in JAX 14 | 15 | See also JAX's tutorial [Distributed Arrays and Automatic Parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) 16 | for more details. 17 | 18 | One of the main ways JAX provides distributed parallelism is via the [jax.sharding.Mesh][]. 19 | A mesh is a logical n-dimensional array of devices. Meshes in JAX are represented as a Numpy ndarray (note: not `jax.numpy`) 20 | of devices and a tuple of axis names. For example, a 2D mesh of 16 devices might look like this: 21 | 22 | ```python 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | from jax.sharding import Mesh 27 | 28 | devices = jax.devices() 29 | mesh = Mesh(jnp.array(devices).reshape((-1, 2)), ("data", "model")) 30 | ``` 31 | 32 | ![2d Device Mesh showing 16 devices](figures/device_mesh_2d.png) 33 | 34 | The mesh above has two axes, `data` and `model`. In JAX's mesh parallelism, arrays are distributed by overlaying axes of 35 | the array on top of the axes of the mesh. For example, if we have a batch of 32 sequences we might do something like this: 36 | 37 | ```python 38 | from jax.sharding import NamedSharding, PartitionSpec 39 | 40 | batch_size = 32 41 | seqlen = 512 42 | 43 | batch = jnp.zeros((batch_size, seqlen), dtype=jnp.float32) 44 | batch = jax.device_put(batch, NamedSharding(mesh, PartitionSpec("data", None))) 45 | ``` 46 | 47 | This specifies that the first axis of `batch` should be distributed across the `data` axis of the mesh. The `None` in the 48 | `PartitionSpec` indicates that the second axis of `batch` is not distributed, which means that the data is replicated 49 | so that one copy of the data is partitioned across each row of the mesh. 50 | 51 | ![Device Mesh showing 16 devices with data partitioned across data axis](figures/device_mesh_2d_batch_partitioned.png) 52 | 53 | What's nice about this approach is that jax will automatically schedule computations so that operations are distributed 54 | in the way you would expect: you don't have to explicitly manage communication between devices. 55 | 56 | However, JAX sometimes gets confused, and it's not sure how you want your arrays partitioned. In Jax, there's a function 57 | called [jax.lax.with_sharding_constraint][] that lets you explicitly specify the sharding for the outputs of arrays. 58 | You use this function only inside `jit`. 59 | 60 | ## Haliax Partitioning in a nutshell 61 | 62 | As you might imagine, it gets tedious and error-prone to have to specify the partitioning of every array you create. Haliax provides 63 | routines to handle mapping of [haliax.NamedArray][]s automatically. 64 | 65 | ```python 66 | import haliax as hax 67 | 68 | Batch = hax.Axis("batch", 32) 69 | SeqLen = hax.Axis("seqlen", 512) 70 | 71 | axis_mapping = {"batch": "data", } 72 | 73 | batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32) 74 | batch = hax.shard(batch, axis_mapping) 75 | 76 | # we also have "auto_sharded" and support context mappings for axis mappings: 77 | with hax.axis_mapping({"batch": "data"}): 78 | batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32) 79 | batch = hax.shard(batch) 80 | ``` 81 | 82 | Unlike in JAX, which has separate APIs for partitioning arrays inside and outside of `jit`, Haliax has a single API: 83 | `hax.shard` work inside and outside of `jit`. Haliax automatically 84 | chooses which JAX function to use based on context. 85 | 86 | 87 | ## Axis Mappings 88 | 89 | The core data structure we use to represent partitioning is the [haliax.partitioning.ResourceMapping][] which 90 | is just an alias for a `Dict[str, str|Sequence[str]]`. The keys in this dictionary are the names of "logical" Axes in NamedArrays 91 | and the values are the names of axes in the mesh. (In theory you can partition a single Axis across multiple axes in the mesh, 92 | but we don't use this functionality.) 93 | 94 | ::: haliax.partitioning.ResourceMapping 95 | 96 | A context manager can be used to specify an axis mapping for the current thread for the duration of the context: 97 | 98 | ```python 99 | with hax.axis_mapping({"batch": "data"}): 100 | batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32) 101 | batch = hax.auto_sharded(batch) 102 | ``` 103 | 104 | ::: haliax.partitioning.axis_mapping 105 | 106 | ## Partitioning Functions 107 | 108 | ### Sharding Arrays and PyTrees 109 | 110 | These functions are used to shard arrays and PyTrees of arrays, e.g. Modules. 111 | This is the main function you will use to shard arrays: 112 | 113 | ::: haliax.shard 114 | 115 | This function is like `shard` but does not issue a warning if there is no context axis mapping. 116 | It's useful for library code where there may or may not be a context mapping: 117 | 118 | ::: haliax.auto_sharded 119 | 120 | This is an older function that is being deprecated in favor of `shard`. It is functionally equivalent to `shard`: 121 | 122 | ::: haliax.shard_with_axis_mapping 123 | 124 | ### `named_jit` and friends 125 | 126 | ::: haliax.named_jit 127 | ::: haliax.fsdp 128 | 129 | 130 | ### Querying the Mesh and Axis Mapping 131 | 132 | 133 | ::: haliax.partitioning.round_axis_for_partitioning 134 | ::: haliax.partitioning.physical_axis_name 135 | ::: haliax.partitioning.physical_axis_size 136 | ::: haliax.partitioning.sharding_for_axis 137 | -------------------------------------------------------------------------------- /docs/rearrange.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "initial_id", 7 | "metadata": { 8 | "collapsed": true, 9 | "ExecuteTime": { 10 | "end_time": "2023-10-24T19:27:47.968939Z", 11 | "start_time": "2023-10-24T19:27:47.349182Z" 12 | } 13 | }, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", 20 | "I0000 00:00:1698175667.814623 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import haliax as hax\n", 26 | "from jax.random import PRNGKey\n", 27 | "\n", 28 | "N = hax.Axis(\"N\", 8)\n", 29 | "C = hax.Axis(\"C\", 3)\n", 30 | "H = hax.Axis(\"H\", 64)\n", 31 | "W = hax.Axis(\"W\", 64)\n", 32 | "\n", 33 | "x = hax.random.uniform(PRNGKey(0), (N, C, H, W))\n", 34 | "\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 9, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": "{'N': 8, 'C': 3, 'P': 64, 'H': 8, 'W': 8}" 44 | }, 45 | "execution_count": 9, 46 | "metadata": {}, 47 | "output_type": "execute_result" 48 | } 49 | ], 50 | "source": [ 51 | "# split into patches\n", 52 | "hax.rearrange(x, \"N C (ph H) (pw W) -> N C (P: ph pw) H W\", ph=8, pw=8)\n", 53 | "# order agnostic\n", 54 | "hax.rearrange(x, \"{(H: ph H) (W: pw W)} -> ... (P: ph pw) H W\", ph=8, pw=8)" 55 | ], 56 | "metadata": { 57 | "collapsed": false, 58 | "ExecuteTime": { 59 | "end_time": "2023-10-24T19:30:14.757668Z", 60 | "start_time": "2023-10-24T19:30:14.753994Z" 61 | } 62 | }, 63 | "id": "f0178b7a92a41783" 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "Python 3", 69 | "language": "python", 70 | "name": "python3" 71 | }, 72 | "language_info": { 73 | "codemirror_mode": { 74 | "name": "ipython", 75 | "version": 2 76 | }, 77 | "file_extension": ".py", 78 | "mimetype": "text/x-python", 79 | "name": "python", 80 | "nbconvert_exporter": "python", 81 | "pygments_lexer": "ipython2", 82 | "version": "2.7.6" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 5 87 | } 88 | -------------------------------------------------------------------------------- /docs/rearrange.md: -------------------------------------------------------------------------------- 1 | # Rearrange 2 | 3 | ## Introduction 4 | 5 | Haliax strives to be "order independent": the order of axes should not impact the correctness of the program. However, 6 | when interfacing with non-named APIs (e.g. the JAX Numpy API or Equinox), you have to be able to know exactly what the 7 | order of axes is. In addition, the order of axes can impact performance. To that end, Haliax provides a `rearrange` 8 | function that allows you to specify the order of axes in a tensor. 9 | 10 | In addition, it is sometimes necessary to split and merge axes: turning images into patches, 11 | or turning a batch of images into a single image. Without `rearrange`, this is a bit clunky. 12 | 13 | `rearrange` comes in two flavors: sequence syntax and einops-style syntax. Sequence 14 | syntax is just for transposing axes, while [einops-style](https://einops.rocks/) syntax is more 15 | powerful and can also split and merge axes. 16 | 17 | ## Sequence Syntax 18 | 19 | The sequence syntax is very simple: you just provide a sequence of axis names, and the tensor 20 | will be rearranged to match that sequence. For example: 21 | 22 | ```python 23 | import haliax as hax 24 | import jax.random as jrandom 25 | 26 | N = hax.Axis("N", 32) 27 | C = hax.Axis("C", 3) 28 | H = hax.Axis("H", 64) 29 | W = hax.Axis("W", 64) 30 | 31 | x = hax.random.normal(jrandom.PRNGKey(0), (N, C, H, W)) 32 | 33 | y = hax.rearrange(x, (N, H, W, C)) 34 | 35 | # at most one ellipsis is allowed 36 | z = hax.rearrange(x, (N, ..., C)) 37 | 38 | # you can use strings instead of axis objects 39 | z = hax.rearrange(x, ("N", ..., "C")) 40 | ``` 41 | 42 | As we said before, almost all Haliax operations are order-agnostic, so (this version of) `rearrange` only impacts 43 | performance and allows you to interface with other libraries that need you to specify the order of axes 44 | for an unnamed tensor. 45 | 46 | ## Einops-style Syntax 47 | 48 | [einops](https://einops.rocks/) is a powerful library for manipulating tensor shapes, generalizing 49 | `reshape`, `transpose`, and other shape-manipulation operations. Haliax provides a subset of its functionality 50 | (specifically `rearrange` and not `repeat` or `reduce`, which are less useful in named code). The syntax has been generalized to named 51 | tensors. 52 | 53 | If you're used to einops, the syntax should be familiar, with the main differences being specifying names 54 | and the additional "unordered" syntax for selecting dimensions by name. 55 | 56 | !!! warning 57 | 58 | This syntax is fairly new. It is pretty well-tested, but it is possible that there are bugs. 59 | 60 | ### Examples 61 | 62 | Examples are probably the best way to get a feel for the syntax: 63 | 64 | ```python 65 | import haliax as hax 66 | import jax.random as jrandom 67 | 68 | N = hax.Axis("N", 32) 69 | C = hax.Axis("C", 3) 70 | H = hax.Axis("H", 64) 71 | W = hax.Axis("W", 64) 72 | 73 | x = hax.random.normal(jrandom.PRNGKey(0), (N, C, H, W)) 74 | 75 | # transpose/permute axes 76 | y = hax.rearrange(x, "N C H W -> N H W C") 77 | # names don't have to match with positional syntax 78 | z = hax.rearrange(x, "num ch h w -> num h w ch") 79 | # ellipsis can be used to specify the rest of the dimensions 80 | z = hax.rearrange(x, "N C ... -> N ... C") 81 | 82 | # unordered patterns allow you to match a subset of dimensions by name, rather than using positional matching 83 | # transpose last two dimensions using the unordered syntax 84 | y = hax.rearrange(x, "{H W} -> ... W H") 85 | 86 | # don't know the order? use an unordered pattern 87 | y = hax.rearrange(x, "{W C H N} -> N H W C") 88 | 89 | # split dims as in einops 90 | y = hax.rearrange(x, "(step microbatch) ... -> step microbatch ...", step=4) 91 | # splitting dims can be done using unordered syntax, similar to positional syntax 92 | y = hax.rearrange(x, "{(N: step microbatch) ...} -> step microbatch ...", step=4) 93 | 94 | # merging dims requires a name 95 | x = hax.rearrange(y, "step microbatch ... -> (N: step microbatch) ...") 96 | 97 | # you can partially specify the order by using two or more ellipses on the rhs 98 | y = hax.rearrange(x, "{H W} -> ... (F: H W) ...") 99 | y = hax.rearrange(x, "{H W C} -> ... (F: H W) ... C") # ensures C is the last dimension 100 | 101 | 102 | # some fancier examples 103 | 104 | # split into patches 105 | y = hax.rearrange(x, "N C (patch_h H) (patch_w W) -> N (P: patch_h patch_w) C H W", H=4, W=4) 106 | # unordered version 107 | y = hax.rearrange(x, "{(H: patch_h H) (W: patch_w W) C } -> ... (P: patch_h patch_w) C H W", H=4, W=4) 108 | 109 | # split into patches, then merge patches and channels 110 | y = hax.rearrange(x, "N C (patch_h H) (patch_w W) -> N (P: patch_h patch_w) (C: C H W)", H=4, W=4) 111 | # unordered version 112 | y = hax.rearrange(x, "{(H: patch_h H) (W: patch_w W) C } -> ... (P: patch_h patch_w) (C: C H W)", H=4, W=4) 113 | ``` 114 | 115 | ### Bindings: Aliasing and Sizing 116 | 117 | In the above examples we used keyword arguments to give sizes to split axes, which is the same 118 | as in einops. However, we can also use bindings to alias axes. For example: 119 | 120 | ```python 121 | # this produces the same result as the previous example 122 | y2 = hax.rearrange(x, "N C (patch_h foo) (patch_w bar) -> N (P: patch_h patch_w) (C: C foo bar)", foo=hax.Axis("H", 4), bar=hax.Axis("W", 4)) 123 | assert y.axes == y2.axes 124 | ``` 125 | 126 | This example is a bit contrived, but the point is that this syntax lets us use shorter or different names in the string, 127 | which is occasionally useful. 128 | 129 | You can actually pass in a string alias instead of an axis object, and it will be converted to an axis object for you: 130 | For instance, if we wanted "P" to actually be called "patch", but wanted to keep the short syntax, we could do: 131 | 132 | ```python 133 | y3 = hax.rearrange(x, "N C (nh ph) (nw pw) -> N (P: nh nw) (C: C ph pw)", P="patch", pw=4, ph=4) 134 | ``` 135 | 136 | 137 | ### Differences from einops 138 | 139 | As you may have noticed, there are some differences from einops: 140 | 141 | * Merged axes must have a name: `(C: C H W)` instead of `(C H W)`. 142 | * The unordered syntax with `{ }` is new: you select dimensions by name instead of by position. 143 | * As discussed immediately above, you can use bindings to specify axis objects for names as well as sizes. 144 | 145 | ### Syntax 146 | 147 | Semiformally, the syntax is an `lhs -> rhs` pair, where the `lhs` is either ordered or unordered, and the `rhs` is always ordered. 148 | For the `lhs`: 149 | 150 | * An *ordered lhs* is a sequence of axis variables (e.g. `x`) or (named or anonymous) split axes (e.g. `(x y)`), separated by spaces or commas, and up to one ellipsis 151 | * An *unordered lhs* is a sequence of axis names (e.g. `x`, where `x` is an axis name in the input array) or named split axes (e.g. `(x: y z)`), surrounded by `{}`, separated by spaces or commas. 152 | 153 | * An *axis variable* is an identifier (that need not correspond to an axis name in the input or output.) 154 | * An *axis name* is an identifier that corresponds to an axis name in the input or output. 155 | * An *anonymous split axis* is a parenthesized expression of the form `(ident*)`, e.g. `(N C)`. 156 | * A *named split axis* is a parenthesized expression of the form `(name: ident*)`, where `name` is the name of an axis and the `ident` are axis variable names, e.g. `(N: s mb)` 157 | 158 | A note on "axis variable" vs "axis name": the former is an identifier that can refer to any axis and is matched 159 | by **position** in the input, while the latter is an identifier that refers to a specific axis and is matched by **name** in the input 160 | (or used to name an axis in the output). 161 | 162 | The `rhs` is always ordered. Its syntax is similar to an ordered `lhs`, except that merged axes must always be named and there may be more than one ellipsis. 163 | 164 | * *rhs* is a sequence of axis variable names or named merged axes, separated by spaces or commas, and some number of ellipses. 165 | 166 | * *Named merged axes* are parenthesized expressions of the form `(name: ident ident ...)`, where `name` is an axis name and `ident` is an identifier. 167 | The name is used to name the merged axis in the output, and the `ident` are axis variable names that are merged from the input. 168 | 169 | Identifiers in the `rhs` must be "bound" by an identifier in the `lhs`, that is, they must appear in the `lhs` as an *axis variable name*. 170 | 171 | As in einops: split and merged axes are processed in "C-order": the first dimension is the most significant, and the 172 | last dimension is the least significant. 173 | 174 | 175 | ## Error Handling 176 | 177 | `rearrange` attempts to be as helpful as possible when it encounters errors. For example: 178 | 179 | ```python 180 | y = hax.rearrange(x, "N C H W Z -> N H W C") 181 | # ValueError: Error while parsing: 182 | # N C H W Z -> N H W C 183 | # ^ 184 | # Too many axes in lhs 185 | ``` 186 | 187 | In general, it will try to give you a helpful error message that points to the problem in the string. 188 | 189 | 190 | ## API Documentation 191 | 192 | See [haliax.rearrange][]. 193 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.6.1 2 | mkdocs-autorefs==1.4.1 3 | mkdocs-bootswatch==1.1 4 | mkdocs-gen-files==0.5.0 5 | mkdocs-get-deps==0.2.0 6 | mkdocs-include-markdown-plugin==7.1.5 7 | mkdocs-literate-nav==0.6.1 8 | mkdocs-macros-plugin==1.3.7 9 | mkdocs-material==9.6.7 10 | mkdocs-material-extensions==1.3.1 11 | mkdocstrings==0.29.0 12 | mkdocstrings-python==1.16.4 13 | -------------------------------------------------------------------------------- /docs/tutorial.md: -------------------------------------------------------------------------------- 1 | Haliax's tutorials are hosted on Google Colab. You can run the tutorials in your browser without installing anything. 2 | 3 | Here are our current tutorials: 4 | 5 | {% 6 | include-markdown "../README.md" 7 | start="" 8 | end="" 9 | %} 10 | -------------------------------------------------------------------------------- /docs/vmap.md: -------------------------------------------------------------------------------- 1 | ## Vectorization 2 | 3 | (This is a work in progress. Please contact dlwh for more information.) 4 | 5 | ::: haliax.vmap 6 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Haliax 2 | repo_url: https://github.com/stanford-crfm/haliax/ 3 | edit_uri: blob/main/docs/ 4 | theme: 5 | name: material 6 | highlightjs: false 7 | features: 8 | - content.code.copy 9 | markdown_extensions: 10 | - attr_list 11 | - admonition 12 | #- callouts 13 | - footnotes 14 | - codehilite 15 | - pymdownx.details # Allowing hidden expandable regions denoted by ??? 16 | - pymdownx.magiclink 17 | - pymdownx.superfences 18 | - pymdownx.arithmatex: # Render LaTeX via MathJax 19 | generic: true 20 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. 21 | - pymdownx.snippets: # Include one Markdown file into another 22 | base_path: docs 23 | - pymdownx.inlinehilite 24 | - pymdownx.snippets: 25 | check_paths: true 26 | - pymdownx.superfences 27 | - toc: 28 | permalink: "¤" 29 | toc_depth: "2-3" 30 | 31 | plugins: 32 | - search 33 | - autorefs 34 | - mkdocstrings: 35 | handlers: 36 | python: 37 | # setup_commands: 38 | # - import pytkdocs_tweaks 39 | # - pytkdocs_tweaks.main() 40 | paths: [src] 41 | import: 42 | - https://docs.python.org/3/objects.inv 43 | - https://jax.readthedocs.io/en/latest/objects.inv 44 | - https://docs.kidger.site/equinox/objects.inv 45 | - https://einops.rocks/objects.inv 46 | options: 47 | docstring_options: 48 | ignore_init_summary: true 49 | show_source: false 50 | filters: 51 | - "!^_" 52 | heading_level: 4 53 | inherited_members: true 54 | members_order: source 55 | merge_init_into_class: true 56 | parameter_headings: true 57 | separate_signature: false 58 | load_external_modules: true 59 | preload_modules: [haliax, haliax.core] 60 | show_if_no_docstring: true 61 | show_root_heading: true 62 | show_root_full_path: false 63 | show_signature_annotations: true 64 | docstring_section_style: list 65 | show_symbol_type_heading: true 66 | show_symbol_type_toc: false 67 | signature_crossrefs: true 68 | line_length: 100 69 | summary: true 70 | 71 | - include-markdown 72 | extra_css: 73 | - css/material.css 74 | - css/mkdocstrings.css 75 | 76 | 77 | watch: 78 | - src 79 | - docs 80 | nav: 81 | - Home: 'index.md' 82 | - Tutorials: 83 | - "Introduction to Haliax": https://colab.research.google.com/drive/1TiTcQQ4V5mopbgCu1SVl-oqJtXn7rFnC 84 | - "Distributed Training and FSDP": https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz 85 | - "Tensor Parallelism": https://colab.research.google.com/drive/18_BrtDpe1lu89M4T6fKzda8DdSLtFJhi 86 | - "Mixed Precision with `jmp`": https://colab.research.google.com/drive/1_4cikwt-UhSH7yRzNRK8ze9msM9r2mEl?usp=sharing 87 | - Cheatsheet: 'cheatsheet.md' 88 | - Named Arrays: 89 | - Broadcasting: 'broadcasting.md' 90 | - Indexing and Slicing: 'indexing.md' 91 | - Rearrange: 'rearrange.md' 92 | - Matrix Multiplication: 'matmul.md' 93 | - Higher Order Functions: 94 | - Scan and Fold: 'scan.md' 95 | - Vectorization: 'vmap.md' 96 | - Neural Networks: 'nn.md' 97 | - Partitioning: 'partitioning.md' 98 | - FP8: 'fp8.md' 99 | - Serialization: 'state-dict.md' 100 | - API Reference: 'api.md' 101 | - FAQ: 'faq.md' 102 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "haliax" 7 | # Set for release builds 8 | #version = "1.3" 9 | authors = [ 10 | { name="David Hall", email="dlwh@cs.stanford.edu" }, 11 | ] 12 | description = "Named Tensors for Legible Deep Learning in JAX" 13 | readme = "README.md" 14 | requires-python = ">=3.10" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: Apache Software License", 18 | "Operating System :: POSIX :: Linux", 19 | "Operating System :: MacOS :: MacOS X", 20 | "Development Status :: 4 - Beta", 21 | "Intended Audience :: Science/Research", 22 | ] 23 | dependencies = [ 24 | # we require that you install jax yourself, since the extras vary by system. 25 | # jax = {version = ">=0.4.19,<0.5.0"} 26 | "equinox>=0.10.6", 27 | "jaxtyping>=0.2.20", 28 | "jmp>=0.0.4", 29 | "safetensors>=0.4.3", 30 | "aqtp>=0.8.2", 31 | ] 32 | dynamic =[ "version" ] 33 | 34 | [project.optional-dependencies] 35 | dev=["pytest >= 7.4.0", "mypy >= 0.910", "mkdocs >= 1.4.3", "mkdocs-material >= 7.3.3", "mkdocstrings >= 0.22.0", 36 | "mkdocs-literate-nav >= 0.6.0", "mkdocs-macros-plugin >= 0.7.0", "mkdocstrings-python >= 1.1.2", 37 | "mkdocs-include-markdown-plugin", 38 | "pymdown-extensions", 39 | "pygments", 40 | "pymdown-extensions", 41 | "chex>=0.1.86" 42 | ] 43 | 44 | 45 | [tool.hatch.version] 46 | path = "src/haliax/__about__.py" 47 | 48 | [options] 49 | packages = ["haliax", "haliax.*"] 50 | 51 | [options.package_data] 52 | haliax = ["src/haliax/*"] 53 | 54 | [tool.black] 55 | line-length = 119 56 | target-version = ["py310"] 57 | preview = true 58 | 59 | [tool.isort] 60 | profile = "black" 61 | multi_line_output = 3 62 | lines_after_imports = 2 63 | include_trailing_comma = true 64 | force_grid_wrap = 0 65 | use_parentheses = true 66 | ensure_newline_before_comments = true 67 | line_length = 119 68 | src_paths = ["src", "tests"] 69 | 70 | [project.urls] 71 | "Homepage" = "https://github.com/stanford-crfm/haliax" 72 | "Bug Tracker" = "https://github.com/stanford-crfm/haliax/issues/" 73 | "Documentation" = "https://haliax.readthedocs.io/en/latest/" 74 | -------------------------------------------------------------------------------- /src/haliax/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.4" 2 | -------------------------------------------------------------------------------- /src/haliax/_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/haliax/af29b9e5ed08056250300df7bc1e908e9a1ba8de/src/haliax/_src/__init__.py -------------------------------------------------------------------------------- /src/haliax/_src/compile_utils.py: -------------------------------------------------------------------------------- 1 | # This whole file is copied from Equinox. 2 | # (c) 2023, Google LLC. and/or Patrick Kidger. Apache 2.0 licensed. 3 | # Patrick doesn't like that I depend on Equinox internals, so I copied this stuff 4 | import functools as ft 5 | import types 6 | import warnings 7 | import weakref 8 | from typing import Any 9 | 10 | import jax.tree_util as jtu 11 | 12 | 13 | def _strip_wrapped_partial(fun): 14 | if hasattr(fun, "__wrapped__"): # ft.wraps 15 | return _strip_wrapped_partial(fun.__wrapped__) 16 | if isinstance(fun, ft.partial): 17 | return _strip_wrapped_partial(fun.func) 18 | return fun 19 | 20 | 21 | internal_caches = [] # type: ignore 22 | internal_lru_caches = [] # type: ignore 23 | 24 | 25 | def clear_caches(): 26 | """Clears internal Equinox caches. 27 | 28 | Best used before calling `jax.clear_caches()` or `jax.clear_backends()`. 29 | 30 | **Arguments:** 31 | 32 | None. 33 | 34 | **Returns:** 35 | 36 | None. 37 | """ 38 | for cache in internal_caches: 39 | cache.clear() 40 | for cache in internal_lru_caches: 41 | cache.cache_clear() 42 | 43 | 44 | def _default_cache_key(): 45 | assert False 46 | 47 | 48 | def compile_cache(cacheable_fn): 49 | cache = weakref.WeakKeyDictionary() # type: ignore 50 | internal_caches.append(cache) 51 | 52 | def cached_fn_impl(leaves, treedef): 53 | user_fn_names, args, kwargs = jtu.tree_unflatten(treedef, leaves) 54 | return cacheable_fn(user_fn_names, *args, **kwargs) 55 | 56 | @ft.wraps(cacheable_fn) 57 | def wrapped_cacheable_fn(user_fn, *args, **kwargs): 58 | user_fn = _strip_wrapped_partial(user_fn) 59 | # Best-effort attempt to clear the cache of old and unused entries. 60 | cache_key: Any 61 | if type(user_fn) is types.FunctionType: # noqa: E721 62 | cache_key = user_fn 63 | else: 64 | cache_key = _default_cache_key 65 | 66 | try: 67 | user_fn_names = user_fn.__name__, user_fn.__qualname__ 68 | except AttributeError: 69 | user_fn_names = type(user_fn).__name__, type(user_fn).__qualname__ 70 | leaves, treedef = jtu.tree_flatten((user_fn_names, args, kwargs)) 71 | leaves = tuple(leaves) 72 | 73 | try: 74 | cached_fn = cache[cache_key] 75 | except KeyError: 76 | cached_fn = cache[cache_key] = ft.lru_cache(maxsize=None)(cached_fn_impl) 77 | return cached_fn(leaves, treedef) 78 | 79 | def delete(user_fn): 80 | user_fn = _strip_wrapped_partial(user_fn) 81 | if type(user_fn) is types.FunctionType: # noqa: E721 82 | try: 83 | del cache[user_fn] 84 | except KeyError: 85 | warnings.warn(f"Could not delete cache for function {user_fn}. Has it already been deleted?") 86 | else: 87 | warnings.warn("Could not delete non-function from cache.") 88 | 89 | wrapped_cacheable_fn.delete = delete # type: ignore 90 | return wrapped_cacheable_fn 91 | -------------------------------------------------------------------------------- /src/haliax/_src/dot.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import typing 3 | import warnings 4 | from typing import Dict, Optional, Tuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | import haliax 10 | from haliax.axis import ( 11 | Axis, 12 | AxisSelection, 13 | PartialAxisSpec, 14 | axis_name, 15 | eliminate_axes, 16 | rearrange_for_partial_order, 17 | union_axes, 18 | ) 19 | from haliax.core import NamedArray 20 | from haliax.jax_utils import _jittable_dg_einsum 21 | from haliax.types import DTypeLike, PrecisionLike 22 | from haliax.util import ensure_tuple 23 | 24 | 25 | # deprecated overload 26 | @typing.overload 27 | def dot( 28 | axis: Optional[AxisSelection], 29 | *arrays: NamedArray, 30 | precision: PrecisionLike = None, 31 | preferred_element_type: Optional[DTypeLike] = None, 32 | out_axes: Optional[PartialAxisSpec] = ..., 33 | dot_general=jax.lax.dot_general, 34 | ) -> NamedArray: 35 | ... 36 | 37 | 38 | @typing.overload 39 | def dot( 40 | *arrays: NamedArray, 41 | axis: Optional[AxisSelection], 42 | precision: PrecisionLike = None, 43 | preferred_element_type: Optional[DTypeLike] = None, 44 | out_axes: Optional[PartialAxisSpec] = ..., 45 | dot_general=jax.lax.dot_general, 46 | ) -> NamedArray: 47 | ... 48 | 49 | 50 | def dot( 51 | *arrays, 52 | precision: PrecisionLike = None, 53 | preferred_element_type: Optional[DTypeLike] = None, 54 | out_axes: Optional[PartialAxisSpec] = None, 55 | dot_general=jax.lax.dot_general, 56 | **kwargs, 57 | ) -> NamedArray: 58 | """Returns the tensor product of two NamedArrays. The axes `axis` are contracted over, 59 | and any other axes that are shared between the arrays are batched over. Non-contracted Axes in one 60 | that are not in the other are preserved. 61 | 62 | Note that if `axis` is None, the result will be a scalar, not a NamedArray. The semantics of `axis=None` are 63 | similar to e.g. how `sum` and other reduction functions work in numpy. If `axis=()`, then the result will be 64 | an "outer product" of the arrays, i.e. a tensor with shape equal to the concatenation of the shapes of the arrays. 65 | 66 | By default, the order of output axes is determined by the order of the input axes, such that each output axis 67 | occurs in the same order as it first occurs in the concatenation of the input axes. 68 | 69 | If `out_axes` is provided, the output will be transposed to match the provided axes. `out_axes` may be a partial 70 | specification of the output axes (using ellipses), in which case the output will be rearranged to be consistent 71 | with the partial specification. For example, if `out_axes=(..., Height, Width)` and the output axes are 72 | `(Width, Height, Depth)`, the output will be transposed to `(Depth, Height, Width)`. Multiple ellipses 73 | are supported, in which case axes will be inserted according to a greedy heuristic that prefers to place 74 | unconstrained axes as soon as all prior axes in the "natural" order are covered. 75 | 76 | Args: 77 | *arrays (NamedArray): The arrays to contract. 78 | axis (AxisSelection): The axes to contract over. 79 | precision (PrecisionLike, optional): The precision to use. Defaults to None. This argument is passed to `jax.numpy.einsum`, 80 | which in turn passes it to jax.lax.dot_general. 81 | preferred_element_type (DTypeLike, optional): The preferred element type of the result. Defaults to None. 82 | This argument is passed to `jax.numpy.einsum`. 83 | out_axes (Optional[PartialAxisSpec], optional): a potentially partial specification of the output axes. 84 | If provided, the output will be transposed to match the provided axes. Defaults to None. 85 | 86 | 87 | Returns: 88 | NamedArray: The result of the contraction. 89 | """ 90 | if len(arrays) == 0: 91 | raise ValueError("Must provide at least one array to dot") 92 | 93 | if "axis" in kwargs: 94 | axis = kwargs["axis"] 95 | else: 96 | axis = arrays[0] 97 | arrays = arrays[1:] 98 | if isinstance(axis, NamedArray): 99 | raise ValueError("Must provide an axis to dot") 100 | 101 | warnings.warn("Axis has been changed to a keyword argument. Please update your code.", DeprecationWarning) 102 | 103 | _ensure_no_mismatched_axes(*arrays) 104 | 105 | # to call dot_general we need two things: 106 | # list of contractions and list of arrays 107 | 108 | all_axes: Tuple[Axis, ...] = ft.reduce(union_axes, (a.axes for a in arrays), ()) # type: ignore 109 | output_axes: Tuple[Axis, ...] 110 | if axis is None: 111 | # we want to contract over all the axes 112 | output_axes = () 113 | else: 114 | output_axes = eliminate_axes(all_axes, axis) 115 | 116 | if out_axes is not None: 117 | output_axes = rearrange_for_partial_order(out_axes, output_axes) 118 | 119 | array_specs = [] 120 | 121 | next_index = 0 122 | axis_mappings: Dict[str, int] = {} 123 | 124 | for a in arrays: 125 | spec = "" 126 | for ax in a.axes: 127 | if ax.name in axis_mappings: 128 | spec += f"{axis_mappings[ax.name]} " 129 | else: 130 | axis_mappings[ax.name] = next_index 131 | spec += f"{next_index} " 132 | next_index += 1 133 | 134 | array_specs.append(spec) 135 | 136 | # now compute the output axes: 137 | output_spec = " ".join(str(axis_mappings[ax.name]) for ax in output_axes) 138 | 139 | # get a name for jax so it's easier to interpret logs 140 | if axis is None: 141 | jax_str = f"contract {', '.join(axis_name(ax) for ax in all_axes)} -> " 142 | else: 143 | axis = ensure_tuple(axis) 144 | jax_str = f"contract {', '.join(axis_name(ax) for ax in axis)} -> {', '.join(a.name for a in output_axes)}" 145 | 146 | with jax.named_scope(jax_str): 147 | output = _jittable_dg_einsum( 148 | ", ".join(array_specs) + "-> " + output_spec, 149 | *[a.array for a in arrays], 150 | precision=precision, 151 | preferred_element_type=preferred_element_type, 152 | _dot_general=dot_general, 153 | ) 154 | 155 | out = NamedArray(output, output_axes) 156 | return haliax.auto_sharded(out) 157 | 158 | 159 | def _ensure_no_mismatched_axes(*arrays: NamedArray): 160 | """Ensure that all the arrays have no axes with the same name but different sizes""" 161 | if len(arrays) <= 1: 162 | return 163 | 164 | known_sizes: dict[str, int] = {} 165 | for a in arrays: 166 | for ax in a.axes: 167 | if ax.name in known_sizes: 168 | if known_sizes[ax.name] != ax.size: 169 | raise ValueError(f"Axis {ax.name} has multiple sizes: {known_sizes[ax.name]} and {ax.size}") 170 | else: 171 | known_sizes[ax.name] = ax.size 172 | -------------------------------------------------------------------------------- /src/haliax/_src/fp8.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import partial 3 | 4 | from jax import custom_jvp, custom_vjp, lax 5 | from jax import numpy as jnp 6 | 7 | 8 | # All of this is copy paste from flax/linen/fp8_ops.py 9 | # (Until we get to the module) 10 | 11 | # Copyright 2024 The Flax Authors. 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | 25 | 26 | def quantize_dequantize(x, q_dtype, scale, compute_dtype): 27 | qx = quantize(x, q_dtype, scale, compute_dtype) 28 | return dequantize(qx, x.dtype, scale) 29 | 30 | 31 | def get_fp8_max(fp8_dtype, out_dtype): 32 | assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2) 33 | return jnp.finfo(fp8_dtype).max.astype(out_dtype) 34 | 35 | 36 | def quantize(x, q_dtype, scale, compute_dtype): 37 | # Explicitly cast the max values to the compute dtype to avoid unnecessary 38 | # casting to FP32 during the subsequent math operations." 39 | dtype_max = get_fp8_max(q_dtype, compute_dtype) 40 | scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape) 41 | clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max) 42 | return clipped_x.astype(q_dtype) 43 | 44 | 45 | def dequantize(x, dq_dtype, scale): 46 | return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape) 47 | 48 | 49 | def compute_scale(amax, scale, fp8_max, margin=0): 50 | # The algorithm for computing the new scale is sourced from 51 | # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas 52 | # wherein the `original_scale` corresponds to the reciprocal of the `scale` 53 | # passed in this function. 54 | scale = 1.0 / scale 55 | 56 | sf = (fp8_max / amax) / (2**margin) 57 | sf = jnp.where(amax > 0.0, sf, scale) 58 | sf = jnp.where(jnp.isfinite(amax), sf, scale) 59 | 60 | return 1.0 / sf 61 | 62 | 63 | def compute_amax_history(x, amax_history): 64 | amax_update = jnp.max(jnp.abs(x)).astype(amax_history.dtype) 65 | new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update) 66 | return new_history 67 | 68 | 69 | def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): 70 | dtype_max = get_fp8_max(q_dtype, jnp.float32) 71 | amax_from_history = jnp.max(amax_history, axis=0) 72 | new_scale = compute_scale(amax_from_history, scale, dtype_max) 73 | 74 | qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) 75 | 76 | new_history = compute_amax_history(x, amax_history) 77 | 78 | return qx, new_scale, new_history 79 | 80 | 81 | @partial(custom_vjp, nondiff_argnums=(0,)) 82 | def in_qdq(compute_dtype, inp, scale, amax_history): 83 | qin, _, _ = qdq_and_return(inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype) 84 | return qin 85 | 86 | 87 | def in_qdq_fwd(compute_dtype, inp, scale, amax_history): 88 | qin, new_scale, new_history = qdq_and_return(inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype) 89 | return qin, (new_scale, new_history) 90 | 91 | 92 | def in_qdq_bwd(compute_dtype, res, g): 93 | new_scale, new_history = res 94 | q_g = g 95 | return q_g, new_scale, new_history 96 | 97 | 98 | in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd) 99 | 100 | 101 | @partial(custom_vjp, nondiff_argnums=(0,)) 102 | def out_qdq(compute_dtype, out, scale, amax_history): 103 | return out 104 | 105 | 106 | def out_qdq_fwd(compute_dtype, out, scale, amax_history): 107 | return out, (scale, amax_history) 108 | 109 | 110 | def out_qdq_bwd(compute_dtype, res, g): 111 | scale, amax_history = res 112 | q_g, new_scale, new_history = qdq_and_return(g, jnp.float8_e5m2, scale, amax_history, compute_dtype) 113 | return q_g, new_scale, new_history 114 | 115 | 116 | out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) 117 | 118 | 119 | @partial(custom_jvp, nondiff_argnums=(2, 3, 4)) 120 | def dot_general_with_precision(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): 121 | if precision is not None or preferred_element_type is not None: 122 | # einsum sets preferred_element_type and so this is just noisy 123 | # warnings.warn( 124 | # "The function dot_general_with_precision will set the " 125 | # "precision/preferred_element_type and disregard any provided " 126 | # "values." 127 | # ) 128 | pass 129 | return lax.dot_general(lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT) 130 | 131 | 132 | @dot_general_with_precision.defjvp 133 | def dot_general_with_precision_jvp(dimension_numbers, precision, preferred_element_type, primals, tangents): 134 | lhs, rhs = primals 135 | lhs_dot, rhs_dot = tangents 136 | 137 | out = lax.dot_general(lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT) 138 | grad_out = lax.dot_general(lhs_dot, rhs, dimension_numbers, precision=lax.Precision.HIGHEST) + lax.dot_general( 139 | lhs, rhs_dot, dimension_numbers, precision=lax.Precision.HIGHEST 140 | ) 141 | return out, grad_out 142 | -------------------------------------------------------------------------------- /src/haliax/_src/util.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, MutableMapping, Sequence, TypeAlias, TypeVar 2 | 3 | 4 | T = TypeVar("T") 5 | U = TypeVar("U") 6 | py_slice = slice 7 | slice_t: TypeAlias = slice 8 | 9 | 10 | def index_where(pred: Callable[[T], bool], xs: Sequence[T], start: int = 0) -> int: 11 | for i in range(start, len(xs)): 12 | if pred(xs[i]): 13 | return i 14 | raise ValueError("No element satisfies predicate") 15 | 16 | 17 | class IdentityMap(MutableMapping[T, U]): 18 | """Map that compares keys by identity. 19 | 20 | This is a map that compares keys by identity instead of equality. It is 21 | useful for storing objects that are not hashable or that should be compared 22 | by identity. 23 | 24 | This is a mutable mapping, but it does not support the ``__hash__`` method 25 | and therefore cannot be used as a dictionary key or as an element of another 26 | set. 27 | """ 28 | 29 | def __init__(self, iterable=None): 30 | self._data = {} 31 | if iterable is not None: 32 | self.update(iterable) 33 | 34 | def __contains__(self, key): 35 | return id(key) in self._data 36 | 37 | def __getitem__(self, key): 38 | return self._data[id(key)][1] 39 | 40 | def __setitem__(self, key, value): 41 | self._data[id(key)] = [key, value] 42 | 43 | def __delitem__(self, key): 44 | del self._data[id(key)] 45 | 46 | def __iter__(self): 47 | return (x[0] for x in self._data.values()) 48 | 49 | def __len__(self): 50 | return len(self._data) 51 | 52 | def __repr__(self): 53 | return f"IdentityMap({list(repr(x) for x in self._data.values())})" 54 | 55 | def __str__(self): 56 | return f"IdentityMap({list(str(x) for x in self._data.values())})" 57 | -------------------------------------------------------------------------------- /src/haliax/debug.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List, Tuple, Union 3 | 4 | import equinox as eqx 5 | import jax.numpy as jnp 6 | import jax.tree_util as jtu 7 | 8 | from haliax.core import NamedArray 9 | from haliax.util import is_jax_or_hax_array_like 10 | 11 | from ._src.util import IdentityMap 12 | 13 | 14 | ArrayLike = Union[jnp.ndarray, NamedArray] 15 | 16 | 17 | def describe_array(arr): 18 | if isinstance(arr, NamedArray): 19 | return f"NamedArray(axes={arr.axes}, dtype={arr.dtype})" 20 | else: 21 | return f"ndarray(shape={arr.shape}, dtype={arr.dtype})" 22 | 23 | 24 | class ModuleProblems(Exception): 25 | def __init__(self): 26 | self.reused_arrays: List[Tuple[ArrayLike, List]] = [] 27 | self.static_arrays: List[str] = [] 28 | 29 | def __bool__(self): 30 | return bool(self.reused_arrays or self.static_arrays) 31 | 32 | def __str__(self): 33 | if not self: 34 | return "No problems found" 35 | else: 36 | return "\n".join( 37 | [ 38 | "Found some problems with your module:", 39 | *self._format_reused_arrays(), 40 | *self._format_static_arrays(), 41 | ] 42 | ) 43 | 44 | def _format_reused_arrays(self): 45 | return [f" Reused array {describe_array(arr)} at paths {paths}" for arr, paths in self.reused_arrays] 46 | 47 | def _format_static_arrays(self): 48 | return [f" Static array at field {field}" for field in self.static_arrays] 49 | 50 | 51 | def diagnose_common_issues(module: eqx.Module): 52 | """ 53 | Checks for common issues in a module, such as reused arrays and static arrays. 54 | Equinox modules (and therefore Haliax modules) should not have arrays that are stored 55 | in multiple places, and should not have arrays stored as static fields. 56 | 57 | We'll add more checks here as we find them. 58 | 59 | Args: 60 | module: The module to check for problems 61 | 62 | Returns: 63 | None 64 | 65 | Raises: 66 | ModuleProblems: if any problems are found 67 | 68 | """ 69 | 70 | problems = ModuleProblems() 71 | _check_for_reused_arrays(problems, module) 72 | _check_for_static_arrays(problems, module) 73 | 74 | if problems: 75 | raise problems 76 | 77 | # just in case we missed anything, raise equinox's errors: 78 | eqx.tree_check(module) 79 | 80 | 81 | def _check_for_reused_arrays(problems, module): 82 | used_arrays = IdentityMap[ArrayLike, List[str]]() 83 | 84 | path_leaves, _ = jtu.tree_flatten_with_path(module, is_leaf=is_jax_or_hax_array_like) 85 | 86 | for path, leaf in path_leaves: 87 | if is_jax_or_hax_array_like(leaf): 88 | if leaf in used_arrays: 89 | used_arrays[leaf].append(jtu.keystr(path)) 90 | else: 91 | used_arrays[leaf] = [jtu.keystr(path)] 92 | 93 | for arr, paths in used_arrays.items(): 94 | if len(paths) > 1: 95 | problems.reused_arrays.append((arr, paths)) 96 | 97 | 98 | def _check_for_static_arrays(problems, module): 99 | static_arrays = [] 100 | 101 | def recurse(module, path): 102 | if isinstance(module, eqx.Module): 103 | for field in dataclasses.fields(module): 104 | value = getattr(module, field.name) 105 | if field.metadata.get("static", False) and is_jax_or_hax_array_like(value): 106 | static_arrays.append(f"{path}.{field.name}") 107 | else: 108 | recurse(value, f"{path}.{field.name}") 109 | else: 110 | leaves, _ = eqx.tree_flatten_one_level(module) 111 | if leaves != [module]: 112 | leaves_with_names = jtu.tree_leaves_with_path(module, is_leaf=lambda x: x in leaves) 113 | for name, leaf in leaves_with_names: 114 | recurse(leaf, f"{path}{name}") 115 | 116 | recurse(module, "") 117 | 118 | if static_arrays: 119 | problems.static_arrays.extend(static_arrays) 120 | -------------------------------------------------------------------------------- /src/haliax/hof.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from functools import wraps 3 | 4 | import equinox as eqx 5 | import jax 6 | from jaxtyping import PyTree 7 | 8 | import haliax.tree_util as htu 9 | 10 | from ._src.scan import ( 11 | UnnamedAxisSpec, 12 | _infer_axis_size_from_tree, 13 | _is_passive_array, 14 | _pacify_named_arrays, 15 | _PassiveNamedArray, 16 | _prepend_named_batch_axis, 17 | _zero_if_array_else_none, 18 | fold, 19 | map, 20 | scan, 21 | ) 22 | from .axis import Axis, AxisSelector, selects_axis 23 | from .core import NamedArray 24 | from .jax_utils import Static, broadcast_prefix, is_jax_array_like 25 | from .partitioning import physical_axis_name 26 | from .util import is_named_array 27 | 28 | 29 | def vmap( 30 | fn, 31 | axis: AxisSelector, 32 | *, 33 | default: PyTree[UnnamedAxisSpec] = _zero_if_array_else_none, 34 | args: PyTree[UnnamedAxisSpec] = (), 35 | kwargs: PyTree[UnnamedAxisSpec] = None, 36 | ): 37 | """ 38 | [haliax.NamedArray][]-aware version of [jax.vmap][]. Normal arrays are mapped according to the specs as in 39 | [equinox.filter_vmap][] 40 | 41 | Because of NamedArrays, vmap is typically less useful than in vanilla JAX, but it is sometimes 42 | useful for initializing modules that will be scanned over. See [haliax.nn.Stacked][] for an example. 43 | 44 | Args: 45 | fn (Callable): function to vmap over 46 | axis (Axis): axis to vmap over 47 | default: how to handle (unnamed) arrays by default. Should be either an integer or None, or a callable that takes a PyTree leaf 48 | and returns an integer or None, or a PyTree prefix of the same. If an integer, the array will be mapped over that axis. If None, the array will not be mapped over. 49 | args: optional per-argument overrides for how to handle arrays. Should be a PyTree prefix of the same type as default. 50 | kwargs: optional per-keyword-argument overrides for how to handle arrays. Should be a PyTree prefix of the same type as default. 51 | """ 52 | 53 | if kwargs is None: 54 | kwargs = {} 55 | 56 | signature = inspect.signature(fn) 57 | 58 | # this mirrors equinox's filter_vmap, but it's not really documented there so: 59 | # we use inspect.signature to align args/kwargs specified in vmap to what actually gets passed in 60 | # axis_spec_bound_sig's job is to hold that mapping 61 | signature_default = signature.replace( 62 | parameters=[ 63 | p 64 | if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) 65 | else p.replace(default=default) 66 | for p in signature.parameters.values() 67 | ] 68 | ) 69 | axis_spec_bound_sig = signature_default.bind_partial(*args, **kwargs) 70 | axis_spec_bound_sig.apply_defaults() 71 | del args, kwargs 72 | 73 | def _index_of_batch_axis(array, default): 74 | if isinstance(array, NamedArray): 75 | return array._lookup_indices(axis) 76 | elif callable(default): 77 | return default(array) 78 | else: 79 | return default 80 | 81 | # TODO: tests to exercise this more 82 | @wraps(fn) 83 | def wrapped_vmap_fn(*args, **kwargs): 84 | # TODO: this probably results in a lot of compilation misses. Need to think about it. 85 | actual_bound = signature.bind(*args, **kwargs) 86 | actual_bound.apply_defaults() 87 | 88 | # now that we have args, we can figure out what the axis spec is for each arg 89 | padded_spec_args = axis_spec_bound_sig.args + (default,) * ( 90 | len(actual_bound.args) - len(axis_spec_bound_sig.args) 91 | ) 92 | 93 | padded_spec_kwargs = { 94 | **axis_spec_bound_sig.kwargs, 95 | **{k: default for k in actual_bound.kwargs.keys() - axis_spec_bound_sig.kwargs.keys()}, 96 | } 97 | 98 | # want to support padded_spec_args being a tree prefix of the actual args, which this enables 99 | padded_spec_args = broadcast_prefix(padded_spec_args, actual_bound.args, is_leaf=is_named_array) 100 | padded_spec_kwargs = broadcast_prefix(padded_spec_kwargs, actual_bound.kwargs) 101 | 102 | arg_axis_specs = htu.tree_map(_index_of_batch_axis, actual_bound.args, padded_spec_args) 103 | 104 | kwarg_axis_specs = htu.tree_map(_index_of_batch_axis, actual_bound.kwargs, padded_spec_kwargs) 105 | 106 | # now we can actually vmap. We used "pacified" versions of NamedArrays that don't check 107 | # invariants, because intermediates creating during tracing won't have the axes right 108 | arg_axis_specs = htu.tree_map(_pacify_named_arrays, arg_axis_specs) 109 | kwarg_axis_specs = htu.tree_map(_pacify_named_arrays, kwarg_axis_specs) 110 | 111 | def wrapped_fn(args, kwargs): 112 | # the args that come in here are pacified. Their names will still have the batch axis even though the array 113 | # itself will already have that one removed. We need to turn them back into NamedArrays by removing the axis 114 | unchilled_args = jax.tree_util.tree_map(_to_unbatched_named_array(axis), args, is_leaf=_is_passive_array) 115 | unchilled_kwargs = jax.tree_util.tree_map( 116 | _to_unbatched_named_array(axis), kwargs, is_leaf=_is_passive_array 117 | ) 118 | 119 | out = fn(*unchilled_args, **unchilled_kwargs) 120 | 121 | # now we need to pacify the output, which may include NamedArrays, and add the batch axis back at the end 122 | chilled = htu.tree_map(_pacify_named_arrays, out) 123 | arrays, nonarrays = eqx.partition(chilled, is_jax_array_like) 124 | return arrays, Static(nonarrays) 125 | 126 | spmd_axis_name = physical_axis_name(axis) 127 | 128 | args = htu.tree_map(_pacify_named_arrays, actual_bound.args) 129 | kwargs = htu.tree_map(_pacify_named_arrays, actual_bound.kwargs) 130 | 131 | result_dynamic, result_static = jax.vmap( 132 | wrapped_fn, 133 | in_axes=(arg_axis_specs, kwarg_axis_specs), 134 | out_axes=0, 135 | axis_size=axis.size if isinstance(axis, Axis) else None, 136 | spmd_axis_name=spmd_axis_name, 137 | )(args, kwargs) 138 | 139 | result = eqx.combine(result_dynamic, result_static.value) 140 | 141 | # if we were passed in a string arg, we need to get its axis size out from some result 142 | true_axis = _infer_axis_size_from_tree(result, axis) 143 | if true_axis is None: 144 | raise ValueError("vmap failed to infer axis size from result") 145 | 146 | result = jax.tree_util.tree_map(_prepend_named_batch_axis(true_axis), result, is_leaf=_is_passive_array) 147 | return result 148 | 149 | return wrapped_vmap_fn 150 | 151 | 152 | def _to_unbatched_named_array(axis_to_strip: AxisSelector): 153 | def to_unbatched_named_array(leaf): 154 | if isinstance(leaf, _PassiveNamedArray): 155 | if selects_axis(leaf.main_axes, axis_to_strip): 156 | return leaf.strip_axis(axis_to_strip) 157 | else: 158 | return leaf.to_named_array() 159 | else: 160 | return leaf 161 | 162 | return to_unbatched_named_array 163 | 164 | 165 | __all__ = ["scan", "fold", "vmap", "map"] 166 | -------------------------------------------------------------------------------- /src/haliax/nn/__init__.py: -------------------------------------------------------------------------------- 1 | import jax.nn as jnn 2 | import jax.numpy as jnp 3 | 4 | import haliax 5 | import haliax as hax 6 | import haliax.nn.activations 7 | import haliax.nn.attention as attention 8 | import haliax.nn.normalization 9 | 10 | from ..axis import Axis 11 | from ..core import NamedArray 12 | from .activations import ( 13 | celu, 14 | elu, 15 | gelu, 16 | glu, 17 | hard_sigmoid, 18 | hard_silu, 19 | hard_swish, 20 | hard_tanh, 21 | leaky_relu, 22 | log_sigmoid, 23 | quick_gelu, 24 | relu, 25 | relu6, 26 | selu, 27 | sigmoid, 28 | silu, 29 | soft_sign, 30 | softplus, 31 | swish, 32 | ) 33 | from .conv import Conv, ConvTranspose 34 | from .dropout import Dropout, dropout 35 | from .embedding import Embedding 36 | from .linear import Linear, MoELinear 37 | from .loss import binary_cross_entropy_loss, cross_entropy_loss, cross_entropy_loss_and_log_normalizers, reduce_loss 38 | from .mlp import MLP 39 | from .normalization import LayerNorm, RmsNorm, log_softmax, logsumexp, softmax, standardize 40 | from .pool import max_pool, mean_pool, min_pool 41 | from .scan import BlockSeq, ScanCheckpointPolicy, Stacked 42 | 43 | 44 | def one_hot(x: NamedArray | int, class_axis: Axis, *, dtype=None) -> NamedArray: 45 | """ 46 | Convert an integer to a one-hot vector. This is basically a generalization of [jax.nn.one_hot][] 47 | for NamedArrays. 48 | 49 | Args: 50 | x: the integer or NamedArray of integers to convert 51 | class_axis: the axis to convert to one-hot 52 | dtype: the dtype of the result. If None, it will default to jax's default (currently float_) 53 | Returns: 54 | a NamedArray with the same axes as `x` plus `class_axis`, with 1s in the appropriate places 55 | """ 56 | if isinstance(x, NamedArray): 57 | array = jnn.one_hot(x.array, num_classes=class_axis.size, dtype=dtype) 58 | # Disabling this to prevent a crash in XLA on GPU 59 | # return hax.auto_sharded(hax.named(array, x.axes + (class_axis,))) 60 | return hax.named(array, x.axes + (class_axis,)) 61 | else: 62 | assert isinstance(x, int) 63 | assert class_axis.size > x >= -class_axis.size 64 | 65 | one = 1 66 | if dtype is not None: 67 | one = dtype(one) 68 | 69 | array = jnp.zeros(class_axis.size, dtype=dtype).at[x].set(one) 70 | return hax.auto_sharded(haliax.named(array, class_axis)) 71 | 72 | 73 | __all__ = [ 74 | "attention", 75 | "one_hot", 76 | "binary_cross_entropy_loss", 77 | "reduce_loss", 78 | "cross_entropy_loss", 79 | "cross_entropy_loss_and_log_normalizers", 80 | "Conv", 81 | "ConvTranspose", 82 | "Dropout", 83 | "dropout", 84 | "LayerNorm", 85 | "Linear", 86 | "MoELinear", 87 | "Embedding", 88 | "RmsNorm", 89 | "Stacked", 90 | "BlockSeq", 91 | "MLP", 92 | "relu", 93 | "gelu", 94 | "quick_gelu", 95 | "glu", 96 | "relu6", 97 | "sigmoid", 98 | "soft_sign", 99 | "softplus", 100 | "swish", 101 | "silu", 102 | "log_sigmoid", 103 | "leaky_relu", 104 | "hard_sigmoid", 105 | "hard_silu", 106 | "hard_swish", 107 | "hard_tanh", 108 | "logsumexp", 109 | "softmax", 110 | "log_softmax", 111 | "standardize", 112 | "elu", 113 | "celu", 114 | "selu", 115 | "max_pool", 116 | "mean_pool", 117 | "min_pool", 118 | "ScanCheckpointPolicy", 119 | ] 120 | -------------------------------------------------------------------------------- /src/haliax/nn/activations.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from jax import nn as jnn 4 | from jax import numpy as jnp 5 | 6 | from ..axis import Axis 7 | from ..core import NamedArray 8 | from ..types import Scalar 9 | from ..wrap import wrap_elemwise_unary 10 | 11 | 12 | A = typing.TypeVar("A", Scalar, NamedArray, jnp.ndarray) 13 | 14 | 15 | def relu(a: A) -> A: 16 | return wrap_elemwise_unary(jnn.relu, a) 17 | 18 | 19 | def relu6(a: A) -> A: 20 | return wrap_elemwise_unary(jnn.relu6, a) 21 | 22 | 23 | def sigmoid(a: A) -> A: 24 | return wrap_elemwise_unary(jnn.sigmoid, a) 25 | 26 | 27 | def softplus(a: A) -> A: 28 | return wrap_elemwise_unary(jnn.softplus, a) 29 | 30 | 31 | def soft_sign(a: A) -> A: 32 | return wrap_elemwise_unary(jnn.soft_sign, a) 33 | 34 | 35 | def silu(a: A) -> A: 36 | return wrap_elemwise_unary(jnn.silu, a) 37 | 38 | 39 | def swish(a: A) -> A: 40 | return wrap_elemwise_unary(jnn.swish, a) 41 | 42 | 43 | def log_sigmoid(a: A) -> A: 44 | return wrap_elemwise_unary(jnn.log_sigmoid, a) 45 | 46 | 47 | def leaky_relu(a: A) -> A: 48 | return wrap_elemwise_unary(jnn.leaky_relu, a) 49 | 50 | 51 | def hard_sigmoid(a: A) -> A: 52 | return wrap_elemwise_unary(jnn.hard_sigmoid, a) 53 | 54 | 55 | def hard_silu(a: A) -> A: 56 | return wrap_elemwise_unary(jnn.hard_silu, a) 57 | 58 | 59 | def hard_swish(a: A) -> A: 60 | return wrap_elemwise_unary(jnn.hard_swish, a) 61 | 62 | 63 | def hard_tanh(a: A) -> A: 64 | return wrap_elemwise_unary(jnn.hard_tanh, a) 65 | 66 | 67 | def elu(a: A) -> A: 68 | return wrap_elemwise_unary(jnn.elu, a) 69 | 70 | 71 | def celu(a: A) -> A: 72 | return wrap_elemwise_unary(jnn.celu, a) 73 | 74 | 75 | def selu(a: A) -> A: 76 | return wrap_elemwise_unary(jnn.selu, a) 77 | 78 | 79 | def gelu(a: A, approximate: bool = True) -> A: 80 | return wrap_elemwise_unary(jnn.gelu, a, approximate=approximate) 81 | 82 | 83 | def glu(x: NamedArray, axis: Axis) -> NamedArray: 84 | axis_index = x.axes.index(axis) 85 | return jnn.glu(x.array, axis_index) 86 | 87 | 88 | def quick_gelu(x): 89 | return x * sigmoid(1.702 * x) 90 | -------------------------------------------------------------------------------- /src/haliax/nn/dropout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import equinox as eqx 4 | import jax 5 | from jaxtyping import PRNGKeyArray 6 | 7 | import haliax 8 | from haliax.axis import AxisSpec 9 | from haliax.core import NamedArray 10 | from haliax.util import ensure_tuple 11 | 12 | 13 | def dropout(x, pdrop, broadcast_axes=None, *, inference, key=None): 14 | """Applies dropout. 15 | 16 | **Arguments:** 17 | 18 | - `x`: An any-dimensional JAX array to dropout. 19 | - `pdrop`: The fraction of entries to set to zero. 20 | - `broadcast_axes`: The dimensions to broadcast the dropout mask over. If set, these axes will share the same mask 21 | - `key`: A `jax.random.PRNGKey` used to provide randomness for calculating 22 | which elements to dropout. (Keyword only argument.) 23 | - `inference`: As per [`equinox.nn.Dropout.__init__`][]. If `True` or 24 | `False` then it will take priority over `self.inference`. If `None` 25 | then the value from `self.inference` will be used. 26 | """ 27 | if inference: 28 | return x 29 | elif isinstance(pdrop, (int, float)) and pdrop == 0: 30 | return x 31 | elif isinstance(pdrop, (int, float)) and pdrop == 1: 32 | return haliax.zeros_like(x) 33 | elif key is None: 34 | raise RuntimeError("Dropout requires a key when running in non-deterministic mode.") 35 | else: 36 | with jax.named_scope(name="dropout"): 37 | if broadcast_axes is None: 38 | if isinstance(x, NamedArray): 39 | shape_to_generate = x.axes 40 | else: 41 | shape_to_generate = x.shape 42 | else: 43 | axes = ensure_tuple(broadcast_axes) 44 | shape_to_generate = tuple(ax for ax in x.axes if ax not in axes) 45 | 46 | q = 1 - pdrop 47 | mask: NamedArray = haliax.random.bernoulli(key, shape_to_generate, q) 48 | q = x.dtype.type(q) 49 | 50 | out = haliax.where(mask, x / q, 0) 51 | assert out.dtype == x.dtype 52 | return out 53 | 54 | 55 | class Dropout(eqx.Module): 56 | """Applies dropout. 57 | 58 | Attributes: 59 | pdrop: The fraction of entries to set to zero. 60 | broadcast_axes: The dimensions to broadcast the dropout mask over. If set, these axes will share the same mask 61 | """ 62 | 63 | # key difference from equinox: these are static fields 64 | pdrop: float = eqx.static_field() 65 | broadcast_axes: Optional[AxisSpec] = eqx.static_field() 66 | inference: bool = False # note: not static 67 | 68 | def __init__( 69 | self, 70 | pdrop: float = 0.5, 71 | broadcast_axes: Optional[AxisSpec] = None, 72 | inference: bool = False, 73 | ): 74 | self.pdrop = pdrop 75 | self.broadcast_axes = broadcast_axes 76 | self.inference = inference 77 | 78 | @property 79 | def is_active(self): 80 | """Returns `True` if dropout is active (and therefore needs a key), `False` otherwise.""" 81 | return not self.inference and self.pdrop > 0 82 | 83 | def __call__( 84 | self, 85 | x: NamedArray, 86 | *, 87 | inference: Optional[bool] = None, 88 | key: Optional[PRNGKeyArray] = None, 89 | ) -> NamedArray: 90 | """**Arguments:** 91 | 92 | - `x`: An any-dimensional JAX array to dropout. 93 | - `key`: A `jax.random.PRNGKey` used to provide randomness for calculating 94 | which elements to dropout. (Keyword only argument.) 95 | - `inference`: As per [`equinox.nn.Dropout.__init__`][]. If `True` or 96 | `False` then it will take priority over `self.inference`. If `None` 97 | then the value from `self.inference` will be used. 98 | """ 99 | if inference is None: 100 | inference = self.inference 101 | 102 | return dropout( 103 | x, 104 | self.pdrop, 105 | broadcast_axes=self.broadcast_axes, 106 | inference=inference, 107 | key=key, 108 | ) 109 | -------------------------------------------------------------------------------- /src/haliax/nn/embedding.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import math 3 | import warnings 4 | from typing import Optional 5 | 6 | import equinox as eqx 7 | from jaxtyping import PRNGKeyArray 8 | 9 | import haliax as hax 10 | 11 | from ..axis import Axis, AxisSpec 12 | from ..core import NamedArray 13 | from ..jax_utils import named_call 14 | from ..tree_util import resize_axis 15 | from ..util import ensure_tuple 16 | 17 | 18 | class Embedding(eqx.Module): 19 | weight: NamedArray 20 | 21 | # axes 22 | Vocab: Axis = eqx.static_field() 23 | Embed: AxisSpec = eqx.static_field() 24 | 25 | @staticmethod 26 | def init(Vocab: Axis, Embed: AxisSpec, *, init_scale: float = 1, key, initializer_range: Optional[float] = None): 27 | """ 28 | Initialize an Embedding module. 29 | 30 | An embedding module is a simple lookup table that maps integer indices to vectors or tensors. 31 | Weights are initialized with a truncated normal distribution with a standard deviation of 32 | `init_scale / output_size`. 33 | 34 | Args: 35 | Vocab: Size of the vocabulary 36 | Embed: Shape of the embedding vectors. May be a single axis or a full AxisSpec 37 | init_scale: Scale of the initialization 38 | key: PRNG key 39 | initializer_range: Deprecated. Use init_scale instead. 40 | """ 41 | if initializer_range is not None: 42 | warnings.warn("initializer_range is deprecated. Use init_std instead.", DeprecationWarning) 43 | init_scale = initializer_range 44 | 45 | all_axes = (Vocab,) + ensure_tuple(Embed) 46 | output_size = hax.axis_size(Embed) 47 | weight = hax.random.truncated_normal(key, all_axes, -3, 3) * (init_scale / output_size) 48 | return Embedding(weight=weight, Vocab=Vocab, Embed=Embed) 49 | 50 | def __call__(self, input_ids: NamedArray, *, key: Optional[PRNGKeyArray] = None): 51 | """Alias for `embed`. key is ignored.""" 52 | return self.embed(input_ids) 53 | 54 | @named_call 55 | def embed(self, input_ids: NamedArray): 56 | """ 57 | Args: 58 | input_ids: token IDs with shape > {Vocab} 59 | """ 60 | input_embeds = self.weight.take(self.Vocab, input_ids) 61 | return input_embeds 62 | 63 | def unembed(self, input_embeds: NamedArray): 64 | """ 65 | Unembed the input embeddings back to the vocabulary space. 66 | 67 | Equivalent to `input_embeds.dot(self.weight, axis=self.Embed)`. 68 | """ 69 | return input_embeds.dot(self.weight, axis=self.Embed) 70 | 71 | def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): 72 | """ 73 | Resize the embedding layer to a new size. 74 | Args: 75 | new_size: New size of the vocabulary 76 | key: PRNG key for initialization of any new weights 77 | 78 | Returns: 79 | Embedding: Resized embedding layer 80 | 81 | """ 82 | new_weights = resize_axis(self.weight, self.Vocab, new_size, key=key) 83 | return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), weight=new_weights) # type: ignore 84 | -------------------------------------------------------------------------------- /src/haliax/nn/loss.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import warnings 3 | from typing import Optional 4 | 5 | from jax import numpy as jnp 6 | 7 | import haliax as hax 8 | from haliax.axis import AxisSelection, AxisSelector 9 | from haliax.core import NamedArray 10 | from haliax.util import UNSPECIFIED, Unspecified 11 | from haliax.wrap import ReductionFunction 12 | 13 | 14 | @typing.overload 15 | def cross_entropy_loss( 16 | logits: NamedArray, 17 | Label: AxisSelector, 18 | targets: NamedArray, 19 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED, 20 | where: Optional[NamedArray] = None, 21 | reduction_axis: None = None, 22 | ) -> jnp.ndarray | NamedArray: 23 | ... 24 | 25 | 26 | @typing.overload 27 | def cross_entropy_loss( 28 | logits: NamedArray, 29 | Label: AxisSelector, 30 | targets: NamedArray, 31 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED, 32 | where: Optional[NamedArray] = None, 33 | reduction_axis: AxisSelection = ..., 34 | ) -> NamedArray: 35 | ... 36 | 37 | 38 | def cross_entropy_loss( 39 | logits: NamedArray, 40 | Label: AxisSelector, 41 | targets: NamedArray, 42 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED, 43 | where: Optional[NamedArray] = None, 44 | reduction_axis: Optional[AxisSelection] = None, 45 | ) -> jnp.ndarray | NamedArray: 46 | loss, _ = cross_entropy_loss_and_log_normalizers(logits, Label, targets) 47 | 48 | # if target_y isn't some kind of floating point, something is wrong, so warn 49 | if not jnp.issubdtype(targets.dtype, jnp.floating): 50 | warnings.warn( 51 | f"target_y has dtype {targets.dtype}, which is not a floating point type. This is probably a mistake." 52 | ) 53 | 54 | loss = maybe_reduce_loss(loss, reduction, reduction_axis, where) 55 | 56 | return loss 57 | 58 | 59 | @typing.overload 60 | def binary_cross_entropy_loss( 61 | logits: NamedArray, 62 | targets: NamedArray, 63 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED, 64 | where: Optional[NamedArray] = None, 65 | reduction_axis: None = None, 66 | ) -> jnp.ndarray | NamedArray: 67 | ... 68 | 69 | 70 | @typing.overload 71 | def binary_cross_entropy_loss( 72 | logits: NamedArray, 73 | targets: NamedArray, 74 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED, 75 | where: Optional[NamedArray] = None, 76 | reduction_axis: AxisSelection = ..., 77 | ) -> NamedArray: 78 | ... 79 | 80 | 81 | def binary_cross_entropy_loss( 82 | logits: NamedArray, 83 | targets: NamedArray, 84 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED, 85 | where: Optional[NamedArray] = None, 86 | reduction_axis: Optional[AxisSelection] = None, 87 | ) -> jnp.ndarray | NamedArray: 88 | log_p = hax.nn.log_sigmoid(logits) 89 | log_not_p = hax.nn.log_sigmoid(-logits) # == log(1-sigmoid(x)) 90 | targets = targets.astype(logits.dtype) 91 | loss = -targets * log_p - (1.0 - targets) * log_not_p 92 | 93 | loss = maybe_reduce_loss(loss, reduction, reduction_axis, where) 94 | return loss 95 | 96 | 97 | def reduce_loss( 98 | arr, 99 | reduction: Optional[ReductionFunction] | Unspecified = UNSPECIFIED, 100 | reduction_axis: Optional[AxisSelection] = None, 101 | where: Optional[NamedArray] = None, 102 | ): 103 | """ 104 | Reduce a loss array according to the given reduction and reduction axis. 105 | If reduction is None, the loss is not reduced. 106 | If reduction is UNSPECIFIED, the default reduction is used (mean). 107 | If reduction_axis is None (default), the loss is reduced over all axes. 108 | """ 109 | return maybe_reduce_loss(arr, reduction, reduction_axis, where) 110 | 111 | 112 | def maybe_reduce_loss( 113 | arr, 114 | reduction: Optional[ReductionFunction] | Unspecified, 115 | reduction_axis: Optional[AxisSelection], 116 | where: Optional[NamedArray], 117 | ): 118 | if reduction is not None and reduction_axis != (): 119 | if reduction is UNSPECIFIED: 120 | reduction = hax.mean 121 | arr = reduction(arr, where=where, axis=reduction_axis) 122 | elif where is not None: 123 | arr = hax.where(where, arr, 0) 124 | return arr 125 | 126 | 127 | def cross_entropy_loss_and_log_normalizers( 128 | pred_y: NamedArray, 129 | Label: AxisSelector, 130 | target_y: NamedArray, 131 | ) -> tuple[NamedArray, NamedArray]: 132 | """ 133 | Compute the cross entropy loss and log normalizers for a batch of predictions and targets. 134 | 135 | :param pred_y: a NamedArray with the Label axis (and possibly others for e.g. batch and seq) containing the logits 136 | :param Label: the Label axis 137 | :param target_y: a NamedArray with the Label axis (and possibly others) containing the targets 138 | 139 | :return: tuple of two named arrays, with "per position" losses and log normalizers 140 | """ 141 | log_normalizers = hax.nn.logsumexp(pred_y, Label) 142 | neg_log_normalized = log_normalizers - pred_y 143 | 144 | loss = hax.dot(target_y, neg_log_normalized, axis=Label) 145 | 146 | return loss, log_normalizers 147 | -------------------------------------------------------------------------------- /src/haliax/nn/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence 2 | 3 | import equinox as eqx 4 | import jax 5 | from jaxtyping import PRNGKeyArray 6 | 7 | from ..axis import Axis, AxisSpec 8 | from ..core import NamedArray 9 | from ..jax_utils import maybe_rng_split 10 | from ..quantization import DotGeneralOp 11 | from .activations import relu 12 | from .linear import Linear 13 | 14 | 15 | DEFAULT_WIDTH_NAME = "mlp" 16 | 17 | 18 | class MLP(eqx.Module): 19 | """ 20 | A multilayer perceptron (MLP) / feed-forward neural network (FFNN). 21 | 22 | MLPs, with their stacked linear layers often with non-semantic axes for hidden 23 | dims, are not a particular strength of Haliax's design philosophy. Nonetheless, they are a useful tool for 24 | many tasks, and so we provide this module. 25 | 26 | In Haliax, all axes must have names, and names must be unique within an array. We considered a few strategies 27 | for naming the axes of an MLP, and settled on the following: By default, we alternate hidden names between "mlp" 28 | and "mlp2". Input and output names must be specified, and are not repeated. This naming scheme is not perfect, 29 | but does mean that model parallelism works reasonably well. 30 | 31 | NB: unlike Equinox's MLP, this MLP uses a static field for activation. If you want a learnable activation, you 32 | likely want a unique activation per layer, which neither version provides. Instead, you should use a 33 | [haliax.nn.Stacked][] with a suitable block. 34 | """ 35 | 36 | activation: Callable = eqx.field(static=True) 37 | layers: Sequence[Linear] 38 | 39 | @staticmethod 40 | def init( 41 | Input: AxisSpec, 42 | Output: AxisSpec, 43 | width: int | Axis, 44 | depth: int, 45 | activation: Callable = relu, 46 | *, 47 | out_first: bool = True, 48 | use_bias: bool = True, 49 | use_final_bias: bool = True, 50 | key: PRNGKeyArray, 51 | dot_general: Optional[DotGeneralOp] = None, 52 | init_scale: float = 1.0, 53 | ): 54 | Width = _get_width(width) 55 | Width2 = Width.alias(Width.name + "2") 56 | 57 | keys = jax.random.split(key, depth + 1) 58 | 59 | layers = [] 60 | 61 | kwargs: dict = { 62 | "use_bias": use_bias, 63 | "dot_general": dot_general, 64 | "init_scale": init_scale, 65 | "out_first": out_first, 66 | } 67 | 68 | last_kwargs: dict = { 69 | "use_bias": use_final_bias, 70 | "dot_general": dot_general, 71 | "init_scale": init_scale, 72 | "out_first": out_first, 73 | } 74 | 75 | if depth == 0: 76 | # special case: no hidden layers 77 | layers.append(Linear.init(Input, Output, key=keys[0], **last_kwargs)) 78 | else: 79 | # first hidden layer 80 | layers.append(Linear.init(Input, Width, key=keys[0], **kwargs)) 81 | # middle hidden layers 82 | cur = Width 83 | next = Width2 84 | for i in range(1, depth): 85 | layers.append(Linear.init(cur, next, key=keys[i], **kwargs)) 86 | cur, next = next, cur 87 | # final layer 88 | layers.append(Linear.init(cur, Output, key=keys[-1], **last_kwargs)) 89 | 90 | return MLP( 91 | layers=tuple(layers), 92 | activation=activation, 93 | ) 94 | 95 | @property 96 | def In(self) -> AxisSpec: 97 | return self.layers[0].In 98 | 99 | @property 100 | def Out(self) -> AxisSpec: 101 | return self.layers[-1].Out 102 | 103 | def __call__(self, x: NamedArray, *, key=None) -> NamedArray: 104 | keys = maybe_rng_split(key, len(self.layers)) 105 | for layer, k in zip(self.layers[:-1], keys): 106 | x = self.activation(layer(x, key=k)) 107 | return self.layers[-1](x, key=keys[-1]) 108 | 109 | 110 | def _get_width(Width: int | Axis) -> Axis: 111 | if isinstance(Width, int): 112 | return Axis(DEFAULT_WIDTH_NAME, Width) 113 | else: 114 | return Width 115 | -------------------------------------------------------------------------------- /src/haliax/nn/normalization.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from abc import abstractmethod 3 | from typing import Optional, TypeVar 4 | 5 | import equinox as eqx 6 | from jax import nn as jnn 7 | from jax import numpy as jnp 8 | 9 | import haliax 10 | import haliax as hax 11 | 12 | from .._src.state_dict import Mod, ModuleWithStateDictSerialization 13 | from ..axis import AxisSelection, AxisSpec 14 | from ..core import NamedArray 15 | from ..types import Scalar 16 | from ..wrap import unwrap_namedarrays, wrap_axiswise_call, wrap_reduction_call 17 | 18 | 19 | A = TypeVar("A", Scalar, NamedArray, jnp.ndarray) 20 | 21 | 22 | class LayerNormBase(ModuleWithStateDictSerialization): 23 | axis: AxisSpec = eqx.static_field() 24 | weight: Optional[NamedArray] 25 | bias: Optional[NamedArray] 26 | eps: float = eqx.static_field(default=1e-5) 27 | dtype: Optional[jnp.dtype] = eqx.field(default=None, static=True) 28 | 29 | @abstractmethod 30 | def __call__(self, x: NamedArray) -> NamedArray: 31 | pass 32 | 33 | @classmethod 34 | def init( 35 | cls, 36 | axis: AxisSpec, 37 | eps: float = 1e-5, 38 | *, 39 | use_weight: bool = True, 40 | use_bias: bool = True, 41 | dtype: Optional[jnp.dtype] = None, 42 | ): 43 | if use_weight: 44 | weight = hax.ones(axis) 45 | else: 46 | weight = None 47 | 48 | if use_bias: 49 | bias = hax.zeros(axis) 50 | else: 51 | bias = None 52 | 53 | return cls(axis, weight, bias, eps, dtype) 54 | 55 | def flatten_for_export(self: Mod) -> Mod: 56 | if isinstance(self.axis, hax.Axis): 57 | return self 58 | 59 | if self.weight is not None: 60 | weight = self.weight.flatten("__OUT") 61 | else: 62 | weight = None 63 | 64 | if self.bias is not None: 65 | bias = self.bias.flatten("__OUT") 66 | else: 67 | bias = None 68 | 69 | return dataclasses.replace(self, weight=weight, bias=bias, axis=hax.flatten_axes(self.axis, "__OUT")) 70 | 71 | def unflatten_from_export(self: Mod, template: Mod) -> Mod: 72 | if template.axis == self.axis: 73 | return self 74 | 75 | if self.weight is not None: 76 | assert isinstance(self.axis, hax.Axis), "Cannot unflatten weight with non-axis axis" 77 | weight = hax.unflatten_axis(self.weight, self.axis, template.axis) 78 | else: 79 | weight = None 80 | 81 | if self.bias is not None: 82 | assert isinstance(self.axis, hax.Axis), "Cannot unflatten weight with non-axis axis" 83 | bias = hax.unflatten_axis(self.bias, self.axis, template.axis) 84 | 85 | else: 86 | bias = None 87 | 88 | return dataclasses.replace(self, weight=weight, bias=bias, axis=template.axis) 89 | 90 | 91 | class LayerNorm(LayerNormBase): 92 | r""" 93 | Normalises the input along the specified axis (or axes), using the mean and variance of the 94 | input along that axis. 95 | """ 96 | axis: AxisSpec = eqx.field(static=True) 97 | weight: Optional[NamedArray] 98 | bias: Optional[NamedArray] 99 | 100 | eps: float = eqx.field(default=1e-5, static=True) 101 | dtype: Optional[jnp.dtype] = eqx.field(default=None, static=True) 102 | 103 | def __call__(self, x: NamedArray) -> NamedArray: 104 | dtype = x.dtype 105 | mean = x.mean(self.axis) 106 | var = x.var(self.axis) 107 | inv = hax.rsqrt(var + self.eps) 108 | out = (x - mean) * inv 109 | out = out.astype(dtype) 110 | 111 | if self.weight is not None: 112 | out = self.weight * out 113 | if self.bias is not None: 114 | out = out + self.bias 115 | return out 116 | 117 | 118 | class RmsNorm(LayerNormBase): 119 | r""" 120 | Implements RMS normalization, which normalizes the input by dividing by the root mean square of the input. 121 | """ 122 | 123 | def __call__(self, x: NamedArray) -> NamedArray: 124 | in_dtype = x.dtype 125 | x = x.astype(self.dtype) 126 | var = hax.mean(hax.square(x), axis=self.axis) 127 | inv = hax.rsqrt(var + self.eps) 128 | out = x * inv 129 | out = out.astype(in_dtype) 130 | 131 | if self.weight is not None: 132 | out = self.weight * out 133 | if self.bias is not None: 134 | out = out + self.bias 135 | return out 136 | 137 | 138 | def logsumexp(a: A, axis: Optional[AxisSelection] = None) -> A: 139 | # TODO: logsumexp indirectly supports where via `b`. we should support it directly 140 | return wrap_reduction_call(jnn.logsumexp, a, axis=axis, single_axis_only=False, supports_where=False) 141 | 142 | 143 | # TODO: support where in softmax, etc 144 | 145 | 146 | def softmax(a: A, axis: Optional[AxisSelection] = None) -> A: 147 | return wrap_axiswise_call(jnn.softmax, a, axis=axis, single_axis_only=False) 148 | 149 | 150 | def log_softmax(a: A, axis: Optional[AxisSelection] = None) -> A: 151 | return wrap_axiswise_call(jnn.log_softmax, a, axis=axis, single_axis_only=False) 152 | 153 | 154 | def standardize( 155 | x: NamedArray, 156 | axis: AxisSpec, 157 | *, 158 | mean: Optional[NamedArray] = None, 159 | variance: Optional[NamedArray] = None, 160 | epsilon: float = 1e-5, 161 | where: Optional[NamedArray] = None, 162 | ) -> NamedArray: 163 | """Analogous to [jax.nn.standardize][], but with support for NamedArrays.""" 164 | x, mean, variance, where = haliax.broadcast_arrays(x, mean, variance, where) # type: ignore 165 | raw_x, mean, variance, where = unwrap_namedarrays(x, mean, variance, where) 166 | axis_indices = x._lookup_indices(axis) 167 | 168 | plain = jnn.standardize(raw_x, axis_indices, mean=mean, variance=variance, epsilon=epsilon, where=where) 169 | return NamedArray(plain, x.axes) 170 | -------------------------------------------------------------------------------- /src/haliax/ops.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Optional, Union 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from .axis import Axis, AxisSelector 8 | from .core import NamedArray, NamedOrNumeric, broadcast_arrays, broadcast_arrays_and_return_axes, named 9 | from .jax_utils import is_scalarish 10 | 11 | 12 | def trace(array: NamedArray, axis1: AxisSelector, axis2: AxisSelector, offset=0, dtype=None) -> NamedArray: 13 | """Compute the trace of an array along two named axes.""" 14 | a1_index = array._lookup_indices(axis1) 15 | a2_index = array._lookup_indices(axis2) 16 | 17 | if a1_index is None: 18 | raise ValueError(f"Axis {axis1} not found in array. Available axes: {array.axes}") 19 | if a2_index is None: 20 | raise ValueError(f"Axis {axis2} not found in array. Available axes: {array.axes}") 21 | 22 | if a1_index == a2_index: 23 | raise ValueError(f"Cannot trace along the same axis. Got {axis1} and {axis2}") 24 | 25 | inner = jnp.trace(array.array, offset=offset, axis1=a1_index, axis2=a2_index, dtype=dtype) 26 | # remove the two indices 27 | axes = tuple(a for i, a in enumerate(array.axes) if i not in (a1_index, a2_index)) 28 | return NamedArray(inner, axes) 29 | 30 | 31 | @typing.overload 32 | def where( 33 | condition: NamedOrNumeric | bool, 34 | x: NamedOrNumeric, 35 | y: NamedOrNumeric, 36 | ) -> NamedArray: 37 | ... 38 | 39 | 40 | @typing.overload 41 | def where( 42 | condition: NamedArray, 43 | *, 44 | fill_value: int, 45 | new_axis: Axis, 46 | ) -> tuple[NamedArray, ...]: 47 | ... 48 | 49 | 50 | def where( 51 | condition: Union[NamedOrNumeric, bool], 52 | x: Optional[NamedOrNumeric] = None, 53 | y: Optional[NamedOrNumeric] = None, 54 | fill_value: Optional[int] = None, 55 | new_axis: Optional[Axis] = None, 56 | ) -> NamedArray | tuple[NamedArray, ...]: 57 | """Like jnp.where, but with named axes.""" 58 | 59 | if (x is None) != (y is None): 60 | raise ValueError("Must either specify both x and y, or neither") 61 | 62 | # one argument form 63 | if (x is None) and (y is None): 64 | if not isinstance(condition, NamedArray): 65 | raise ValueError(f"condition {condition} must be a NamedArray in single argument mode") 66 | if fill_value is None or new_axis is None: 67 | raise ValueError("Must specify both fill_value and new_axis") 68 | return tuple( 69 | NamedArray(idx, (new_axis,)) 70 | for idx in jnp.where(condition.array, size=new_axis.size, fill_value=fill_value) 71 | ) 72 | 73 | # if x or y is a NamedArray, the other must be as well. wrap as needed for scalars 74 | 75 | if is_scalarish(condition): 76 | if x is None or y is None: 77 | raise ValueError("Must specify x and y when condition is a scalar") 78 | 79 | if isinstance(x, NamedArray) and not isinstance(y, NamedArray): 80 | if not is_scalarish(y): 81 | raise ValueError("y must be a NamedArray or scalar if x is a NamedArray") 82 | y = named(y, ()) 83 | elif isinstance(y, NamedArray) and not isinstance(x, NamedArray): 84 | if not is_scalarish(x): 85 | raise ValueError("x must be a NamedArray or scalar if y is a NamedArray") 86 | x = named(x, ()) 87 | x, y = broadcast_arrays(x, y) 88 | return jax.lax.cond(condition, lambda _: x, lambda _: y, None) 89 | 90 | condition, x, y = broadcast_arrays(condition, x, y) # type: ignore 91 | 92 | assert isinstance(condition, NamedArray) 93 | 94 | def _array_if_named(x): 95 | if isinstance(x, NamedArray): 96 | return x.array 97 | return x 98 | 99 | raw = jnp.where(condition.array, _array_if_named(x), _array_if_named(y)) 100 | return NamedArray(raw, condition.axes) 101 | 102 | 103 | def clip(array: NamedOrNumeric, a_min: NamedOrNumeric, a_max: NamedOrNumeric) -> NamedArray: 104 | """Like jnp.clip, but with named axes. This version currently only accepts the three argument form.""" 105 | (array, a_min, a_max), axes = broadcast_arrays_and_return_axes(array, a_min, a_max) 106 | array = raw_array_or_scalar(array) 107 | a_min = raw_array_or_scalar(a_min) 108 | a_max = raw_array_or_scalar(a_max) 109 | 110 | return NamedArray(jnp.clip(array, a_min, a_max), axes) 111 | 112 | 113 | def tril(array: NamedArray, axis1: Axis, axis2: Axis, k=0) -> NamedArray: 114 | """Compute the lower triangular part of an array along two named axes.""" 115 | array = array.rearrange((..., axis1, axis2)) 116 | 117 | inner = jnp.tril(array.array, k=k) 118 | return NamedArray(inner, array.axes) 119 | 120 | 121 | def triu(array: NamedArray, axis1: Axis, axis2: Axis, k=0) -> NamedArray: 122 | """Compute the upper triangular part of an array along two named axes.""" 123 | array = array.rearrange((..., axis1, axis2)) 124 | 125 | inner = jnp.triu(array.array, k=k) 126 | return NamedArray(inner, array.axes) 127 | 128 | 129 | def isclose(a: NamedArray, b: NamedArray, rtol=1e-05, atol=1e-08, equal_nan=False) -> NamedArray: 130 | """Returns a boolean array where two arrays are element-wise equal within a tolerance.""" 131 | a, b = broadcast_arrays(a, b) 132 | # TODO: numpy supports an array atol and rtol, but we don't yet 133 | return NamedArray(jnp.isclose(a.array, b.array, rtol=rtol, atol=atol, equal_nan=equal_nan), a.axes) 134 | 135 | 136 | def pad_left(array: NamedArray, axis: Axis, new_axis: Axis, value=0) -> NamedArray: 137 | """Pad an array along named axes.""" 138 | amount_to_pad_to = new_axis.size - axis.size 139 | if amount_to_pad_to < 0: 140 | raise ValueError(f"Cannot pad {axis} to {new_axis}") 141 | 142 | idx = array._lookup_indices(axis) 143 | 144 | padding = [(0, 0)] * array.ndim 145 | if idx is None: 146 | raise ValueError(f"Axis {axis} not found in array. Available axes: {array.axes}") 147 | padding[idx] = (amount_to_pad_to, 0) 148 | 149 | padded = jnp.pad(array.array, padding, constant_values=value) 150 | return NamedArray(padded, array.axes[:idx] + (new_axis,) + array.axes[idx + 1 :]) 151 | 152 | 153 | def raw_array_or_scalar(x: NamedOrNumeric): 154 | if isinstance(x, NamedArray): 155 | return x.array 156 | return x 157 | 158 | 159 | __all__ = ["trace", "where", "tril", "triu", "isclose", "pad_left", "clip"] 160 | -------------------------------------------------------------------------------- /src/haliax/specialized_fns.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from .axis import Axis, AxisSelector 7 | from .core import NamedArray 8 | 9 | 10 | def top_k( 11 | arr: NamedArray, axis: AxisSelector, k: int, new_axis: Optional[AxisSelector] = None 12 | ) -> Tuple[NamedArray, NamedArray]: 13 | """ 14 | Select the top k elements along the given axis. 15 | Args: 16 | arr (NamedArray): array to select from 17 | axis (AxisSelector): axis to select from 18 | k (int): number of elements to select 19 | new_axis (Optional[AxisSelector]): new axis name, if none, the original axis will be resized to k 20 | 21 | Returns: 22 | NamedArray: array with the top k elements along the given axis 23 | NamedArray: array with the top k elements' indices along the given axis 24 | """ 25 | pos = arr._lookup_indices(axis) 26 | if pos is None: 27 | raise ValueError(f"Axis {axis} not found in {arr}") 28 | new_array = jnp.moveaxis(arr.array, pos, -1) # move axis to the last position 29 | values, indices = jax.lax.top_k(new_array, k=k) 30 | values = jnp.moveaxis(values, -1, pos) # move axis back to its original position 31 | indices = jnp.moveaxis(indices, -1, pos) 32 | 33 | if new_axis is None: 34 | axis = arr.resolve_axis(axis) 35 | new_axis = axis.resize(k) 36 | elif isinstance(new_axis, str): 37 | new_axis = Axis(new_axis, k) 38 | 39 | updated_axes = arr.axes[:pos] + (new_axis,) + arr.axes[pos + 1 :] 40 | return NamedArray(values, updated_axes), NamedArray(indices, updated_axes) 41 | -------------------------------------------------------------------------------- /src/haliax/state_dict.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, TypeVar 2 | 3 | import equinox 4 | 5 | from haliax.jax_utils import is_jax_array_like 6 | from haliax.types import FilterSpec 7 | 8 | from ._src.state_dict import ( 9 | ModuleWithStateDictSerialization, 10 | StateDict, 11 | flatten_modules_for_export, 12 | from_state_dict, 13 | from_torch_compatible_state_dict, 14 | load_state_dict, 15 | save_state_dict, 16 | to_numpy_state_dict, 17 | to_state_dict, 18 | unflatten_modules_from_export, 19 | with_prefix, 20 | ) 21 | 22 | 23 | T = TypeVar("T") 24 | 25 | 26 | def to_torch_compatible_state_dict( 27 | t: T, *, flatten: bool = True, prefix: Optional[str] = None, filter: FilterSpec = is_jax_array_like 28 | ) -> StateDict: 29 | """ 30 | Convert a tree to a state dict that is compatible with torch-style state dicts. 31 | 32 | This applies the same logic as [to_state_dict][] but also uses [haliax.state_dict.ModuleWithStateDictSerialization.flatten_for_export][] to flatten 33 | 34 | Args: 35 | t: The tree to convert 36 | flatten: Whether to flatten axes using flatten_for_export 37 | prefix: The prefix to use for the state dict keys 38 | filter: The filter to use for selecting which nodes to include in the state dict. By default, this includes only 39 | array-like objects (e.g. JAX and NumPy arrays). 40 | """ 41 | t = equinox.filter(t, filter) 42 | if flatten: 43 | t = flatten_modules_for_export(t) 44 | return to_numpy_state_dict(t, prefix=prefix) 45 | 46 | 47 | __all__ = [ 48 | "ModuleWithStateDictSerialization", 49 | "from_torch_compatible_state_dict", 50 | "load_state_dict", 51 | "save_state_dict", 52 | "from_state_dict", 53 | "with_prefix", 54 | "to_state_dict", 55 | "to_numpy_state_dict", 56 | "StateDict", 57 | "to_torch_compatible_state_dict", 58 | "flatten_modules_for_export", 59 | "unflatten_modules_from_export", 60 | ] 61 | -------------------------------------------------------------------------------- /src/haliax/tree_util.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import functools 3 | from typing import Optional 4 | 5 | import equinox as eqx 6 | import jax 7 | import jax.tree_util as jtu 8 | from jaxtyping import PRNGKeyArray, PyTree 9 | 10 | import haliax.nn 11 | 12 | from .axis import AxisSelector 13 | from .core import NamedArray 14 | from .jax_utils import maybe_rng_split 15 | from .util import is_named_array 16 | 17 | 18 | def tree_map(fn, tree, *rest, is_leaf=None): 19 | """ 20 | Version of [jax.tree_util.tree_map][] that automatically treats NamedArrays as leaves. 21 | """ 22 | old_is_leaf = is_leaf 23 | if is_leaf is None: 24 | is_leaf = lambda x: isinstance(x, NamedArray) 25 | else: 26 | is_leaf = lambda x: old_is_leaf(x) or is_named_array(x) 27 | 28 | return jax.tree.map(fn, tree, *rest, is_leaf=is_leaf) 29 | 30 | 31 | def scan_aware_tree_map(fn, tree, *rest, is_leaf=None): 32 | """ 33 | Version of [haliax.tree_util.tree_map][] that is aware of the scan-layer pattern, specifically as implemented 34 | in hax.nn.Stacked. This function will (implicitly) apply the transform to each layer in each Stacked module 35 | (using vmap). If there are no Stacked modules in the tree, this function is equivalent to [haliax.tree_util.tree_map][]. 36 | 37 | """ 38 | old_is_leaf = is_leaf 39 | if is_leaf is None: 40 | is_leaf = lambda x: isinstance(x, haliax.nn.Stacked) 41 | else: 42 | is_leaf = lambda x: old_is_leaf(x) or isinstance(x, haliax.nn.Stacked) 43 | 44 | mapped_fn = functools.partial(scan_aware_tree_map, fn, is_leaf=is_leaf) 45 | 46 | def rec_fn(x, *rest): 47 | if isinstance(x, haliax.nn.Stacked): 48 | new_inner = haliax.vmap(mapped_fn, x.Block)(x.stacked, *[r.stacked for r in rest]) 49 | return dataclasses.replace(x, stacked=new_inner) # type: ignore 50 | else: 51 | return fn(x, *rest) 52 | 53 | return tree_map(rec_fn, tree, *rest, is_leaf=is_leaf) 54 | 55 | 56 | def tree_flatten(tree, is_leaf=None): 57 | """ 58 | Version of [jax.tree_util.tree_flatten][] that automatically treats NamedArrays as leaves. 59 | """ 60 | if is_leaf is None: 61 | is_leaf = lambda x: isinstance(x, NamedArray) 62 | else: 63 | is_leaf = lambda x: is_leaf(x) or is_named_array(x) 64 | 65 | return jax.tree_util.tree_flatten(tree, is_leaf=is_leaf) 66 | 67 | 68 | def tree_unflatten(treedef, leaves): 69 | """ 70 | Provided for consistency with tree_flatten. 71 | """ 72 | return jax.tree_util.tree_unflatten(treedef, leaves) 73 | 74 | 75 | def tree_leaves(tree, is_leaf=None): 76 | """ 77 | Version of [jax.tree_util.tree_leaves][] that automatically treats NamedArrays as leaves. 78 | """ 79 | if is_leaf is None: 80 | is_leaf = lambda x: isinstance(x, NamedArray) 81 | else: 82 | is_leaf = lambda x: is_leaf(x) or is_named_array(x) 83 | 84 | return jax.tree_util.tree_leaves(tree, is_leaf=is_leaf) 85 | 86 | 87 | def tree_structure(tree, is_leaf=None): 88 | """ 89 | Version of [jax.tree_util.tree_structure][] that automatically treats NamedArrays as leaves. 90 | """ 91 | if is_leaf is None: 92 | is_leaf = lambda x: isinstance(x, NamedArray) 93 | else: 94 | is_leaf = lambda x: is_leaf(x) or is_named_array(x) 95 | 96 | return jax.tree_util.tree_structure(tree, is_leaf=is_leaf) 97 | 98 | 99 | def resize_axis(tree: PyTree[NamedArray], old_axis: AxisSelector, new_size: int, key: Optional[PRNGKeyArray] = None): 100 | """Resizes the NamedArrays of a PyTree along a given axis. If the array needs to grow and key is not none, then the 101 | new elements are sampled from a truncated normal distribution with the same mean and standard deviation as the 102 | existing elements. If the key is none, they're just initialized to the mean. If the array needs to shrink, then it's 103 | truncated. 104 | 105 | Note: if you have a module that stores a reference to the old axis, then you'll need to update that reference 106 | manually. 107 | 108 | """ 109 | import haliax.random 110 | 111 | def _resize_one(x, key): 112 | if not is_named_array(x): 113 | return x 114 | 115 | assert isinstance(x, NamedArray) 116 | 117 | try: 118 | current_axis = x.resolve_axis(old_axis) 119 | except ValueError: 120 | return x 121 | 122 | if new_size == current_axis.size: 123 | return x 124 | elif current_axis.size > new_size: 125 | return x.slice(current_axis, start=0, length=new_size) 126 | else: 127 | num_padding = new_size - current_axis.size 128 | 129 | mean = x.mean(current_axis) 130 | std = x.std(current_axis) 131 | 132 | # the shape of the padding is the same as the original array, except with the axis size changed 133 | padding_axes = list(x.axes) 134 | padding_axes[padding_axes.index(current_axis)] = current_axis.resize(num_padding) 135 | 136 | if key is None: 137 | padding = mean.broadcast_axis(padding_axes) 138 | else: 139 | padding = haliax.random.truncated_normal(key, padding_axes, lower=-2, upper=2) * std + mean 140 | 141 | return haliax.concatenate(current_axis.name, [x, padding]) 142 | 143 | leaves, structure = jax.tree_util.tree_flatten(tree, is_leaf=is_named_array) 144 | keys = maybe_rng_split(key, len(leaves)) 145 | 146 | new_leaves = [_resize_one(x, key) for x, key in zip(leaves, keys)] 147 | 148 | return jax.tree_util.tree_unflatten(structure, new_leaves) 149 | 150 | 151 | # old version of eqx's partition functions 152 | def hashable_partition(pytree, filter_spec): 153 | dynamic, static = eqx.partition(pytree, filter_spec) 154 | static_leaves, static_treedef = jtu.tree_flatten(static) 155 | static_leaves = tuple(static_leaves) 156 | return dynamic, (static_leaves, static_treedef) 157 | 158 | 159 | def hashable_combine(dynamic, static): 160 | static_leaves, static_treedef = static 161 | static = jtu.tree_unflatten(static_treedef, static_leaves) 162 | return eqx.combine(dynamic, static) 163 | -------------------------------------------------------------------------------- /src/haliax/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Literal, Protocol, Tuple, TypeAlias, Union 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from jax.lax import Precision 6 | from jaxtyping import PyTree 7 | 8 | 9 | DType: TypeAlias = np.dtype 10 | 11 | try: 12 | from jax.typing import DTypeLike 13 | except ImportError: 14 | # Cribbed from jax.typing, for older versions of JAX 15 | class SupportsDType(Protocol): 16 | @property 17 | def dtype(self) -> DType: 18 | ... 19 | 20 | DTypeLike = Union[ 21 | str, # like 'float32', 'int32' 22 | type, # like np.float32, np.int32, float, int 23 | np.dtype, # like np.dtype('float32'), np.dtype('int32') 24 | SupportsDType, # like jnp.float32, jnp.int32 25 | ] 26 | 27 | 28 | Scalar = Union[float, int, jnp.ndarray] # ndarray b/c array(1) is a scalar 29 | IntScalar = Union[int, jnp.ndarray] 30 | 31 | PrecisionLike = Union[None, str, Precision, Tuple[str, str], Tuple[Precision, Precision]] 32 | 33 | GatherScatterModeStr = Literal["promise_in_bounds", "clip", "drop", "fill"] 34 | 35 | 36 | FilterSpec = Union[bool, Callable[[Any], bool]] 37 | """ 38 | A filter specification. Typically used on a pytree to filter out certain subtrees. Boolean values are 39 | treated as-is, while callables are called on each element of the pytree. If the callable returns True, the element 40 | is kept, otherwise it is filtered out. 41 | """ 42 | 43 | FilterTree = FilterSpec | PyTree[FilterSpec] 44 | -------------------------------------------------------------------------------- /src/haliax/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Sequence, Tuple, TypeAlias, TypeVar, Union 3 | 4 | import equinox 5 | 6 | from haliax.jax_utils import is_jax_array_like 7 | 8 | 9 | T = TypeVar("T") 10 | 11 | Unspecified: TypeAlias = type("NotSpecified", (), {}) # type: ignore 12 | UNSPECIFIED = Unspecified() 13 | 14 | 15 | def is_named_array(leaf): 16 | from .core import NamedArray 17 | 18 | "Typically used as the is_leaf predicate in tree_map" 19 | return isinstance(leaf, NamedArray) 20 | 21 | 22 | def ensure_tuple(x: Union[Sequence[T], T]) -> Tuple[T, ...]: 23 | if isinstance(x, str): 24 | return (x,) # type: ignore 25 | elif isinstance(x, Sequence): 26 | return tuple(x) 27 | return (x,) 28 | 29 | 30 | def maybe_untuple(x: Union[Sequence[T], T]) -> Union[T, Sequence[T]]: 31 | """ 32 | If x is a tuple with one element, return that element. Otherwise return x. 33 | """ 34 | if isinstance(x, tuple) and len(x) == 1: 35 | return x[0] 36 | return x 37 | 38 | 39 | class StringHolderEnum(type): 40 | """Like a python enum but just holds string constants, as opposed to wrapped string constants""" 41 | 42 | # https://stackoverflow.com/questions/62881486/a-group-of-constants-in-python 43 | 44 | def __new__(cls, name, bases, members): 45 | cls.members = [v for k, v in members.items() if not k.startswith("__") and not callable(v)] 46 | return super().__new__(cls, name, bases, members) 47 | 48 | # giving your class an __iter__ method gives you membership checking 49 | # and the ability to easily convert to another iterable 50 | @classmethod 51 | def __iter__(cls): 52 | yield from cls.members 53 | 54 | 55 | def is_jax_or_hax_array_like(x): 56 | return is_jax_array_like(x) or is_named_array(x) 57 | 58 | 59 | def safe_wraps(fn): 60 | """ 61 | Equinox has a special [equinox.module_update_wrapper][] that works with [equinox.Module][]s, but 62 | doesn't work with regular functions. Likewise, functools.update_wrapper doesn't work with [equinox.Module][]s. 63 | 64 | This function is a wrapper around both of them that works with both [equinox.Module][]s and regular functions. 65 | 66 | Use this if you get this exception: `dataclasses.FrozenInstanceError: cannot assign to field '__module__'` 67 | """ 68 | return functools.partial(safe_update_wrapper, wrapped=fn) 69 | 70 | 71 | def safe_update_wrapper(wrapper, wrapped): 72 | """ 73 | As [safe_wraps][] but not a decorator. 74 | Args: 75 | wrapper: 76 | wrapped: 77 | 78 | Returns: 79 | 80 | """ 81 | if isinstance(wrapper, equinox.Module): 82 | return equinox.module_update_wrapper(wrapper, wrapped) 83 | else: 84 | return functools.update_wrapper(wrapper, wrapped) 85 | 86 | 87 | __all__ = [ 88 | "is_named_array", 89 | "ensure_tuple", 90 | "StringHolderEnum", 91 | "is_jax_or_hax_array_like", 92 | ] 93 | -------------------------------------------------------------------------------- /src/haliax/wrap.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Optional, Protocol 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from haliax.core import NamedArray, _broadcast_order, broadcast_to 8 | from haliax.jax_utils import is_scalarish 9 | from haliax.util import ensure_tuple 10 | 11 | from .axis import AxisSelection, AxisSelector, selects_axis 12 | 13 | 14 | def wrap_elemwise_unary(f, a, *args, **kwargs): 15 | if isinstance(a, NamedArray): 16 | return NamedArray(f(a.array, *args, **kwargs), a.axes) 17 | else: 18 | return f(a, *args, **kwargs) 19 | 20 | 21 | def wrap_reduction_call( 22 | fn, 23 | a, 24 | axis: Optional[AxisSelection], 25 | where: Optional[NamedArray] = None, 26 | single_axis_only: bool = False, 27 | supports_where: bool = True, 28 | **kwargs, 29 | ): 30 | kwargs = dict(kwargs) 31 | if where is not None and not supports_where: 32 | raise ValueError(f"where is not supported by {fn.__name__}") 33 | 34 | if kwargs.get("out", None) is not None: 35 | raise ValueError("out is not supported yet for NamedArray") 36 | if kwargs.get("keepdims", False): 37 | raise ValueError("keepdims is not supported for NamedArray") 38 | 39 | def reduce_one_leaf(a): 40 | nonlocal axis, where 41 | if isinstance(a, NamedArray): 42 | if where is not None: 43 | if not isinstance(where, NamedArray): 44 | raise TypeError(f"where must be a NamedArray if a is a NamedArray, but is {where}") 45 | where = broadcast_to(where, a.axes) 46 | kwargs["where"] = where.array 47 | 48 | if axis is None: 49 | result = fn(a.array, axis=None, **kwargs) 50 | return NamedArray(result, ()) 51 | else: 52 | axis = ensure_tuple(axis) 53 | if single_axis_only and len(axis) > 1: 54 | raise ValueError(f"{fn.__name__} only supports a single axis") 55 | indices = a._lookup_indices(axis) 56 | if indices is None or any(x is None for x in indices): 57 | raise ValueError(f"axis {axis} is not in {a.axes}") 58 | new_axes = [ax for ax in a.axes if not selects_axis(axis, ax)] 59 | if single_axis_only: 60 | result = fn(a.array, axis=indices[0], **kwargs) 61 | else: 62 | result = fn(a.array, axis=indices, **kwargs) 63 | return NamedArray(result, tuple(new_axes)) 64 | else: 65 | if where is not None: 66 | kwargs["where"] = where 67 | return fn(a, axis=axis, **kwargs) 68 | 69 | return jax.tree_util.tree_map(reduce_one_leaf, a, is_leaf=lambda x: isinstance(x, NamedArray)) 70 | 71 | 72 | def wrap_axiswise_call(fn, a, axis: Optional[AxisSelection], *, single_axis_only: bool, **kwargs): 73 | if isinstance(a, NamedArray): 74 | if axis is None: 75 | return fn(a.array, axis=None, **kwargs) 76 | else: 77 | indices = ensure_tuple(a._lookup_indices(axis)) 78 | if any(x is None for x in indices): 79 | raise ValueError(f"axis {axis} is not in {a.axes}") 80 | if len(indices) == 1: 81 | return NamedArray(fn(a.array, axis=indices[0], **kwargs), a.axes) 82 | elif single_axis_only: 83 | raise ValueError(f"{fn.__name__} only supports a single axis") 84 | else: 85 | return NamedArray(fn(a.array, axis=indices, **kwargs), a.axes) 86 | 87 | else: 88 | return fn(a, axis=axis, **kwargs) 89 | 90 | 91 | def wrap_elemwise_binary(op): 92 | def binop(a, b): 93 | if isinstance(a, NamedArray) and isinstance(b, NamedArray): 94 | axes = _broadcast_order(a, b) 95 | a = broadcast_to(a, axes) 96 | b = broadcast_to(b, axes) 97 | return NamedArray(op(a.array, b.array), axes) 98 | elif isinstance(a, NamedArray): 99 | return NamedArray(op(a.array, b), a.axes) 100 | elif isinstance(b, NamedArray): 101 | return NamedArray(op(a, b.array), b.axes) 102 | else: 103 | return op(a, b) 104 | 105 | return binop 106 | 107 | 108 | def unwrap_namedarrays(*a): 109 | return tuple(x.array if isinstance(x, NamedArray) else x for x in a) 110 | 111 | 112 | class ReductionFunction(Protocol): 113 | def __call__( 114 | self, 115 | array: NamedArray, 116 | axis: Optional[AxisSelection] = None, 117 | where: Optional[NamedArray] = None, 118 | **kwargs, 119 | ) -> NamedArray: 120 | ... 121 | 122 | 123 | class SimpleReductionFunction(Protocol): 124 | def __call__(self, array: NamedArray, axis: Optional[AxisSelector] = None, **kwargs) -> NamedArray: 125 | ... 126 | 127 | 128 | __all__ = [ 129 | "wrap_elemwise_unary", 130 | "wrap_reduction_call", 131 | "wrap_axiswise_call", 132 | "wrap_elemwise_binary", 133 | "unwrap_namedarrays", 134 | "ReductionFunction", 135 | "SimpleReductionFunction", 136 | ] 137 | -------------------------------------------------------------------------------- /tests/test_attention.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | from jax.random import PRNGKey 4 | 5 | import haliax as hax 6 | from haliax.nn.attention import ( 7 | alibi_attention_bias, 8 | causal_mask, 9 | dot_product_attention, 10 | dot_product_attention_weights, 11 | forgetful_causal_mask, 12 | self_attention, 13 | ) 14 | from test_utils import skip_if_no_torch 15 | 16 | 17 | def test_dot_product_attention_requires_axis_to_be_present(): 18 | Pos = hax.Axis("Pos", 20) 19 | KeyPos = hax.Axis("Pos_key", 20) 20 | NumHeads = hax.Axis("NumHeads", 1) 21 | Hid = hax.Axis("Hid", 8) 22 | 23 | query = hax.ones((NumHeads, KeyPos, Hid)) # NB: KeyPos not Pos 24 | key = hax.ones((KeyPos, NumHeads, Hid)) 25 | value = hax.ones((KeyPos, NumHeads, Hid)) 26 | 27 | try: 28 | dot_product_attention(Pos, Hid, query, key, value) 29 | except ValueError as e: 30 | assert "not found" in str(e) 31 | else: 32 | raise AssertionError("Should have raised an error") 33 | 34 | 35 | def test_attention_doesnt_allow_overlapping_axes(): 36 | KeyPos = hax.Axis("Pos_key", 20) 37 | NumHeads = hax.Axis("NumHeads", 1) 38 | Hid = hax.Axis("Hid", 8) 39 | 40 | query = hax.ones((NumHeads, KeyPos, Hid)) # NB: KeyPos not Pos 41 | key = hax.ones((KeyPos, NumHeads, Hid)) 42 | value = hax.ones((KeyPos, NumHeads, Hid)) 43 | 44 | try: 45 | dot_product_attention(KeyPos, Hid, query, key, value) 46 | except ValueError as e: 47 | assert "must be distinct" in str(e) 48 | else: 49 | raise AssertionError("Should have raised an error") 50 | 51 | 52 | def test_self_attention_basically_works(): 53 | Pos = hax.Axis("Pos", 20) 54 | KeyPos = hax.Axis("Pos_key", 20) 55 | NumHeads = hax.Axis("NumHeads", 1) 56 | Hid = hax.Axis("Hid", 8) 57 | 58 | query = hax.ones((NumHeads, Pos, Hid)) 59 | 60 | result = self_attention(Pos, Hid, query, query, query, is_causal=True) 61 | assert result.axes == (NumHeads, Pos, Hid) 62 | 63 | k = query.rename({Pos: KeyPos}) 64 | cmask = causal_mask(Pos, KeyPos) 65 | result2 = dot_product_attention(KeyPos, Hid, query, k, k, mask=cmask) 66 | assert result2.axes == (NumHeads, Pos, Hid) 67 | 68 | # tight tolerances because it should be exactly the same computation 69 | assert jnp.allclose(result.array, result2.array) 70 | 71 | 72 | def test_alibi_attention_bias(): 73 | KeyPos = hax.Axis("KeyPos", 20) 74 | NumHeads = hax.Axis("NumHeads", 1) 75 | Hid = hax.Axis("Hid", 8) 76 | 77 | bias = alibi_attention_bias(NumHeads, KeyPos) 78 | 79 | query = hax.ones((NumHeads, Hid)) 80 | key = hax.ones((KeyPos, NumHeads, Hid)) 81 | 82 | weights_bias = dot_product_attention_weights(Hid, KeyPos, query, key, bias=bias) 83 | weights_no_bias = dot_product_attention_weights(Hid, KeyPos, query, key) 84 | 85 | assert weights_bias[KeyPos, -1] > weights_bias[KeyPos, -2] 86 | assert weights_bias[KeyPos, -1] > weights_no_bias[KeyPos, -1] 87 | 88 | assert weights_no_bias[KeyPos, -1] == weights_no_bias[KeyPos, -2] 89 | 90 | 91 | @skip_if_no_torch 92 | def test_alibi_attention_compared_to_hf(): 93 | import torch 94 | from transformers.models.bloom.modeling_bloom import build_alibi_tensor 95 | 96 | L, H = hax.make_axes(L=1, H=16) 97 | 98 | # Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) 99 | torch_tensor = ( 100 | build_alibi_tensor(torch.ones(1, L.size), H.size, dtype=torch.float32).numpy().reshape(H.size, L.size) 101 | ) 102 | 103 | hax_tensor = np.array(alibi_attention_bias(H, L).array) 104 | 105 | assert np.allclose(torch_tensor, hax_tensor) 106 | 107 | 108 | def test_fcm_attention_mask(): 109 | KeyPos, QueryPos, Head = hax.make_axes(KeyPos=20, QueryPos=10, Head=8) 110 | 111 | mask = forgetful_causal_mask(KeyPos, mask_prob=0.6, sample_prob=False, key=PRNGKey(0)) 112 | 113 | assert mask.axes == (KeyPos,) 114 | assert mask.array[0].item() == 1 115 | 116 | assert mask.astype(float).sum().item() <= KeyPos.size 117 | 118 | query = hax.arange(QueryPos).broadcast_axis(Head) 119 | key = hax.arange(KeyPos).broadcast_axis(Head) 120 | 121 | weights = dot_product_attention_weights(Head, KeyPos, query, key, mask=mask) 122 | 123 | # check that all masked out values are zero 124 | weights = weights.rearrange((KeyPos, QueryPos)) 125 | 126 | assert (weights * (mask == 0)).sum() == 0 127 | assert (weights * (mask == 1)).sum() > 0 128 | -------------------------------------------------------------------------------- /tests/test_axis.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from haliax.axis import Axis, eliminate_axes, make_axes, overlapping_axes, rearrange_for_partial_order 4 | 5 | 6 | def test_eliminate_axes(): 7 | H, W, C = make_axes(H=3, W=4, C=5) 8 | 9 | assert eliminate_axes((H, W), (H,)) == (W,) 10 | assert eliminate_axes((H, W), (W,)) == (H,) 11 | assert eliminate_axes((H, W), (H, W)) == () 12 | 13 | with pytest.raises(ValueError): 14 | eliminate_axes((H, W), (C,)) 15 | 16 | with pytest.raises(ValueError): 17 | eliminate_axes((H, W), (H, C)) 18 | 19 | # test string references 20 | assert eliminate_axes((H, W), ("H",)) == (W,) 21 | assert eliminate_axes(("H", W), (H,)) == (W,) 22 | assert eliminate_axes(("H", W), ("H",)) == (W,) 23 | assert eliminate_axes(("H", W), ("H", "W")) == () 24 | 25 | 26 | def assert_partial_order_respected(partial_order, output): 27 | positions = {el: i for i, el in enumerate(output)} 28 | 29 | last_pos = -1 30 | for el in partial_order: 31 | if el is ...: 32 | # Reset last_pos for flexible positions 33 | last_pos = -1 34 | else: 35 | # Check if the element is in the correct order 36 | assert el in positions, f"{el} is missing in the output" 37 | assert positions[el] > last_pos, f"Partial order not respected for {el}" 38 | last_pos = positions[el] 39 | 40 | 41 | def test_basic_order(): 42 | partial_order = ("apple", ..., "banana") 43 | candidates = ("banana", "apple", "cherry") 44 | expected_output = ("apple", "cherry", "banana") 45 | actual_output = rearrange_for_partial_order(partial_order, candidates) 46 | assert actual_output == expected_output 47 | assert_partial_order_respected(partial_order, actual_output) 48 | 49 | 50 | def test_start_with_ellipsis(): 51 | partial_order = (..., "apple", "banana") 52 | candidates = ("banana", "apple", "cherry") 53 | actual_output = rearrange_for_partial_order(partial_order, candidates) 54 | assert_partial_order_respected(partial_order, actual_output) 55 | assert actual_output == ("cherry", "apple", "banana") 56 | 57 | 58 | def test_end_with_ellipsis(): 59 | partial_order = ("apple", ..., "banana", ...) 60 | candidates = ("banana", "apple", "cherry") 61 | actual_output = rearrange_for_partial_order(partial_order, candidates) 62 | assert_partial_order_respected(partial_order, actual_output) 63 | 64 | # this one could be either but we'll assert the order so we catch changes 65 | assert actual_output == ("apple", "banana", "cherry") 66 | 67 | 68 | def test_full_flexibility(): 69 | partial_order = (...,) 70 | candidates = ("banana", "apple", "cherry") 71 | actual_output = rearrange_for_partial_order(partial_order, candidates) 72 | assert_partial_order_respected(partial_order, actual_output) 73 | 74 | 75 | def test_no_flexibility(): 76 | partial_order = ("apple", "banana") 77 | candidates = ("banana", "apple", "cherry") 78 | with pytest.raises(ValueError): 79 | rearrange_for_partial_order(partial_order, candidates) 80 | 81 | 82 | def test_final_ellipsis(): 83 | partial_order = ("apple", "banana", ...) 84 | candidates = ("banana", "apple", "cherry") 85 | actual_output = rearrange_for_partial_order(partial_order, candidates) 86 | assert_partial_order_respected(partial_order, actual_output) 87 | assert actual_output == ("apple", "banana", "cherry") 88 | 89 | 90 | def test_lots_of_ellipsis(): 91 | partial_order = ("apple", ..., "banana", ..., "cherry", ...) 92 | candidates = ("banana", "orange", "cherry", "apple", "grape") 93 | actual_output = rearrange_for_partial_order(partial_order, candidates) 94 | assert_partial_order_respected(partial_order, actual_output) 95 | assert actual_output == ("apple", "banana", "orange", "cherry", "grape") 96 | 97 | 98 | def test_no_ellipsis(): 99 | partial_order = ("apple", "banana", "cherry") 100 | candidates = ("banana", "apple", "cherry") 101 | actual_output = rearrange_for_partial_order(partial_order, candidates) 102 | assert_partial_order_respected(partial_order, actual_output) 103 | assert actual_output == ("apple", "banana", "cherry") 104 | 105 | 106 | def test_no_elements(): 107 | partial_order = (...,) 108 | candidates = () 109 | actual_output = rearrange_for_partial_order(partial_order, candidates) 110 | assert_partial_order_respected(partial_order, actual_output) 111 | assert actual_output == () 112 | 113 | 114 | def test_missing_elements_errors(): 115 | partial_order = ("qux", ...) 116 | candidates = ("banana", "apple", "cherry") 117 | with pytest.raises(ValueError): 118 | rearrange_for_partial_order(partial_order, candidates) 119 | 120 | 121 | def test_duplicate_elements_errors(): 122 | partial_order: tuple = ("apple", "apple", ...) 123 | candidates = ("banana", "apple", "cherry") 124 | with pytest.raises(ValueError): 125 | rearrange_for_partial_order(partial_order, candidates) 126 | 127 | candidates = ("banana", "apple", "apple") 128 | 129 | with pytest.raises(ValueError): 130 | rearrange_for_partial_order(partial_order, candidates) 131 | 132 | partial_order = ("apple", "banana", "grape", ...) 133 | 134 | with pytest.raises(ValueError): 135 | rearrange_for_partial_order(partial_order, candidates) 136 | 137 | 138 | def test_overlapping_axes_with_different_sizes(): 139 | A1 = Axis("A", 10) 140 | A2 = Axis("A", 12) 141 | B = Axis("B", 14) 142 | C = Axis("C", 16) 143 | D = Axis("D", 18) 144 | 145 | ax1 = (A1, B, C) 146 | ax2 = (A2, C, D) 147 | 148 | overlapping_names = overlapping_axes(ax1, ax2) # Should not error 149 | assert overlapping_names == ("A", "C") 150 | -------------------------------------------------------------------------------- /tests/test_conv.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | import haliax as hax 6 | from haliax.nn.conv import Conv, ConvTranspose 7 | 8 | 9 | def test_conv_basic_equiv_to_eqx(): 10 | key = jax.random.PRNGKey(0) 11 | In = hax.Axis("In", 3) 12 | Out = hax.Axis("Out", 4) 13 | hax_conv = Conv.init(("Height", "Width"), In, Out, kernel_size=3, key=key) 14 | eqx_conv = eqx.nn.Conv(2, 3, 4, kernel_size=3, key=key) 15 | 16 | assert hax_conv.weight.array.shape == eqx_conv.weight.shape 17 | assert hax_conv.bias.array.shape == eqx_conv.bias.shape[0:1] 18 | assert jnp.all(hax_conv.weight.array == eqx_conv.weight) 19 | 20 | input = hax.random.normal(jax.random.PRNGKey(1), (In, hax.Axis("Height", 5), hax.Axis("Width", 6))) 21 | hax_output = hax_conv(input) 22 | eqx_output = eqx_conv(input.array) 23 | 24 | assert hax_output.array.shape == eqx_output.shape 25 | assert jnp.all(hax_output.array == eqx_output) 26 | 27 | # test batched 28 | input = hax.random.normal( 29 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6)) 30 | ) 31 | hax_output = hax_conv(input) 32 | eqx_output = eqx.filter_vmap(eqx_conv)(input.array) 33 | 34 | assert hax_output.array.shape == eqx_output.shape 35 | assert jnp.all(hax_output.array == eqx_output) 36 | 37 | # test multibatch 38 | input = hax.random.normal( 39 | jax.random.PRNGKey(1), 40 | (hax.Axis("Batch", 2), hax.Axis("Batch2", 3), In, hax.Axis("Height", 5), hax.Axis("Width", 6)), 41 | ) 42 | hax_output = hax_conv(input) 43 | eqx_output = eqx.filter_vmap(eqx.filter_vmap(eqx_conv))(input.array) 44 | 45 | assert hax_output.array.shape == eqx_output.shape 46 | assert jnp.allclose(hax_output.array, eqx_output) 47 | 48 | input = hax.random.normal( 49 | jax.random.PRNGKey(1), 50 | ( 51 | hax.Axis("Batch", 2), 52 | In, 53 | hax.Axis("Height", 5), 54 | hax.Axis("Width", 6), 55 | hax.Axis("Batch2", 3), 56 | ), 57 | ) 58 | hax_output = hax_conv(input).rearrange(("Batch", "Batch2", "Out", "Height", "Width")) 59 | eqx_output = eqx.filter_vmap(eqx.filter_vmap(eqx_conv))( 60 | input.rearrange(("Batch", "Batch2", "In", "Height", "Width")).array 61 | ) 62 | 63 | assert hax_output.array.shape == eqx_output.shape 64 | assert jnp.allclose(hax_output.array, eqx_output) 65 | 66 | 67 | def test_conv_grouped_equiv_to_eqx(): 68 | key = jax.random.PRNGKey(0) 69 | In = hax.Axis("In", 4) 70 | Out = hax.Axis("Out", 6) 71 | hax_conv = Conv.init(("Height", "Width"), In, Out, kernel_size=3, groups=2, key=key) 72 | eqx_conv = eqx.nn.Conv(2, 4, 6, kernel_size=3, groups=2, key=key) 73 | 74 | assert hax_conv.weight.array.shape == eqx_conv.weight.shape 75 | assert hax_conv.bias.array.shape == eqx_conv.bias.shape[0:1] 76 | assert jnp.all(hax_conv.weight.array == eqx_conv.weight) 77 | 78 | input = hax.random.normal(jax.random.PRNGKey(1), (In, hax.Axis("Height", 5), hax.Axis("Width", 6))) 79 | eqx_output = eqx_conv(input.array) 80 | hax_output = hax_conv(input) 81 | 82 | assert hax_output.array.shape == eqx_output.shape 83 | assert jnp.all(hax_output.array == eqx_output) 84 | 85 | # test batched 86 | input = hax.random.normal( 87 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6)) 88 | ) 89 | hax_output = hax_conv(input) 90 | eqx_output = eqx.filter_vmap(eqx_conv)(input.array) 91 | 92 | assert hax_output.array.shape == eqx_output.shape 93 | assert jnp.all(hax_output.array == eqx_output) 94 | 95 | # test multibatch 96 | input = hax.random.normal( 97 | jax.random.PRNGKey(1), 98 | (hax.Axis("Batch", 2), hax.Axis("Batch2", 3), In, hax.Axis("Height", 5), hax.Axis("Width", 6)), 99 | ) 100 | hax_output = hax_conv(input) 101 | eqx_output = eqx.filter_vmap(eqx.filter_vmap(eqx_conv))(input.array) 102 | 103 | assert hax_output.array.shape == eqx_output.shape 104 | assert jnp.allclose(hax_output.array, eqx_output) 105 | 106 | 107 | def test_conv_weird_order(): 108 | key = jax.random.PRNGKey(0) 109 | In = hax.Axis("In", 3) 110 | Out = hax.Axis("Out", 4) 111 | hax_conv = Conv.init(("Height", "Width"), In, Out, kernel_size=3, key=key) 112 | eqx_conv = eqx.nn.Conv(2, 3, 4, kernel_size=3, key=key) 113 | 114 | assert hax_conv.weight.array.shape == eqx_conv.weight.shape 115 | assert hax_conv.bias.array.shape == eqx_conv.bias.shape[0:1] 116 | assert jnp.all(hax_conv.weight.array == eqx_conv.weight) 117 | 118 | input = hax.random.normal( 119 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6)) 120 | ) 121 | hax_output = hax_conv(input) 122 | 123 | # test weird orders 124 | input = input.rearrange(("In", "Height", "Width", "Batch")) 125 | hax_output2 = hax_conv(input).rearrange(("Batch", "Out", "Height", "Width")) 126 | 127 | assert jnp.allclose(hax_output.array, hax_output2.array) 128 | 129 | 130 | def test_conv_transpose_basic_equiv_to_eqx(): 131 | key = jax.random.PRNGKey(0) 132 | In = hax.Axis("In", 3) 133 | Out = hax.Axis("Out", 4) 134 | hax_conv = ConvTranspose.init( 135 | ("Height", "Width"), In, Out, kernel_size=3, dilation=2, output_padding=1, stride=2, key=key 136 | ) 137 | eqx_conv = eqx.nn.ConvTranspose(2, 3, 4, kernel_size=3, dilation=2, output_padding=1, stride=2, key=key) 138 | 139 | assert hax_conv.weight.array.shape == eqx_conv.weight.shape 140 | assert hax_conv.bias.array.shape == eqx_conv.bias.shape[0:1] 141 | assert jnp.all(hax_conv.weight.array == eqx_conv.weight) 142 | 143 | input = hax.random.normal(jax.random.PRNGKey(1), (In, hax.Axis("Height", 5), hax.Axis("Width", 6))) 144 | hax_output = hax_conv(input) 145 | eqx_output = eqx_conv(input.array) 146 | 147 | assert hax_output.array.shape == eqx_output.shape 148 | assert jnp.all(hax_output.array == eqx_output) 149 | 150 | # test batched 151 | input = hax.random.normal( 152 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6)) 153 | ) 154 | hax_output = hax_conv(input) 155 | eqx_output = eqx.filter_vmap(eqx_conv)(input.array) 156 | 157 | assert hax_output.array.shape == eqx_output.shape 158 | assert jnp.all(hax_output.array == eqx_output) 159 | 160 | # test multibatch 161 | input = hax.random.normal( 162 | jax.random.PRNGKey(1), 163 | (hax.Axis("Batch", 2), hax.Axis("Batch2", 3), In, hax.Axis("Height", 5), hax.Axis("Width", 6)), 164 | ) 165 | hax_output = hax_conv(input) 166 | eqx_output = eqx.filter_vmap(eqx.filter_vmap(eqx_conv))(input.array) 167 | 168 | assert hax_output.array.shape == eqx_output.shape 169 | assert jnp.allclose(hax_output.array, eqx_output) 170 | 171 | 172 | def test_weird_orders_conv_transpose(): 173 | key = jax.random.PRNGKey(0) 174 | In = hax.Axis("In", 3) 175 | Out = hax.Axis("Out", 4) 176 | hax_conv = ConvTranspose.init( 177 | ("Height", "Width"), In, Out, kernel_size=3, dilation=2, output_padding=1, stride=2, key=key 178 | ) 179 | 180 | input = hax.random.normal( 181 | jax.random.PRNGKey(1), (hax.Axis("Batch", 2), In, hax.Axis("Height", 5), hax.Axis("Width", 6)) 182 | ) 183 | hax_output = hax_conv(input) 184 | 185 | # test weird orders 186 | input = input.rearrange(("In", "Height", "Width", "Batch")) 187 | hax_output2 = hax_conv(input).rearrange(("Batch", "Out", "Height", "Width")) 188 | 189 | assert jnp.allclose(hax_output.array, hax_output2.array) 190 | -------------------------------------------------------------------------------- /tests/test_debug.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | import jax.numpy as jnp 3 | import pytest 4 | 5 | import haliax as hax 6 | 7 | 8 | def test_diagnose_common_issues_repeated(): 9 | class M(eqx.Module): 10 | a: jnp.ndarray = eqx.field() 11 | b: jnp.ndarray = eqx.field() 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.a = jnp.zeros(1) 16 | self.b = self.a 17 | 18 | try: 19 | hax.debug.diagnose_common_issues(M()) 20 | pytest.fail("Should have raised an exception") 21 | except hax.debug.ModuleProblems as e: 22 | assert len(e.reused_arrays) == 1 23 | assert len(e.static_arrays) == 0 24 | 25 | 26 | def test_diagnose_common_issues_repeated_nested(): 27 | class M(eqx.Module): 28 | a: jnp.ndarray = eqx.field() 29 | b: jnp.ndarray = eqx.field() 30 | 31 | def __init__(self): 32 | super().__init__() 33 | self.a = jnp.zeros(1) 34 | self.b = self.a 35 | 36 | class N(eqx.Module): 37 | m: M = eqx.field() 38 | c: jnp.ndarray = eqx.field() 39 | 40 | def __init__(self): 41 | super().__init__() 42 | self.m = M() 43 | self.c = self.m.a 44 | 45 | try: 46 | hax.debug.diagnose_common_issues(N()) 47 | pytest.fail("Should have raised an exception") 48 | except hax.debug.ModuleProblems as e: 49 | assert len(e.reused_arrays) == 1 50 | assert e.reused_arrays[0][1] == [".m.a", ".m.b", ".c"] 51 | assert len(e.static_arrays) == 0 52 | 53 | 54 | def test_diagnose_common_issues_static(): 55 | class M(eqx.Module): 56 | a: jnp.ndarray = eqx.static_field() 57 | b: hax.NamedArray = eqx.static_field() 58 | 59 | def __init__(self): 60 | super().__init__() 61 | self.a = jnp.zeros(1) 62 | self.b = hax.named(jnp.zeros(3), "a") 63 | 64 | try: 65 | hax.debug.diagnose_common_issues(M()) 66 | pytest.fail("Should have raised an exception") 67 | except hax.debug.ModuleProblems as e: 68 | assert len(e.reused_arrays) == 0 69 | assert len(e.static_arrays) == 2 70 | 71 | 72 | def test_diagnose_common_issues_static_nested(): 73 | class M(eqx.Module): 74 | a: jnp.ndarray = eqx.static_field() 75 | b: hax.NamedArray = eqx.static_field() 76 | 77 | def __init__(self): 78 | super().__init__() 79 | self.a = jnp.zeros(1) 80 | self.b = hax.named(jnp.zeros(3), "a") 81 | 82 | class N(eqx.Module): 83 | m: M = eqx.field() 84 | c: jnp.ndarray = eqx.field() 85 | 86 | def __init__(self): 87 | super().__init__() 88 | self.m = M() 89 | self.c = self.m.a 90 | 91 | try: 92 | hax.debug.diagnose_common_issues(N()) 93 | pytest.fail("Should have raised an exception") 94 | except hax.debug.ModuleProblems as e: 95 | assert len(e.reused_arrays) == 0 96 | assert len(e.static_arrays) == 2 97 | -------------------------------------------------------------------------------- /tests/test_dot.py: -------------------------------------------------------------------------------- 1 | # these test if the rearrange logic works for partial orders 2 | import pytest 3 | from jax import numpy as jnp 4 | 5 | import haliax as hax 6 | from haliax import Axis, NamedArray 7 | 8 | 9 | def test_dot(): 10 | Height = Axis("Height", 2) 11 | Width = Axis("Width", 3) 12 | Depth = Axis("Depth", 4) 13 | 14 | m1 = hax.ones((Height, Width, Depth)) 15 | m2 = hax.ones((Depth, Width, Height)) 16 | 17 | assert jnp.all(jnp.equal(hax.dot(m1, m2, axis=Height).array, jnp.einsum("ijk,kji->jk", m1.array, m2.array))) 18 | assert jnp.all( 19 | jnp.equal( 20 | hax.dot(m1, m2, axis=(Height, Width)).array, 21 | jnp.einsum("ijk,kji->k", m1.array, m2.array), 22 | ) 23 | ) 24 | assert jnp.all( 25 | jnp.equal( 26 | hax.dot(m1, m2, axis=(Height, Width, Depth)).array, 27 | jnp.einsum("ijk,kji->", m1.array, m2.array), 28 | ) 29 | ) 30 | 31 | # reduce to scalar 32 | assert jnp.all( 33 | jnp.equal( 34 | hax.dot(m1, m2, axis=None), 35 | jnp.einsum("ijk,kji->", m1.array, m2.array), 36 | ) 37 | ) 38 | 39 | 40 | def test_dot_string_selection(): 41 | Height = Axis("Height", 2) 42 | Width = Axis("Width", 3) 43 | Depth = Axis("Depth", 4) 44 | 45 | m1 = hax.ones((Height, Width, Depth)) 46 | m2 = hax.ones((Depth, Width, Height)) 47 | 48 | assert jnp.all(jnp.equal(hax.dot(m1, m2, axis="Height").array, jnp.einsum("ijk,kji->jk", m1.array, m2.array))) 49 | assert jnp.all( 50 | jnp.equal( 51 | hax.dot(m1, m2, axis=("Height", "Width")).array, 52 | jnp.einsum("ijk,kji->k", m1.array, m2.array), 53 | ) 54 | ) 55 | assert jnp.all( 56 | jnp.equal( 57 | hax.dot(m1, m2, axis=("Height", "Width", "Depth")).array, 58 | jnp.einsum("ijk,kji->", m1.array, m2.array), 59 | ) 60 | ) 61 | 62 | 63 | def test_dot_errors_if_different_sized_axes(): 64 | Height = Axis("Height", 2) 65 | Width = Axis("Width", 3) 66 | Depth = Axis("Depth", 4) 67 | 68 | H2 = Axis("Height", 4) 69 | 70 | m1 = hax.ones((Height, Width, Depth)) 71 | m2 = hax.ones((Depth, Width, H2)) 72 | 73 | with pytest.raises(ValueError): 74 | hax.dot(m1, m2, axis="Height") 75 | 76 | 77 | def test_dot_with_output_axes(): 78 | Height = Axis("Height", 2) 79 | Width = Axis("Width", 3) 80 | Depth = Axis("Depth", 4) 81 | 82 | m1 = hax.ones((Height, Width, Depth)) 83 | m2 = hax.ones((Depth, Width, Height)) 84 | 85 | assert jnp.all( 86 | jnp.equal( 87 | hax.dot(m1, m2, axis=Height, out_axes=(Width, ...)).array, 88 | jnp.einsum("ijk,kji->jk", m1.array, m2.array), 89 | ) 90 | ) 91 | 92 | assert jnp.all( 93 | jnp.equal( 94 | hax.dot(m1, m2, axis=Height, out_axes=(Depth, ...)).array, 95 | jnp.einsum("ijk,kji->kj", m1.array, m2.array), 96 | ) 97 | ) 98 | 99 | assert jnp.all( 100 | jnp.equal( 101 | hax.dot(m1, m2, axis=Height, out_axes=(Depth, Width)).array, 102 | jnp.einsum("ijk,kji->kj", m1.array, m2.array), 103 | ) 104 | ) 105 | 106 | assert jnp.all( 107 | jnp.equal( 108 | hax.dot(m1, m2, axis=(), out_axes=(Depth, Height, ...)).array, 109 | jnp.einsum("ijk,kji->kij", m1.array, m2.array), 110 | ) 111 | ) 112 | 113 | assert jnp.all( 114 | jnp.equal( 115 | hax.dot(m1, m2, axis=(), out_axes=(..., Depth, Height)).array, 116 | jnp.einsum("ijk,kji->jki", m1.array, m2.array), 117 | ) 118 | ) 119 | 120 | assert jnp.all( 121 | jnp.equal( 122 | hax.dot(m1, m2, axis=(), out_axes=(..., Depth, ...)).array, 123 | jnp.einsum("ijk,kji->ijk", m1.array, m2.array), 124 | ) 125 | ) 126 | 127 | assert jnp.all( 128 | jnp.equal( 129 | hax.dot(m1, m2, axis=(), out_axes=(..., Depth, Height, ...)).array, 130 | jnp.einsum("ijk,kji->kij", m1.array, m2.array), 131 | ) 132 | ) 133 | -------------------------------------------------------------------------------- /tests/test_fp8.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import equinox as eqx 3 | import jax.numpy as jnp 4 | import jax.random as jrandom 5 | import jax.tree_util 6 | import numpy as np 7 | from chex import assert_trees_all_close 8 | 9 | import haliax as hax 10 | from haliax._src.fp8 import compute_scale 11 | from haliax.nn import Linear 12 | from haliax.quantization import ( 13 | Fp8DotGeneralOp, 14 | QuantizationConfig, 15 | apply_updates, 16 | partition_for_grad_overwrite, 17 | quantize_linear_layers, 18 | ) 19 | 20 | 21 | def test_fp8_is_reasonable(): 22 | In = hax.Axis("In", 8) 23 | Out = hax.Axis("Out", 8) 24 | linear = Linear.init(In, Out, key=jrandom.PRNGKey(0), init_scale=0.1) 25 | 26 | fp8_linear = Linear.init( 27 | In, Out, key=jrandom.PRNGKey(0), dot_general=hax.quantization.Fp8DotGeneralOp.init(), init_scale=0.1 28 | ) 29 | 30 | input = hax.random.normal(jrandom.PRNGKey(3), In) 31 | output = linear(input) 32 | fp8_output = fp8_linear(input) 33 | 34 | assert output.shape == fp8_output.shape 35 | assert output.dtype == fp8_output.dtype 36 | 37 | assert_trees_all_close(output.array, fp8_output.array, atol=2e-2, rtol=5e-2) 38 | 39 | 40 | # https://github.com/google/flax/blob/6f2b08e024c2fd2f8cec42a6c82408cb35412319/tests/linen/linen_test.py#L1222 41 | def test_fp_loop(): 42 | key, init_key, random_key = jrandom.split(jrandom.PRNGKey(seed=123), 3) 43 | Batch = hax.Axis("Batch", 16) 44 | In = hax.Axis("In", 16) 45 | Out = hax.Axis("Out", 32) 46 | linear = Linear.init(In, Out, key=init_key, dot_general=Fp8DotGeneralOp.init()) 47 | 48 | def _roll_and_update(amax_h, update): 49 | return jnp.roll(amax_h, shift=-1, axis=0).at[0].set(update) 50 | 51 | lr = 1e-3 52 | 53 | def apply_gradients(model, grads): 54 | overwrites, grads = partition_for_grad_overwrite(grads) 55 | updates = jax.tree_util.tree_map(lambda g: -lr * g, grads) 56 | model = apply_updates(model, updates, overwrites) 57 | return model 58 | 59 | def _train_step(model, x, dy): 60 | def loss_fn(lin): 61 | y = lin(x) 62 | loss = y * dy.astype(y.dtype) 63 | return hax.sum(loss).scalar() 64 | 65 | grad_fn = eqx.filter_grad(loss_fn) 66 | grads = grad_fn(model) 67 | return apply_gradients(model, grads) 68 | 69 | train_fn = eqx.filter_jit(_train_step) 70 | 71 | scale_x, amax_history_x = jnp.ones(()), jnp.zeros((1024,)) 72 | scale_k, amax_history_k = jnp.ones(()), jnp.zeros((1024,)) 73 | scale_g, amax_history_g = jnp.ones(()), jnp.zeros((1024,)) 74 | e4m3_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) 75 | e5m2_max = jnp.finfo(jnp.float8_e5m2).max.astype(jnp.float32) 76 | 77 | for _ in range(5): 78 | key, random_key = jrandom.split(key, 2) 79 | # x = jrandom.normal(random_key, (16, 16), dtype=jnp.float32) 80 | # g = jrandom.normal(random_key, (16, 32), dtype=jnp.float32) 81 | x = hax.random.normal( 82 | random_key, 83 | ( 84 | Batch, 85 | In, 86 | ), 87 | ) 88 | g = hax.random.normal( 89 | random_key, 90 | ( 91 | Batch, 92 | Out, 93 | ), 94 | ) 95 | 96 | # Manually compute the expected amax history and scaling factors. 97 | amax_from_history_x = jnp.max(amax_history_x, axis=0) 98 | amax_from_history_k = jnp.max(amax_history_k, axis=0) 99 | amax_from_history_g = jnp.max(amax_history_g, axis=0) 100 | scale_x = compute_scale(amax_from_history_x, scale_x, e4m3_max) 101 | scale_k = compute_scale(amax_from_history_k, scale_k, e4m3_max) 102 | scale_g = compute_scale(amax_from_history_g, scale_g, e5m2_max) 103 | amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x.array))) 104 | amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(linear.weight.array))) 105 | amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g.array))) 106 | 107 | linear = train_fn(linear, x, g) 108 | 109 | rtol, atol = 0.001, 0.001 110 | np.testing.assert_allclose( 111 | linear.dot_general.input_amax_history, # type: ignore 112 | amax_history_x, 113 | rtol=rtol, 114 | atol=atol, 115 | ) 116 | np.testing.assert_allclose( 117 | linear.dot_general.kernel_amax_history, # type: ignore 118 | amax_history_k, 119 | rtol=rtol, 120 | atol=atol, 121 | ) 122 | np.testing.assert_allclose( 123 | linear.dot_general.output_grad_amax_history, # type: ignore 124 | amax_history_g, 125 | rtol=rtol, 126 | atol=atol, 127 | ) 128 | 129 | np.testing.assert_allclose(linear.dot_general.input_scale, scale_x, rtol=rtol, atol=atol) # type: ignore 130 | np.testing.assert_allclose(linear.dot_general.kernel_scale, scale_k, rtol=rtol, atol=atol) # type: ignore 131 | np.testing.assert_allclose(linear.dot_general.output_grad_scale, scale_g, rtol=rtol, atol=atol) # type: ignore 132 | 133 | 134 | def test_layer_splicing(): 135 | key, init_key, random_key = jrandom.split(jrandom.PRNGKey(seed=123), 3) 136 | Input = hax.Axis("Input", 16) 137 | Hidden = hax.Axis("Hidden", 64) 138 | Output = hax.Axis("Output", 32) 139 | mlp = hax.nn.MLP.init(Input, Output, Hidden, 3, key=init_key, init_scale=0.1) 140 | 141 | mlp_q = quantize_linear_layers(mlp, QuantizationConfig(fp8=True)) 142 | for layer in mlp_q.layers: 143 | assert isinstance(layer.dot_general, Fp8DotGeneralOp) 144 | 145 | input = hax.random.normal(jrandom.PRNGKey(0), Input) * 10 # 10 so we don't underflow 146 | output = mlp(input) 147 | output_q = mlp_q(input) 148 | chex.assert_trees_all_close(output.array, output_q.array, atol=1e-3, rtol=1e-3) 149 | assert not jnp.allclose(output_q.array, 0) # don't want them to all underflow 150 | 151 | mlp_q = quantize_linear_layers(mlp, QuantizationConfig(targets="layers.0", fp8=True)) 152 | for i, layer in enumerate(mlp_q.layers): 153 | if i == 0: 154 | assert isinstance(layer.dot_general, Fp8DotGeneralOp) 155 | else: 156 | assert not isinstance(layer.dot_general, Fp8DotGeneralOp) 157 | 158 | mlp_q = quantize_linear_layers(mlp, QuantizationConfig(targets=["0", "1"], fp8=True)) 159 | for i, layer in enumerate(mlp_q.layers): 160 | if i < 2: 161 | assert isinstance(layer.dot_general, Fp8DotGeneralOp) 162 | else: 163 | assert not isinstance(layer.dot_general, Fp8DotGeneralOp) 164 | 165 | 166 | def test_fp8ize_stacking(): 167 | class Block(eqx.Module): 168 | up_proj: hax.nn.Linear 169 | down_proj: hax.nn.Linear 170 | 171 | @staticmethod 172 | def init(In, Out, key): 173 | up_proj = hax.nn.Linear.init(In, Out, key=key) 174 | down_proj = hax.nn.Linear.init(Out, In, key=key) 175 | return Block(up_proj, down_proj) 176 | 177 | def __call__(self, x): 178 | return self.down_proj(self.up_proj(x)) 179 | 180 | Layer = hax.Axis("Layer", 3) 181 | 182 | class Tformer(eqx.Module): 183 | blocks: hax.nn.Stacked[Block] 184 | 185 | @staticmethod 186 | def init(In, Out, key): 187 | blocks = hax.nn.Stacked.init(Layer, Block)(In, Out, key=jax.random.split(key, Layer.size)) 188 | return Tformer(blocks) 189 | 190 | def __call__(self, x): 191 | return self.blocks.fold(x) 192 | 193 | In = hax.Axis("In", 16) 194 | Out = hax.Axis("Out", 32) 195 | tformer = Tformer.init(In, Out, key=jrandom.PRNGKey(0)) 196 | tformer_q = quantize_linear_layers(tformer, QuantizationConfig(fp8=True)) 197 | 198 | # want to be sure this vmaps the dot_general to the right places 199 | dg = tformer_q.blocks.stacked.up_proj.dot_general 200 | assert isinstance(dg, Fp8DotGeneralOp) 201 | assert dg.input_scale.shape == (Layer.size, 1) 202 | assert dg.input_amax_history.shape == (Layer.size, 1024) 203 | dg = tformer_q.blocks.stacked.down_proj.dot_general 204 | assert isinstance(dg, Fp8DotGeneralOp) 205 | 206 | # just stack the up_proj 207 | tformer_q = quantize_linear_layers(tformer, QuantizationConfig(targets=["up_proj"], fp8=True)) 208 | dg = tformer_q.blocks.stacked.up_proj.dot_general 209 | assert isinstance(dg, Fp8DotGeneralOp) 210 | dg = tformer_q.blocks.stacked.down_proj.dot_general 211 | assert not isinstance(dg, Fp8DotGeneralOp) 212 | -------------------------------------------------------------------------------- /tests/test_int8.py: -------------------------------------------------------------------------------- 1 | import jax.random as jrandom 2 | from chex import assert_trees_all_close 3 | 4 | import haliax as hax 5 | from haliax.nn import Linear 6 | from haliax.quantization import Int8DotGeneralOp 7 | 8 | 9 | def test_int8_is_reasonable(): 10 | In = hax.Axis("In", 8) 11 | Out = hax.Axis("Out", 8) 12 | linear = Linear.init(In, Out, key=jrandom.PRNGKey(0), init_scale=0.1) 13 | 14 | int8_linear = Linear.init(In, Out, key=jrandom.PRNGKey(0), dot_general=Int8DotGeneralOp.init(), init_scale=0.1) 15 | 16 | input = hax.random.normal(jrandom.PRNGKey(3), In) 17 | output = linear(input) 18 | int8_output = int8_linear(input) 19 | 20 | assert output.shape == int8_output.shape 21 | assert output.dtype == int8_output.dtype 22 | 23 | assert_trees_all_close(output.array, int8_output.array, atol=1e-2, rtol=5e-2) 24 | -------------------------------------------------------------------------------- /tests/test_nn.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import equinox as eqx 4 | import jax.nn 5 | import jax.random as jrandom 6 | import pytest 7 | from jax import numpy as jnp 8 | 9 | import haliax as hax 10 | from haliax import Axis, NamedArray 11 | 12 | 13 | def _compare_eqx_and_haliax(hax_mod: eqx.Module, eqx_mod: eqx.Module): 14 | def f(x: NamedArray, *args, **kwargs): 15 | unnamed_x = x.array 16 | hax_out = hax_mod(x, *args, **kwargs) # type: ignore 17 | eqx_out = eqx_mod(unnamed_x, *args, **kwargs) # type: ignore 18 | 19 | assert jnp.allclose(hax_out.array, eqx_out) 20 | return hax_out 21 | 22 | return f 23 | 24 | 25 | def test_layer_norm(): 26 | H = Axis("H", 10) 27 | hax_ln = hax.nn.LayerNorm.init(H) 28 | eqx_ln = eqx.nn.LayerNorm(shape=(H.size,)) 29 | 30 | f = _compare_eqx_and_haliax(hax_ln, eqx_ln) 31 | out = f(hax.random.uniform(jrandom.PRNGKey(0), (H,))) 32 | 33 | assert out.axes == (H,) 34 | 35 | 36 | def test_dropout(): 37 | H = Axis("H", 10) 38 | key = jrandom.PRNGKey(0) 39 | hax_dropout = hax.nn.Dropout(0.5) 40 | eqx_dropout = eqx.nn.Dropout(0.5) 41 | 42 | f = _compare_eqx_and_haliax(hax_dropout, eqx_dropout) 43 | out = f(hax.random.uniform(jrandom.PRNGKey(0), (H,)), key=key, inference=False) 44 | 45 | assert out.axes == (H,) 46 | 47 | 48 | def test_one_hot(): 49 | i, c = hax.make_axes(i=3, c=3) 50 | actual = hax.nn.one_hot(hax.NamedArray(jnp.array([0, 1, 2]), (i,)), c) 51 | expected = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 52 | 53 | assert actual.axes == (i, c) 54 | assert jnp.all(jnp.isclose(actual.array, expected)) 55 | 56 | actual = hax.nn.one_hot(hax.NamedArray(jnp.array([1, 2, 0]), (i,)), c) 57 | expected = jnp.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) 58 | assert actual.axes == (i, c) 59 | assert jnp.all(jnp.isclose(actual.array, expected)) 60 | 61 | 62 | def test_one_hot_out_of_bound(): 63 | i, c = hax.make_axes(i=2, c=3) 64 | actual = hax.nn.one_hot(hax.NamedArray(jnp.array([-1, 3]), (i,)), c) 65 | expected = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 66 | assert jnp.all(jnp.isclose(actual.array, expected)) 67 | 68 | 69 | def test_standardize(): 70 | b, c = hax.make_axes(b=2, c=3) 71 | actual = hax.nn.standardize(hax.NamedArray(jnp.array([0, 1, 2]), (c,)), c) 72 | expected = jax.nn.standardize(jnp.array([0, 1, 2]), axis=0) 73 | 74 | assert actual.axes == (c,) 75 | assert jnp.all(jnp.isclose(actual.array, expected)) 76 | 77 | actual = hax.nn.standardize(hax.NamedArray(jnp.array([[0, 1, 2], [3, 4, 5]]), (b, c)), c) 78 | expected = jax.nn.standardize(jnp.array([[0, 1, 2], [3, 4, 5]]), axis=1) 79 | 80 | assert actual.axes == (b, c) 81 | assert jnp.all(jnp.isclose(actual.array, expected)) 82 | 83 | actual = hax.nn.standardize(hax.NamedArray(jnp.array([[0, 1, 2], [3, 4, 5]]), (b, c)), b) 84 | expected = jax.nn.standardize(jnp.array([[0, 1, 2], [3, 4, 5]]), axis=0) 85 | 86 | assert actual.axes == (b, c) 87 | assert jnp.all(jnp.isclose(actual.array, expected)) 88 | 89 | # test passing in where 90 | mask = hax.NamedArray(jnp.array([True, False, True]), (c,)) 91 | actual = hax.nn.standardize(hax.NamedArray(jnp.array([[0, 1, 2], [3, 4, 5]]), (b, c)), b, where=mask) 92 | expected = jax.nn.standardize(jnp.array([[0, 1, 2], [3, 4, 5]]), axis=0, where=mask.array) 93 | 94 | assert actual.axes == (b, c) 95 | assert jnp.all(jnp.isclose(actual.array, expected) | jnp.isnan(expected)) 96 | 97 | # now mean/variance 98 | input = hax.NamedArray(jnp.array([[0, 1, 2], [3, 4, 5]]), (b, c)) 99 | mean = hax.mean(input, c) 100 | variance = hax.var(input, c) 101 | actual = hax.nn.standardize(input, c, mean=mean, variance=variance) 102 | expected = jax.nn.standardize( 103 | input.array, axis=1, mean=mean.array.reshape(-1, 1), variance=variance.array.reshape(-1, 1) 104 | ) 105 | 106 | assert actual.axes == (b, c) 107 | assert jnp.all(jnp.isclose(actual.array, expected)) 108 | 109 | 110 | @pytest.mark.parametrize("depth", [0, 1, 2, 3, 4, 5]) 111 | def test_mlp(depth): 112 | key = jrandom.PRNGKey(0) 113 | H, C, W, E = hax.make_axes(H=10, C=12, W=14, E=16) 114 | 115 | hax_mlp = hax.nn.MLP.init((H, C, W), E, width=8, depth=depth, key=key) 116 | x = hax.random.uniform(key, (H, C, W)) 117 | assert hax_mlp(x).axes == (E,) 118 | 119 | hax_mlp = hax.nn.MLP.init((H, W), E, width=8, depth=depth, key=key) 120 | assert hax_mlp(x).axes == (C, E) 121 | 122 | # with a named width 123 | M = Axis("M", 18) 124 | hax_mlp = hax.nn.MLP.init((H, W), E, width=M, depth=depth, key=key) 125 | assert hax_mlp(x).axes == (C, E) 126 | 127 | # ensure that we actually use the name for the named width 128 | if depth > 0: 129 | assert hax_mlp.layers[0].Out == M 130 | 131 | if depth % 2 == 0: 132 | assert hax_mlp.layers[-1].In == M.alias("M2") 133 | else: 134 | assert hax_mlp.layers[-1].In == M 135 | 136 | for i in range(1, depth): 137 | if i % 2 == 0: 138 | assert hax_mlp.layers[i].In == M.alias("M2") 139 | assert hax_mlp.layers[i].Out == M 140 | else: 141 | assert hax_mlp.layers[i].In == M 142 | assert hax_mlp.layers[i].Out == M.alias("M2") 143 | 144 | 145 | def test_linear_has_no_function_leaves_by_default(): 146 | H, C, W, E = hax.make_axes(H=10, C=12, W=14, E=16) 147 | 148 | hax_linear = hax.nn.Linear.init((H, C, W), E, key=jrandom.PRNGKey(0)) 149 | assert all(not isinstance(v, Callable) for v in jax.tree_util.tree_leaves(hax_linear)) # type: ignore 150 | -------------------------------------------------------------------------------- /tests/test_parsing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from haliax._src.parsing import parse_einsum, parse_rearrangement 4 | 5 | 6 | def _simplify_captures(expr): 7 | def simplify_capture(capture): 8 | if capture == Ellipsis: 9 | return Ellipsis 10 | elif (capture.binding == capture.axes[0] or capture.binding is None) and len(capture.axes) == 1: 11 | return capture.axes[0] 12 | elif capture.binding is None: 13 | return capture.axes 14 | else: 15 | return {capture.binding: capture.axes} 16 | 17 | return [simplify_capture(capture) for capture in expr.captures] 18 | 19 | 20 | def test_parse_rearrangement_simple(): 21 | lhs, rhs = parse_rearrangement("a b c d -> b c a d") 22 | assert lhs.is_ordered 23 | assert _simplify_captures(lhs) == ["a", "b", "c", "d"] 24 | assert rhs.is_ordered 25 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"] 26 | 27 | lhs, rhs = parse_rearrangement("a ... c d -> b c a d") 28 | assert lhs.is_ordered 29 | assert _simplify_captures(lhs) == ["a", Ellipsis, "c", "d"] 30 | assert rhs.is_ordered 31 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"] 32 | 33 | # longer identifiers 34 | lhs, rhs = parse_rearrangement("a_longer b123 c d -> b123 c a_longer d") 35 | assert lhs.is_ordered 36 | assert _simplify_captures(lhs) == ["a_longer", "b123", "c", "d"] 37 | assert rhs.is_ordered 38 | assert _simplify_captures(rhs) == ["b123", "c", "a_longer", "d"] 39 | 40 | 41 | def test_parse_einsum_ordered(): 42 | # We could support this syntax for dot with something like: 43 | # 44 | # support normal einops syntax, including short name-capture: hax.dot("... c h w, h w d -> ... c d", a, b) 45 | # hax.dot("{h, w} -> ", a, b) means "contract h and w", analogous to hax.dot(a, b, axis=("h", "w")) 46 | # hax.dot("{h, w} -> ... channel embed", a, b) means "contract h and w and ensure that the result ends with [channel, embed]" (by transposing/einsum) 47 | # hax.dot(" -> batch channel embed", a, b) could mean "contract all but the named dims". Not entirely sure how I feel about that one, but used situationally it's probably ok 48 | 49 | lhses, rhs = parse_einsum("a b c d, b c e f -> a d e f") 50 | assert lhses is not None 51 | assert len(lhses) == 2 52 | assert all(lhs.is_ordered for lhs in lhses) 53 | lhs0_captures = _simplify_captures(lhses[0]) 54 | lhs1_captures = _simplify_captures(lhses[1]) 55 | assert lhs0_captures == ["a", "b", "c", "d"] 56 | assert lhs1_captures == ["b", "c", "e", "f"] 57 | assert rhs.is_ordered 58 | assert _simplify_captures(rhs) == ["a", "d", "e", "f"] 59 | 60 | lhses, rhs = parse_einsum("... c h w, h w d -> ... c d") 61 | assert lhses is not None 62 | assert len(lhses) == 2 63 | assert all(lhs.is_ordered for lhs in lhses) 64 | lhs0_captures = _simplify_captures(lhses[0]) 65 | lhs1_captures = _simplify_captures(lhses[1]) 66 | assert lhs0_captures == [..., "c", "h", "w"] 67 | assert lhs1_captures == ["h", "w", "d"] 68 | assert rhs.is_ordered 69 | assert _simplify_captures(rhs) == [..., "c", "d"] 70 | 71 | lhses, rhs = parse_einsum("{...} -> batch channel embed") 72 | assert lhses is not None 73 | assert len(lhses) == 1 74 | assert not lhses[0].is_ordered 75 | assert _simplify_captures(lhses[0]) == [...] 76 | assert rhs.is_ordered 77 | assert _simplify_captures(rhs) == ["batch", "channel", "embed"] 78 | 79 | # just lhs 80 | lhses, rhs = parse_einsum("batch channel embed -> ") 81 | assert lhses is not None 82 | assert len(lhses) == 1 83 | assert lhses[0].is_ordered 84 | assert _simplify_captures(lhses[0]) == ["batch", "channel", "embed"] 85 | assert rhs.is_ordered 86 | assert _simplify_captures(rhs) == [] 87 | 88 | # lhs x 2, empty rhs 89 | lhses, rhs = parse_einsum("batch channel embed, batch channel embed ->") 90 | assert lhses is not None 91 | assert len(lhses) == 2 92 | assert all(lhs.is_ordered for lhs in lhses) 93 | assert _simplify_captures(lhses[0]) == ["batch", "channel", "embed"] 94 | assert _simplify_captures(lhses[1]) == ["batch", "channel", "embed"] 95 | assert rhs.is_ordered 96 | assert _simplify_captures(rhs) == [] 97 | 98 | 99 | def test_parse_einsum_unordered(): 100 | lhses, rhs = parse_einsum("{a, b} -> ") 101 | assert lhses is not None 102 | assert len(lhses) == 1 103 | assert not lhses[0].is_ordered 104 | assert _simplify_captures(lhses[0]) == ["a", "b"] 105 | assert rhs.is_ordered 106 | assert _simplify_captures(rhs) == [] 107 | 108 | lhses, rhs = parse_einsum("{...} -> ") 109 | assert lhses is not None 110 | assert len(lhses) == 1 111 | assert not lhses[0].is_ordered 112 | assert _simplify_captures(lhses[0]) == [...] 113 | assert rhs.is_ordered 114 | assert _simplify_captures(rhs) == [] 115 | 116 | lhses, rhs = parse_einsum("{h, w} -> ... channel embed") 117 | assert lhses is not None 118 | assert len(lhses) == 1 119 | assert not lhses[0].is_ordered 120 | assert _simplify_captures(lhses[0]) == ["h", "w"] 121 | assert rhs.is_ordered 122 | assert _simplify_captures(rhs) == [..., "channel", "embed"] 123 | 124 | 125 | def test_parse_paren_groups(): 126 | lhs, rhs = parse_rearrangement("a (b c) d -> b c a d") 127 | assert lhs.is_ordered 128 | assert _simplify_captures(lhs) == ["a", ("b", "c"), "d"] 129 | assert rhs.is_ordered 130 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"] 131 | 132 | lhs, rhs = parse_rearrangement("a (b: c) (d: e f) -> b c a d") 133 | assert lhs.is_ordered 134 | assert _simplify_captures(lhs) == ["a", {"b": ("c",)}, {"d": ("e", "f")}] 135 | 136 | 137 | def test_parse_unordered(): 138 | lhs, rhs = parse_rearrangement("{a b c d} -> {b c a d}") 139 | assert not lhs.is_ordered 140 | assert _simplify_captures(lhs) == ["a", "b", "c", "d"] 141 | assert not rhs.is_ordered 142 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"] 143 | 144 | lhs, rhs = parse_rearrangement("{(c: a b) d e} -> (q: a d e) b") 145 | assert not lhs.is_ordered 146 | assert _simplify_captures(lhs) == [{"c": ("a", "b")}, "d", "e"] 147 | assert rhs.is_ordered 148 | assert _simplify_captures(rhs) == [{"q": ("a", "d", "e")}, "b"] 149 | 150 | 151 | def test_parse_quoted_identifiers(): 152 | lhs, rhs = parse_rearrangement("a \"b c\" d -> 'b c' a d") 153 | assert lhs.is_ordered 154 | assert _simplify_captures(lhs) == ["a", "b c", "d"] 155 | assert rhs.is_ordered 156 | assert _simplify_captures(rhs) == ["b c", "a", "d"] 157 | 158 | lhs, rhs = parse_rearrangement("{a \"b c\" (d: 'hello')} -> b c a d") 159 | assert not lhs.is_ordered 160 | assert _simplify_captures(lhs) == ["a", "b c", {"d": ("hello",)}] 161 | assert rhs.is_ordered 162 | assert _simplify_captures(rhs) == ["b", "c", "a", "d"] 163 | 164 | 165 | def test_parse_errors(): 166 | with pytest.raises(ValueError, match="Unexpected end of string"): 167 | parse_rearrangement("a b") 168 | 169 | with pytest.raises(ValueError, match="Expected }"): 170 | parse_rearrangement("{ a -> c") 171 | 172 | with pytest.raises(ValueError, match="Unexpected }"): 173 | parse_rearrangement("a } -> c") 174 | 175 | with pytest.raises(ValueError, match="Unexpected }"): 176 | parse_rearrangement("(a: b } -> c") 177 | 178 | with pytest.raises(ValueError, match="Unexpected character"): 179 | parse_rearrangement("(a b ! -> c d e") 180 | 181 | with pytest.raises(ValueError, match="Identifier expected"): 182 | parse_rearrangement("a b ! -> c d e") 183 | -------------------------------------------------------------------------------- /tests/test_pool.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | import haliax 6 | import haliax as hax 7 | from haliax.nn.pool import max_pool, mean_pool 8 | 9 | 10 | # Tests largely cribbed from equinox 11 | 12 | 13 | def test_maxpool1d(): 14 | D = hax.Axis("D", 14) 15 | x = hax.arange(D) 16 | output = max_pool((D.resize(2),), x, stride=(3,)) 17 | answer = jnp.array([1, 4, 7, 10, 13], dtype=jnp.int32) 18 | 19 | assert jnp.all(output.array == answer) 20 | 21 | answer = jnp.array([2, 5, 8, 11]) 22 | output = max_pool(D.resize(3), x, stride=(3,), padding=0) 23 | assert jnp.all(output.array == answer) 24 | 25 | # max_pool = eqx.nn.MaxPool1d(kernel_size=3, stride=3, padding=0, use_ceil=True) 26 | answer = jnp.array([2, 5, 8, 11, 13]) 27 | output = max_pool(D.resize(3), x, stride=(3,), padding=0, use_ceil=True) 28 | assert jnp.all(output.array == answer) 29 | 30 | # test batch axes 31 | B = hax.Axis("B", 2) 32 | x = x.rearrange("(B D) -> B D", B=B) 33 | output = max_pool(D.resize(2), x, stride=(3,), padding="VALID") 34 | answer = jnp.array([[1, 4], [8, 11]]) 35 | assert jnp.all(output.array == answer) 36 | 37 | output = max_pool(D.resize(2), x, stride=(3,), use_ceil=True) 38 | answer = jnp.array([[1, 4, 6], [8, 11, 13]]) 39 | assert jnp.all(output.array == answer) 40 | 41 | output = max_pool(D.resize(3), x, stride=(3,), padding=0) 42 | answer = jnp.array([[2, 5], [9, 12]]) 43 | assert jnp.all(output.array == answer) 44 | 45 | output = max_pool(D.resize(3), x, stride=(3,), padding=0, use_ceil=True) 46 | answer = jnp.array([[2, 5, 6], [9, 12, 13]]) 47 | assert jnp.all(output.array == answer) 48 | 49 | output = max_pool(D.resize(2), x, stride=(3,), padding="SAME") 50 | answer = jnp.array([[1, 4, 6], [8, 11, 13]]) 51 | assert jnp.all(output.array == answer) 52 | 53 | 54 | def test_maxpool2d(): 55 | _x = jnp.arange(36).reshape(6, 6) 56 | x = hax.named(_x, ("H", "W")) 57 | 58 | # max_pool = eqx.nn.MaxPool2d(2, (3, 2)) 59 | output = max_pool((hax.Axis("H", 2), hax.Axis("W", 2)), x, stride=(3, 2)) 60 | answer = jnp.array([[7, 9, 11], [25, 27, 29]]) 61 | 62 | assert jnp.all(output.array == answer) 63 | 64 | output = max_pool((hax.Axis("H", 3), hax.Axis("W", 3)), x, stride=2, padding=1) 65 | answer = jnp.array([[7, 9, 11], [19, 21, 23], [31, 33, 35]]) 66 | 67 | assert jnp.all(output.array == answer) 68 | 69 | # test batch axes 70 | B = hax.Axis("B", 2) 71 | x = haliax.stack(B, [x, x]) 72 | 73 | output = max_pool((hax.Axis("H", 2), hax.Axis("W", 2)), x, stride=(3, 2)) 74 | answer = jnp.array([[[7, 9, 11], [25, 27, 29]], [[7, 9, 11], [25, 27, 29]]]) 75 | 76 | assert jnp.all(output.array == answer) 77 | 78 | 79 | def test_maxpool3d(): 80 | _x = jnp.arange(64).reshape(4, 4, 4) 81 | x = hax.named(_x, ("H", "W", "D")) 82 | output = max_pool((hax.Axis("H", 2), hax.Axis("W", 2), hax.Axis("D", 2)), x, stride=(3, 2, 1)) 83 | 84 | answer = jnp.array([[[21, 22, 23], [29, 30, 31]]]) 85 | 86 | assert jnp.all(output.array == answer) 87 | 88 | answer = jnp.asarray( 89 | [ 90 | [[37, 39, 39], [45, 47, 47], [45, 47, 47]], 91 | [[53, 55, 55], [61, 63, 63], [61, 63, 63]], 92 | ] 93 | ) 94 | output = max_pool( 95 | (hax.Axis("H", 3), hax.Axis("W", 3), hax.Axis("D", 3)), 96 | x, 97 | stride=2, 98 | padding=((0, 1), (1, 1), (1, 1)), 99 | use_ceil=True, 100 | ) 101 | assert jnp.all(output.array == answer) 102 | 103 | 104 | def test_mean_pool_1d(): 105 | D = hax.Axis("D", 14) 106 | x = hax.arange(D) 107 | output = mean_pool((D.resize(2),), x, stride=(3,)) 108 | answer = jnp.array([0.5, 3.5, 6.5, 9.5, 12.5]) 109 | 110 | assert jnp.all(output.array == answer) 111 | 112 | # no pad 113 | output = mean_pool(D.resize(3), x, stride=(3,), padding=0) 114 | answer = jnp.array([1, 4, 7, 10]) 115 | assert jnp.all(output.array == answer) 116 | 117 | # pad, no include pad in avg 118 | output = mean_pool(D.resize(3), x, stride=(3,), padding="SAME", count_include_pad=False) 119 | answer = jnp.array([1, 4, 7, 10, 12.5]) 120 | 121 | assert jnp.all(output.array == answer) 122 | 123 | output = mean_pool(D.resize(3), x, stride=(3,), padding="SAME", count_include_pad=True) 124 | answer = jnp.array([1, 4, 7, 10, (12 + 13) / 3.0]) 125 | 126 | assert jnp.all(output.array == answer) 127 | 128 | 129 | def test_mean_pool_2d(): 130 | _x = jnp.arange(36).reshape(6, 6) 131 | x = hax.named(_x, ("H", "W")) 132 | 133 | output = mean_pool((hax.Axis("H", 1), hax.Axis("W", 3)), x, stride=2) 134 | answer = jnp.array([[1, 3], [13, 15], [25, 27]]) 135 | 136 | assert jnp.all(output.array == answer) 137 | 138 | # test batch axes 139 | B = hax.Axis("B", 2) 140 | x = haliax.stack(B, [x, x]) 141 | 142 | output = mean_pool((hax.Axis("H", 1), hax.Axis("W", 3)), x, stride=2) 143 | answer = jnp.array([[[1, 3], [13, 15], [25, 27]], [[1, 3], [13, 15], [25, 27]]]) 144 | 145 | assert jnp.all(output.array == answer) 146 | 147 | 148 | def test_mean_pool3d(): 149 | _x = jnp.arange(64).reshape(4, 4, 4) 150 | x = hax.named(_x, ("H", "W", "D")) 151 | output = mean_pool((hax.Axis("H", 1), hax.Axis("W", 3), hax.Axis("D", 1)), x, stride=2) 152 | 153 | answer = jnp.array([[[4, 6]], [[36, 38]]]) 154 | 155 | assert jnp.all(output.array == answer) 156 | 157 | 158 | def test_pool_backprop(): 159 | def max_pool_mean(x): 160 | pooled = max_pool( 161 | (hax.Axis("H", 2), hax.Axis("W", 2), hax.Axis("D", 2)), x, stride=1, padding=((0, 1), (0, 1), (0, 1)) 162 | ) 163 | return hax.mean(pooled).scalar() 164 | 165 | _x = jnp.arange(64, dtype=jnp.float32).reshape(1, 4, 4, 4) 166 | x = hax.named(_x, ("B", "H", "W", "D")) 167 | grad_fn = jax.value_and_grad(max_pool_mean) 168 | 169 | hax_loss, hax_grad = grad_fn(x) 170 | 171 | # compare it to eqx 172 | 173 | eqx_max_pool = eqx.nn.MaxPool3d(2, (1, 1, 1), padding=((0, 1), (0, 1), (0, 1))) 174 | 175 | def eqx_max_pool_mean(x): 176 | pooled = eqx_max_pool(x) 177 | return pooled.mean() 178 | 179 | eqx_grad_fn = jax.value_and_grad(eqx_max_pool_mean) 180 | eqx_loss, eqx_grad = eqx_grad_fn(_x) 181 | 182 | assert jnp.allclose(hax_loss, eqx_loss) 183 | assert jnp.allclose(hax_grad.array, eqx_grad) 184 | -------------------------------------------------------------------------------- /tests/test_specialized_fns.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | import haliax 5 | import haliax.specialized_fns as hfns 6 | from haliax import NamedArray 7 | 8 | 9 | def test_top_k(): 10 | H, W, D = haliax.make_axes(H=3, W=4, D=5) 11 | 12 | rand = jax.random.uniform(jax.random.PRNGKey(0), (H.size, W.size, D.size)) 13 | n_rand = NamedArray(rand, (H, W, D)) 14 | 15 | values, indices = hfns.top_k(n_rand, D, 2) 16 | 17 | assert values.array.shape == (H.size, W.size, 2) 18 | assert indices.array.shape == (H.size, W.size, 2) 19 | assert jnp.all( 20 | jnp.equal(jax.lax.top_k(rand, 2)[0], values.array) 21 | ) # test that selecting last axis is same as default 22 | assert jnp.all( 23 | jnp.equal(jnp.moveaxis(n_rand.take(D, indices).array, 0, -1), values.array) 24 | ) # test that indexing using indices is same as selected values 25 | 26 | for idx, i in enumerate(n_rand.axes): # then test selecting all axes 27 | t = jnp.transpose(rand, (*range(idx), *range(idx + 1, len(n_rand.axes)), idx)) 28 | t = jax.lax.top_k(t, 2)[0] 29 | t = jnp.moveaxis(t, -1, idx) 30 | values, indices = hfns.top_k(n_rand, i, 2) 31 | assert jnp.all(jnp.equal(t, values.array)) 32 | assert jnp.all(jnp.equal(jnp.moveaxis(n_rand.take(i, indices).array, 0, idx), values.array)) 33 | -------------------------------------------------------------------------------- /tests/test_state_dict.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any 3 | 4 | import equinox as eqx 5 | import jax 6 | import jax.numpy as jnp 7 | import pytest 8 | 9 | import haliax as hax 10 | from haliax._src.state_dict import flatten_modules_for_export, unflatten_modules_from_export 11 | from haliax.nn import Linear 12 | from haliax.nn.scan import Stacked, _stack_state_dict, _unstack_state_dict 13 | from haliax.state_dict import from_state_dict, to_state_dict 14 | 15 | 16 | @pytest.mark.parametrize("out_dims_first", [True, False]) 17 | def test_flatten_linear_layers(out_dims_first: bool): 18 | H = hax.Axis("H", 10) 19 | W = hax.Axis("W", 20) 20 | D = hax.Axis("D", 30) 21 | B = hax.Axis("B", 40) 22 | linear = hax.nn.Linear.init((H, W), (D, B), key=jax.random.PRNGKey(0), use_bias=True, out_first=out_dims_first) 23 | 24 | if out_dims_first: 25 | assert linear.weight.axes == (D, B, H, W) 26 | else: 27 | assert linear.weight.axes == (H, W, D, B) 28 | 29 | flat_linear = linear.flatten_for_export() 30 | 31 | flat_state_dict = to_state_dict(flat_linear) 32 | if out_dims_first: 33 | assert flat_state_dict["weight"].shape == (D.size * B.size, H.size * W.size) 34 | else: 35 | assert flat_state_dict["weight"].shape == (H.size * W.size, D.size * B.size) 36 | assert flat_state_dict["bias"].shape == (D.size * B.size,) 37 | assert flat_state_dict["weight"].dtype == flat_state_dict["bias"].dtype == linear.weight.dtype 38 | 39 | # now unflatten it 40 | linear2 = Linear.init((H, W), (D, B), key=jax.random.PRNGKey(1), use_bias=True, out_first=out_dims_first) 41 | new_linear = flat_linear.unflatten_from_export(linear2) 42 | 43 | if out_dims_first: 44 | assert new_linear.weight.axes == (D, B, H, W) 45 | else: 46 | assert new_linear.weight.axes == (H, W, D, B) 47 | assert new_linear.bias.axes == (D, B) # type: ignore 48 | 49 | assert linear == new_linear 50 | 51 | 52 | # Test cases for stack_state_dict 53 | @pytest.mark.parametrize( 54 | "input_dict, prefix, expected_output", 55 | [ 56 | # Single block stacking 57 | ( 58 | { 59 | "block.0.weight": jnp.array([1, 2]), 60 | "block.0.bias": jnp.array([3]), 61 | "block.1.weight": jnp.array([4, 5]), 62 | "block.1.bias": jnp.array([6]), 63 | }, 64 | "block", 65 | { 66 | "block.weight": jnp.array([[1, 2], [4, 5]]), 67 | "block.bias": jnp.array([[3], [6]]), 68 | }, 69 | ), 70 | # Mixed data types and unmatched items remain unchanged 71 | ( 72 | { 73 | "block.0.weight": jnp.array([1, 2]), 74 | "block.0.bias": jnp.array([3]), 75 | "block.1.weight": jnp.array([4, 5]), 76 | "block.1.bias": jnp.array([6.0]), 77 | "unrelated.item": jnp.array([7]), 78 | }, 79 | "block", 80 | { 81 | "block.weight": jnp.array([[1, 2], [4, 5]]), 82 | "block.bias": jnp.array([[3.0], [6.0]]), 83 | "unrelated.item": jnp.array([7]), 84 | }, 85 | ), 86 | # No items match prefix, all items should remain unchanged 87 | ( 88 | { 89 | "module.0.param": jnp.array([1]), 90 | "module.1.param": jnp.array([2]), 91 | }, 92 | "block", 93 | { 94 | "module.0.param": jnp.array([1]), 95 | "module.1.param": jnp.array([2]), 96 | }, 97 | ), 98 | ], 99 | ) 100 | def test_stack_state_dict(input_dict, prefix, expected_output): 101 | result = _stack_state_dict(input_dict, prefix) 102 | for key in expected_output: 103 | assert jnp.all(jnp.array_equal(result[key], expected_output[key])), f"Failed on key: {key}" 104 | 105 | # now unstack it 106 | unstacked = _unstack_state_dict(result, prefix) 107 | for key in input_dict: 108 | assert jnp.all(jnp.array_equal(unstacked[key], input_dict[key])), f"Failed on key: {key}" 109 | 110 | 111 | class M(eqx.Module): 112 | a: Any 113 | b: Any 114 | 115 | def __init__(self, a, b): 116 | self.a = a 117 | self.b = b 118 | 119 | 120 | def test_to_from_state_dict(): 121 | a = jnp.array([1, 2]) 122 | b = jnp.array([3, 4]) 123 | m = M(a, b) 124 | 125 | state_dict = to_state_dict(m) 126 | assert state_dict == {"a": a, "b": b} 127 | 128 | m2 = M(jnp.array([0, 0]), jnp.array([0, 0])) 129 | m2 = from_state_dict(m2, state_dict) 130 | assert jnp.all(m2.a == a) 131 | assert jnp.all(m2.b == b) 132 | 133 | 134 | def test_export_layer_norm(): 135 | D = hax.Axis("D", 10) 136 | E = hax.Axis("E", 20) 137 | layer_norm = hax.nn.LayerNorm.init((D, E), eps=1e-5, use_weight=True, use_bias=True) 138 | 139 | flat_layer_norm = layer_norm.flatten_for_export() 140 | 141 | flat_state_dict = to_state_dict(flat_layer_norm) 142 | 143 | assert flat_state_dict["weight"].shape == (D.size * E.size,) 144 | assert flat_state_dict["bias"].shape == (D.size * E.size,) 145 | assert flat_state_dict["weight"].dtype == flat_state_dict["bias"].dtype == layer_norm.weight.dtype 146 | 147 | # now unflatten it 148 | layer_norm2 = hax.nn.LayerNorm.init((D, E), eps=1e-5, use_weight=True, use_bias=True) 149 | # ensure we have different weights 150 | layer_norm2 = dataclasses.replace(layer_norm2, weight=layer_norm2.weight + 1, bias=layer_norm2.bias + 1) 151 | 152 | new_layer_norm = flat_layer_norm.unflatten_from_export(layer_norm2) 153 | 154 | assert layer_norm == new_layer_norm 155 | 156 | 157 | def test_stacked_layer_norm(): 158 | L = hax.Axis("L", 4) 159 | D = hax.Axis("D", 10) 160 | E = hax.Axis("E", 20) 161 | 162 | norms = Stacked.init(L, hax.nn.LayerNorm)((D, E), eps=1e-5, use_weight=True, use_bias=True) 163 | 164 | norms_flat = flatten_modules_for_export(norms) 165 | 166 | flat_state_dict = to_state_dict(norms_flat) 167 | 168 | assert flat_state_dict["0.weight"].shape == (D.size * E.size,) 169 | assert flat_state_dict["0.bias"].shape == (D.size * E.size,) 170 | assert flat_state_dict["1.weight"].shape == (D.size * E.size,) 171 | 172 | # now unflatten it 173 | norms2 = Stacked.init(L, hax.nn.LayerNorm)((D, E), eps=1e-5, use_weight=True, use_bias=True) 174 | 175 | new_norms = unflatten_modules_from_export(norms_flat, norms2) 176 | 177 | assert norms == new_norms 178 | -------------------------------------------------------------------------------- /tests/test_tree_util.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import equinox as eqx 4 | import jax 5 | import jax.numpy as jnp 6 | from chex import assert_trees_all_close 7 | 8 | import haliax as hax 9 | import haliax.tree_util as htu 10 | from haliax import Axis 11 | 12 | 13 | def test_resize_axis(): 14 | A, B, C = hax.make_axes(A=10, B=20, C=30) 15 | 16 | class Module(eqx.Module): 17 | name1: hax.NamedArray 18 | name2: hax.NamedArray 19 | name3: hax.NamedArray 20 | 21 | module = Module( 22 | name1=hax.random.normal(jax.random.PRNGKey(0), (B, A, C)), 23 | name2=hax.zeros((B, C)), 24 | name3=hax.zeros((Axis("A", 20),)), 25 | ) 26 | 27 | NewA = A.resize(15) 28 | 29 | module2 = htu.resize_axis(module, "A", 15, key=jax.random.PRNGKey(1)) 30 | 31 | assert module2.name1.axes == (B, NewA, C) 32 | assert module2.name2.axes == (B, C) 33 | assert module2.name3.axes == (NewA,) 34 | 35 | # we don't mess with the mean or std of the original array too much 36 | assert jnp.allclose(module2.name1.mean(), module.name1.mean(), rtol=1e-1, atol=1e-2) 37 | 38 | 39 | def test_scan_aware_tree_map(): 40 | Embed = hax.Axis("embed", 10) 41 | Up = hax.Axis("up", 20) 42 | Block = hax.Axis("block", 4) 43 | 44 | class Module(eqx.Module): 45 | up: hax.nn.Linear 46 | down: hax.nn.Linear 47 | 48 | def __call__(self, x, *, key): 49 | return self.down(self.up(x), key=key) 50 | 51 | @staticmethod 52 | def init(layer_idx, *, key): 53 | k1, k2 = jax.random.split(key) 54 | up = hax.nn.Linear.init(Embed, Up, key=k1) 55 | down = hax.nn.Linear.init(Up, Embed, key=k2) 56 | 57 | up = dataclasses.replace(up, weight=up.weight + layer_idx) # type: ignore 58 | down = dataclasses.replace(down, weight=down.weight + layer_idx) # type: ignore 59 | 60 | return Module(up=up, down=down) 61 | 62 | class Model(eqx.Module): 63 | layers: hax.nn.Stacked[eqx.Module] 64 | 65 | def __call__(self, x, *, key): 66 | return self.layers.fold(x, key=jax.random.split(key, self.layers.Block.size)) 67 | 68 | @staticmethod 69 | def init(Layers, *, key): 70 | stack = hax.nn.Stacked.init(Layers, Module)( 71 | layer_idx=hax.arange(Layers), key=jax.random.split(key, Layers.size) 72 | ) 73 | return Model(layers=stack) 74 | 75 | model = Model.init(Block, key=jax.random.PRNGKey(0)) 76 | 77 | def transform_linear(x): 78 | if not isinstance(x, hax.nn.Linear): 79 | return x 80 | 81 | # do something that distinguishes doing weights jointly from independently 82 | new_weight = x.weight - hax.mean(x.weight) 83 | return dataclasses.replace(x, weight=new_weight) # type: ignore 84 | 85 | model2 = htu.scan_aware_tree_map(transform_linear, model, is_leaf=lambda x: isinstance(x, hax.nn.Linear)) 86 | model3 = htu.tree_map(transform_linear, model, is_leaf=lambda x: isinstance(x, hax.nn.Linear)) 87 | 88 | assert hax.all(model2.layers.stacked.up.weight != model3.layers.stacked.up.weight) 89 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import jax 4 | import pytest 5 | 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | def skip_if_not_enough_devices(count: int): 11 | return pytest.mark.skipif(len(jax.devices()) < count, reason=f"Not enough devices ({len(jax.devices())})") 12 | 13 | 14 | def has_torch(): 15 | try: 16 | import torch # noqa F401 17 | 18 | return True 19 | except ImportError: 20 | return False 21 | 22 | 23 | def skip_if_no_torch(f): 24 | return pytest.mark.skipif(not has_torch(), reason="torch not installed")(f) 25 | --------------------------------------------------------------------------------