├── docs ├── .htaccess ├── _static │ ├── favicon.png │ ├── mathjax.js │ └── custom_css.css ├── api │ ├── linear_solve.md │ ├── solution.md │ ├── functions.md │ ├── operators.md │ ├── solvers.md │ └── tags.md ├── _overrides │ ├── bluesky.svg │ └── partials │ │ └── source.html ├── examples │ ├── classical_solve.ipynb │ ├── complex_solve.ipynb │ ├── no_materialisation.ipynb │ ├── structured_matrices.ipynb │ └── operators.ipynb ├── faq.md └── index.md ├── tests ├── requirements.txt ├── README.md ├── __init__.py ├── conftest.py ├── __main__.py ├── test_misc.py ├── test_lsmr.py ├── test_jvp_jvp1.py ├── test_jvp_jvp2.py ├── test_transpose.py ├── test_well_posed.py ├── test_norm.py ├── test_adjoint.py ├── test_vmap.py ├── test_jvp.py ├── test_vmap_vmap.py ├── test_vmap_jvp.py ├── test_solve.py └── test_singular.py ├── .gitignore ├── .github └── workflows │ ├── release.yml │ ├── build_docs.yml │ └── run_tests.yml ├── lineax ├── _custom_types.py ├── _solver │ ├── __init__.py │ ├── lu.py │ ├── cholesky.py │ ├── tridiagonal.py │ ├── svd.py │ ├── triangular.py │ ├── qr.py │ ├── diagonal.py │ ├── misc.py │ └── bicgstab.py ├── internal │ └── __init__.py ├── _tags.py ├── __init__.py ├── _solution.py ├── _misc.py └── _norm.py ├── .pre-commit-config.yaml ├── benchmarks ├── lstsq_gradients.py ├── gmres_fails_safely.py └── solver_speeds.py ├── CONTRIBUTING.md ├── pyproject.toml ├── README.md ├── mkdocs.yml └── LICENSE /docs/.htaccess: -------------------------------------------------------------------------------- 1 | ErrorDocument 404 /jaxtyping/404.html 2 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | beartype 2 | equinox 3 | pytest 4 | pytest-xdist 5 | jaxlib 6 | -------------------------------------------------------------------------------- /docs/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrick-kidger/lineax/HEAD/docs/_static/favicon.png -------------------------------------------------------------------------------- /docs/api/linear_solve.md: -------------------------------------------------------------------------------- 1 | # linear_solve 2 | 3 | This is the main entry point. 4 | 5 | ::: lineax.linear_solve 6 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | Each file is run separately to avoid JAX out-of-memory'ing. 2 | 3 | As such, run tests using `python -m tests`, *not* by just running `pytest`. 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **/.ipynb_checkpoints 3 | *.egg-info/ 4 | build/ 5 | dist/ 6 | site/ 7 | examples/data 8 | .all_objects.cache 9 | .pymon 10 | .idea 11 | .venv 12 | -------------------------------------------------------------------------------- /docs/api/solution.md: -------------------------------------------------------------------------------- 1 | # Solution 2 | 3 | ::: lineax.Solution 4 | options: 5 | members: [] 6 | 7 | --- 8 | 9 | ::: lineax.RESULTS 10 | options: 11 | members: [] 12 | -------------------------------------------------------------------------------- /docs/_static/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) 17 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /docs/_overrides/bluesky.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Release 13 | uses: patrick-kidger/action_update_python_project@v6 14 | with: 15 | python-version: "3.11" 16 | test-script: | 17 | cp -r ${{ github.workspace }}/tests ./tests 18 | cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml 19 | python -m pip install -r ./tests/requirements.txt 20 | python -m tests 21 | pypi-token: ${{ secrets.pypi_token }} 22 | github-user: patrick-kidger 23 | github-token: ${{ github.token }} 24 | -------------------------------------------------------------------------------- /lineax/_custom_types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any 16 | 17 | import equinox.internal as eqxi 18 | 19 | 20 | sentinel: Any = eqxi.doc_repr(object(), "sentinel") 21 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import equinox.internal as eqxi 16 | import jax 17 | import pytest 18 | 19 | 20 | jax.config.update("jax_enable_x64", True) 21 | jax.config.update("jax_numpy_dtype_promotion", "strict") 22 | jax.config.update("jax_numpy_rank_promotion", "raise") 23 | 24 | 25 | @pytest.fixture 26 | def getkey(): 27 | return eqxi.GetKey() 28 | -------------------------------------------------------------------------------- /.github/workflows/build_docs.yml: -------------------------------------------------------------------------------- 1 | name: Build docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | strategy: 11 | matrix: 12 | python-version: [ 3.11 ] 13 | os: [ ubuntu-latest ] 14 | runs-on: ${{ matrix.os }} 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | python -m pip install '.[docs]' 28 | 29 | - name: Build docs 30 | run: | 31 | mkdocs build 32 | 33 | - name: Upload docs 34 | uses: actions/upload-artifact@v4 35 | with: 36 | name: docs 37 | path: site # where `mkdocs build` puts the built site 38 | -------------------------------------------------------------------------------- /docs/_overrides/partials/source.html: -------------------------------------------------------------------------------- 1 | {% import "partials/language.html" as lang with context %} 2 | 3 | 4 | {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} 5 | {% include ".icons/" ~ icon ~ ".svg" %} 6 | 7 | 8 | {{ config.repo_name }} 9 | 10 | 11 | 12 | 13 | {% include ".icons/fontawesome/brands/twitter.svg" %} 14 | 15 | 16 | 17 | 18 | {% include "bluesky.svg" %} 19 | 20 | 21 | {{ config.theme.twitter_bluesky_name }} 22 | 23 | 24 | -------------------------------------------------------------------------------- /tests/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pathlib 16 | import subprocess 17 | import sys 18 | 19 | 20 | here = pathlib.Path(__file__).resolve().parent 21 | 22 | 23 | # Each file is ran separately to avoid out-of-memorying. 24 | running_out = 0 25 | for file in here.iterdir(): 26 | if file.is_file() and file.name.startswith("test"): 27 | out = subprocess.run(f"pytest {file}", shell=True).returncode 28 | running_out = max(running_out, out) 29 | sys.exit(running_out) 30 | -------------------------------------------------------------------------------- /lineax/_solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .bicgstab import BiCGStab as BiCGStab 16 | from .cg import CG as CG, NormalCG as NormalCG 17 | from .cholesky import Cholesky as Cholesky 18 | from .diagonal import Diagonal as Diagonal 19 | from .gmres import GMRES as GMRES 20 | from .lsmr import LSMR as LSMR 21 | from .lu import LU as LU 22 | from .qr import QR as QR 23 | from .svd import SVD as SVD 24 | from .triangular import Triangular as Triangular 25 | from .tridiagonal import Tridiagonal as Tridiagonal 26 | -------------------------------------------------------------------------------- /docs/api/functions.md: -------------------------------------------------------------------------------- 1 | # Functions on linear operators 2 | 3 | We define a number of functions on [linear operators](./operators.md). 4 | 5 | ## Computational changes 6 | 7 | These do not change the mathematical meaning of the operator; they simply change how it is stored computationally. (E.g. to materialise the whole operator.) 8 | 9 | ::: lineax.linearise 10 | 11 | --- 12 | 13 | ::: lineax.materialise 14 | 15 | ## Extract information from the operator 16 | 17 | ::: lineax.diagonal 18 | 19 | --- 20 | 21 | ::: lineax.tridiagonal 22 | 23 | ## Test the operator to see if it exhibits a certain property 24 | 25 | Note that these do *not* inspect the values of the operator -- instead, they use typically use [tags](./tags.md). (Or in some cases, just the type of the operator: e.g. `is_diagonal(DiagonalLinearOperator(...)) == True`.) 26 | 27 | ::: lineax.has_unit_diagonal 28 | 29 | --- 30 | 31 | ::: lineax.is_diagonal 32 | 33 | --- 34 | 35 | ::: lineax.is_tridiagonal 36 | 37 | --- 38 | 39 | ::: lineax.is_lower_triangular 40 | 41 | --- 42 | 43 | ::: lineax.is_upper_triangular 44 | 45 | --- 46 | 47 | ::: lineax.is_positive_semidefinite 48 | 49 | --- 50 | 51 | ::: lineax.is_negative_semidefinite 52 | 53 | --- 54 | 55 | ::: lineax.is_symmetric 56 | -------------------------------------------------------------------------------- /lineax/internal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .._misc import ( 17 | complex_to_real_dtype as complex_to_real_dtype, 18 | default_floating_dtype as default_floating_dtype, 19 | ) 20 | from .._norm import ( 21 | max_norm as max_norm, 22 | rms_norm as rms_norm, 23 | sum_squares as sum_squares, 24 | tree_dot as tree_dot, 25 | two_norm as two_norm, 26 | ) 27 | from .._solve import linear_solve_p as linear_solve_p 28 | from .._solver.misc import ( 29 | pack_structures as pack_structures, 30 | PackedStructures as PackedStructures, 31 | ravel_vector as ravel_vector, 32 | transpose_packed_structures as transpose_packed_structures, 33 | unravel_solution as unravel_solution, 34 | ) 35 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | repos: 16 | - repo: local 17 | hooks: 18 | - id: sort_pyproject 19 | name: sort_pyproject 20 | entry: toml-sort -i --sort-table-keys --sort-inline-tables 21 | language: python 22 | files: ^pyproject\.toml$ 23 | additional_dependencies: ["toml-sort==0.23.1"] 24 | - repo: https://github.com/astral-sh/ruff-pre-commit 25 | rev: v0.1.7 26 | hooks: 27 | - id: ruff-format 28 | types_or: [ python, pyi, jupyter ] 29 | - id: ruff 30 | types_or: [ python, pyi, jupyter ] 31 | args: [ --fix ] 32 | - repo: https://github.com/RobertCraigie/pyright-python 33 | rev: v1.1.330 34 | hooks: 35 | - id: pyright 36 | additional_dependencies: ["jax", "equinox", "pytest"] 37 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import lineax as lx 18 | import lineax._misc as lx_misc 19 | import pytest 20 | 21 | 22 | def test_inexact_asarray_no_copy(): 23 | x = jnp.array([1.0]) 24 | assert lx_misc.inexact_asarray(x) is x 25 | y = jnp.array([1.0, 2.0]) 26 | assert jax.vmap(lx_misc.inexact_asarray)(y) is y 27 | 28 | 29 | # See JAX issue #15676 30 | def test_inexact_asarray_jvp(): 31 | p, t = jax.jvp(lx_misc.inexact_asarray, (1.0,), (2.0,)) 32 | assert type(p) is not float 33 | assert type(t) is not float 34 | 35 | 36 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 37 | def test_zero_matrix(dtype): 38 | A = lx.MatrixLinearOperator(jnp.zeros((2, 2), dtype=dtype)) 39 | b = jnp.array([1.0, 2.0], dtype=dtype) 40 | lx.linear_solve(A, b, lx.SVD()) 41 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Run tests 16 | 17 | on: 18 | pull_request: 19 | 20 | jobs: 21 | run-test: 22 | strategy: 23 | matrix: 24 | python-version: [ "3.10", "3.12" ] 25 | os: [ ubuntu-latest ] 26 | fail-fast: false 27 | runs-on: ${{ matrix.os }} 28 | steps: 29 | - name: Checkout code 30 | uses: actions/checkout@v2 31 | 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v2 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | python -m pip install -r ./tests/requirements.txt 41 | 42 | - name: Checks with pre-commit 43 | uses: pre-commit/action@v3.0.1 44 | 45 | - name: Test with pytest 46 | run: | 47 | python -m pip install . 48 | python -m tests 49 | -------------------------------------------------------------------------------- /tests/test_lsmr.py: -------------------------------------------------------------------------------- 1 | import equinox as ex 2 | import jax.numpy as jnp 3 | import lineax as lx 4 | import pytest 5 | 6 | 7 | solver = lx.LSMR(1e-10, 1e-10) 8 | Aill = lx.DiagonalLinearOperator(jnp.array([1e8, 1e6, 1e4, 1e2, 1])) 9 | Awell = lx.DiagonalLinearOperator(jnp.array([2.0, 4.0, 5.0, 8.0, 10.0])) 10 | 11 | 12 | def test_ill_conditioned(): 13 | try: 14 | lx.linear_solve(Aill, jnp.ones(5), solver=solver) 15 | except ex.EquinoxRuntimeError as e: 16 | assert "Condition number" in str(e) 17 | 18 | 19 | @pytest.mark.skip("Damp support is disabled.") 20 | def test_damp_regularizes(): 21 | solution_ill = lx.linear_solve(Aill, jnp.ones(5), solver=solver, options={}) 22 | assert solution_ill.stats["istop"] == 1 23 | 24 | solution_damped = lx.linear_solve( 25 | Aill, jnp.ones(5), solver=solver, options={"damp": 100.0} 26 | ) 27 | assert solution_damped.stats["istop"] == 2 28 | 29 | assert solution_damped.stats["num_steps"] < solution_ill.stats["num_steps"] 30 | 31 | 32 | @pytest.mark.skip("Damp support is disabled.") 33 | def test_damp(): 34 | solution_damped = lx.linear_solve( 35 | Awell, jnp.ones(5), solver=solver, options={"damp": 1.0} 36 | ) 37 | assert jnp.allclose( 38 | solution_damped.value, 39 | jnp.array([0.4, 0.23529412, 0.19230769, 0.12307692, 0.0990099]), 40 | ) 41 | solution_damped = lx.linear_solve( 42 | Awell, jnp.ones(5), solver=solver, options={"damp": 1000.0} 43 | ) 44 | assert jnp.allclose( 45 | solution_damped.value, jnp.array([2e-6, 4e-6, 5e-6, 8e-6, 10.0e-6]) 46 | ) 47 | -------------------------------------------------------------------------------- /benchmarks/lstsq_gradients.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Core JAX has some numerical issues with their lstsq gradients. 16 | # See https://github.com/google/jax/issues/14868 17 | # This demonstrates that we don't have the same issue! 18 | 19 | import sys 20 | 21 | import jax 22 | import jax.numpy as jnp 23 | import lineax as lx 24 | 25 | 26 | sys.path.append("../tests") 27 | from helpers import finite_difference_jvp # pyright: ignore 28 | 29 | 30 | a_primal = (jnp.eye(3),) 31 | a_tangent = (jnp.zeros((3, 3)),) 32 | 33 | 34 | def jax_solve(a): 35 | sol, _, _, _ = jnp.linalg.lstsq(a, jnp.arange(3)) # pyright: ignore 36 | return sol 37 | 38 | 39 | def lx_solve(a): 40 | op = lx.MatrixLinearOperator(a) 41 | return lx.linear_solve(op, jnp.arange(3)).value 42 | 43 | 44 | _, true_jvp = finite_difference_jvp(jax_solve, a_primal, a_tangent) 45 | _, jax_jvp = jax.jvp(jax_solve, a_primal, a_tangent) 46 | _, lx_jvp = jax.jvp(lx_solve, a_primal, a_tangent) 47 | assert jnp.isnan(jax_jvp).all() 48 | assert jnp.allclose(true_jvp, lx_jvp) 49 | -------------------------------------------------------------------------------- /tests/test_jvp_jvp1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import equinox as eqx 16 | import jax.numpy as jnp 17 | import pytest 18 | 19 | from .helpers import ( 20 | construct_matrix, 21 | construct_singular_matrix, 22 | jvp_jvp_impl, 23 | make_jac_operator, 24 | make_matrix_operator, 25 | solvers_tags_pseudoinverse, 26 | ) 27 | 28 | 29 | # Workaround for https://github.com/jax-ml/jax/issues/27201 30 | @pytest.fixture(autouse=True) 31 | def _clear_cache(): 32 | eqx.clear_caches() 33 | 34 | 35 | @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) 36 | @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) 37 | @pytest.mark.parametrize("use_state", (True, False)) 38 | @pytest.mark.parametrize("make_matrix", (construct_matrix, construct_singular_matrix)) 39 | @pytest.mark.parametrize("dtype", (jnp.float64,)) 40 | def test_jvp_jvp( 41 | getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype 42 | ): 43 | jvp_jvp_impl( 44 | getkey, 45 | solver, 46 | tags, 47 | pseudoinverse, 48 | make_operator, 49 | use_state, 50 | make_matrix, 51 | dtype, 52 | ) 53 | -------------------------------------------------------------------------------- /tests/test_jvp_jvp2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import equinox as eqx 16 | import jax.numpy as jnp 17 | import pytest 18 | 19 | from .helpers import ( 20 | construct_matrix, 21 | construct_singular_matrix, 22 | jvp_jvp_impl, 23 | make_jac_operator, 24 | make_matrix_operator, 25 | solvers_tags_pseudoinverse, 26 | ) 27 | 28 | 29 | # Workaround for https://github.com/jax-ml/jax/issues/27201 30 | @pytest.fixture(autouse=True) 31 | def _clear_cache(): 32 | eqx.clear_caches() 33 | 34 | 35 | @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) 36 | @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) 37 | @pytest.mark.parametrize("use_state", (True, False)) 38 | @pytest.mark.parametrize("make_matrix", (construct_matrix, construct_singular_matrix)) 39 | @pytest.mark.parametrize("dtype", (jnp.complex128,)) 40 | def test_jvp_jvp( 41 | getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype 42 | ): 43 | jvp_jvp_impl( 44 | getkey, 45 | solver, 46 | tags, 47 | pseudoinverse, 48 | make_operator, 49 | use_state, 50 | make_matrix, 51 | dtype, 52 | ) 53 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions (pull requests) are very welcome! Here's how to get started. 4 | 5 | --- 6 | 7 | **Getting started** 8 | 9 | First fork the library on GitHub. 10 | 11 | Then clone and install the library in development mode: 12 | 13 | ```bash 14 | git clone https://github.com/your-username-here/lineax.git 15 | cd lineax 16 | pip install -e . 17 | ``` 18 | 19 | Then install the pre-commit hook: 20 | 21 | ```bash 22 | pip install pre-commit 23 | pre-commit install 24 | ``` 25 | 26 | These hooks use Black to format the code, and ruff to lint it. 27 | 28 | --- 29 | 30 | **If you're making changes to the code:** 31 | 32 | Now make your changes. Make sure to include additional tests if necessary. 33 | 34 | Next verify the tests all pass: 35 | 36 | ```bash 37 | pip install -r tests/requirements.txt 38 | python -m tests 39 | ``` 40 | 41 | Then push your changes back to your fork of the repository: 42 | 43 | ```bash 44 | git push 45 | ``` 46 | 47 | Finally, open a pull request on GitHub! 48 | 49 | --- 50 | 51 | **If you're making changes to the documentation:** 52 | 53 | Make your changes. You can then build the documentation by doing 54 | 55 | ```bash 56 | pip install -r docs/requirements.txt 57 | mkdocs serve 58 | ``` 59 | Then doing `Control-C`, and running: 60 | ``` 61 | mkdocs serve 62 | ``` 63 | (So you run `mkdocs serve` twice.) 64 | 65 | You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser. 66 | 67 | ## Contributor License Agreement 68 | 69 | Contributions to this project must be accompanied by a Contributor License 70 | Agreement (CLA). You (or your employer) retain the copyright to your 71 | contribution; this simply gives us permission to use and redistribute your 72 | contributions as part of the project. Head over to 73 | to see your current agreements on file or 74 | to sign a new one. 75 | 76 | You generally only need to submit a CLA once, so if you've already submitted one 77 | (even if it was for a different project), you probably don't need to do it 78 | again. 79 | -------------------------------------------------------------------------------- /docs/examples/classical_solve.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8d41e1dd-93da-4e81-bd4a-33e5df8915f1", 6 | "metadata": {}, 7 | "source": [ 8 | "# Classical solve\n", 9 | "\n", 10 | "We wish to solve the linear system $Ax = b$. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "cb3a7781-2358-40c4-82f3-e908bddeb578", 17 | "metadata": { 18 | "tags": [], 19 | "ExecuteTime": { 20 | "end_time": "2024-04-02T05:26:05.556701Z", 21 | "start_time": "2024-04-02T05:26:03.814599Z" 22 | } 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "A=\n", 30 | "[[-0.3721109 0.26423115 -0.18252768]\n", 31 | " [-0.7368197 0.44973662 -0.1521442 ]\n", 32 | " [-0.67135346 -0.5908641 0.73168886]]\n", 33 | "b=[ 0.17269018 -0.64765567 1.2229712 ]\n", 34 | "x=[-2.7321298 -8.52878 -7.7226872]\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "import jax.random as jr\n", 40 | "import lineax as lx\n", 41 | "\n", 42 | "\n", 43 | "matrix = jr.normal(jr.PRNGKey(0), (3, 3))\n", 44 | "vector = jr.normal(jr.PRNGKey(1), (3,))\n", 45 | "operator = lx.MatrixLinearOperator(matrix)\n", 46 | "solution = lx.linear_solve(operator, vector)\n", 47 | "print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")" 48 | ] 49 | } 50 | ], 51 | "metadata": { 52 | "kernelspec": { 53 | "display_name": "Python 3 (ipykernel)", 54 | "language": "python", 55 | "name": "python3" 56 | }, 57 | "language_info": { 58 | "codemirror_mode": { 59 | "name": "ipython", 60 | "version": 3 61 | }, 62 | "file_extension": ".py", 63 | "mimetype": "text/x-python", 64 | "name": "python", 65 | "nbconvert_exporter": "python", 66 | "pygments_lexer": "ipython3", 67 | "version": "3.9.16" 68 | } 69 | }, 70 | "nbformat": 4, 71 | "nbformat_minor": 5 72 | } 73 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | ## How does this differ from `jax.numpy.solve`, `jax.scipy.{...}` etc.? 4 | 5 | Lineax offers several improvements. Most notably: 6 | 7 | - Several new solvers. For example, [`lineax.QR`][] has no counterpart in core JAX. (And it is much faster than `jax.numpy.linalg.lstsq`, which is the closest equivalent, and uses an SVD decomposition instead.) 8 | 9 | - Several new operators. For example, [`lineax.JacobianLinearOperator`][] has no counterpart in core JAX. 10 | 11 | - A consistent API. The built-in JAX operations all differ from each other slightly, and are split across `jax.numpy`, `jax.scipy`, and `jax.scipy.sparse`. 12 | 13 | - Numerically stable gradients. The existing JAX implementations will sometimes return `NaN`s! 14 | 15 | - Some faster compile times and run times in a few places. 16 | 17 | Most of these are because JAX aims to mimic the existing NumPy/SciPy APIs. (I.e. it's not JAX's fault that it doesn't take the approach that Lineax does!) 18 | 19 | ## How do I represent a {lower, upper} triangular matrix? 20 | 21 | Typically: create a full matrix, with the {lower, upper} part containing your values, and the converse {upper, lower} part containing all zeros. Then use, e.g., `operator = lx.MatrixLinearOperator(matrix, lx.lower_triangular_tag)`. 22 | 23 | This is the most efficient way to store a triangular matrix in JAX's ndarray-based programming model. 24 | 25 | ## What about other operations from linear algebra? (Determinants, eigenvalues, etc.) 26 | 27 | See [`jax.numpy.linalg`](https://jax.readthedocs.io/en/latest/jax.numpy.html#module-jax.numpy.linalg) and [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg). 28 | 29 | ## How do I solve multiple systems of equations (i.e. `AX = B`)? 30 | 31 | Solvers implemented in Lineax target single systems of linear equations (i.e., `Ax = b`), however, using `jax.vmap` or `equinox.filter_vmap`, it can solve multiple systems with minimal effort. 32 | 33 | ```python 34 | multi_linear_solve = eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1)) 35 | # or 36 | multi_linear_solve = jax.vmap(lx.linear_solve, in_axes=(None, 1)) 37 | ``` 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [project] 6 | authors = [ 7 | {email = "raderjason@outlook.com", name = "Jason Rader"}, 8 | {email = "contact@kidger.site", name = "Patrick Kidger"} 9 | ] 10 | classifiers = [ 11 | "Development Status :: 3 - Alpha", 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Financial and Insurance Industry", 14 | "Intended Audience :: Information Technology", 15 | "Intended Audience :: Science/Research", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Natural Language :: English", 18 | "Programming Language :: Python :: 3", 19 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 20 | "Topic :: Scientific/Engineering :: Information Analysis", 21 | "Topic :: Scientific/Engineering :: Mathematics" 22 | ] 23 | dependencies = ["jax>=0.6.1", "jaxtyping>=0.2.24", "equinox>=0.11.10", "typing_extensions>=4.5.0"] 24 | description = "Linear solvers in JAX and Equinox." 25 | keywords = ["jax", "neural-networks", "deep-learning", "equinox", "linear-solvers", "least-squares", "numerical-methods"] 26 | license = {file = "LICENSE"} 27 | name = "lineax" 28 | readme = "README.md" 29 | requires-python = "~=3.10" 30 | urls = {repository = "https://github.com/google/lineax"} 31 | version = "0.0.8" 32 | 33 | [project.optional-dependencies] 34 | docs = [ 35 | "hippogriffe==0.2.0", 36 | "mkdocs==1.6.1", 37 | "mkdocs-include-exclude-files==0.1.0", 38 | "mkdocs-ipynb==0.1.0", 39 | "mkdocs-material==9.6.7", 40 | "mkdocstrings[python]==0.28.3", 41 | "pymdown-extensions==10.14.3" 42 | ] 43 | 44 | [tool.hatch.build] 45 | include = ["lineax/*"] 46 | 47 | [tool.pyright] 48 | include = ["lineax", "tests"] 49 | reportIncompatibleMethodOverride = true 50 | 51 | [tool.pytest.ini_options] 52 | addopts = "--jaxtyping-packages=lineax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" 53 | 54 | [tool.ruff] 55 | extend-include = ["*.ipynb"] 56 | fixable = ["I001", "F401", "UP"] 57 | ignore = ["E402", "E721", "E731", "E741", "F722", "UP038"] 58 | ignore-init-module-imports = true 59 | select = ["E", "F", "I001", "UP"] 60 | src = [] 61 | 62 | [tool.ruff.isort] 63 | combine-as-imports = true 64 | extra-standard-library = ["typing_extensions"] 65 | lines-after-imports = 2 66 | order-by-type = false 67 | -------------------------------------------------------------------------------- /docs/api/operators.md: -------------------------------------------------------------------------------- 1 | # Linear operators 2 | 3 | We often talk about solving a linear system $Ax = b$, where $A \in \mathbb{R}^{n \times m}$ is a matrix, $b \in \mathbb{R}^n$ is a vector, and $x \in \mathbb{R}^m$ is our desired solution. 4 | 5 | The linear operators described on this page are ways of describing the matrix $A$. The simplest is [`lineax.MatrixLinearOperator`][], which simply holds the matrix $A$ directly. 6 | 7 | Meanwhile if $A$ is diagonal, then there is also [`lineax.DiagonalLinearOperator`][]: for efficiency this only stores the diagonal of $A$. 8 | 9 | Or, perhaps we only have a function $F : \mathbb{R}^m \to \mathbb{R}^n$ such that $F(x) = Ax$. Whilst we could use $F$ to materialise the whole matrix $A$ and then store it in a [`lineax.MatrixLinearOperator`][], that may be very memory intensive. Instead, we may prefer to use [`lineax.FunctionLinearOperator`][]. Many linear solvers (e.g. [`lineax.CG`][]) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix $A$. 10 | 11 | ??? abstract "`lineax.AbstractLinearOperator`" 12 | 13 | ::: lineax.AbstractLinearOperator 14 | options: 15 | members: 16 | - mv 17 | - as_matrix 18 | - transpose 19 | - in_structure 20 | - out_structure 21 | - in_size 22 | - out_size 23 | 24 | ::: lineax.MatrixLinearOperator 25 | options: 26 | members: 27 | - __init__ 28 | 29 | --- 30 | 31 | ::: lineax.DiagonalLinearOperator 32 | options: 33 | members: 34 | - __init__ 35 | 36 | --- 37 | 38 | ::: lineax.TridiagonalLinearOperator 39 | options: 40 | members: 41 | - __init__ 42 | 43 | --- 44 | 45 | ::: lineax.PyTreeLinearOperator 46 | options: 47 | members: 48 | - __init__ 49 | 50 | --- 51 | 52 | ::: lineax.JacobianLinearOperator 53 | options: 54 | members: 55 | - __init__ 56 | 57 | --- 58 | 59 | ::: lineax.FunctionLinearOperator 60 | options: 61 | members: 62 | - __init__ 63 | 64 | --- 65 | 66 | ::: lineax.IdentityLinearOperator 67 | options: 68 | members: 69 | - __init__ 70 | 71 | --- 72 | 73 | ::: lineax.TaggedLinearOperator 74 | options: 75 | members: 76 | - __init__ 77 | -------------------------------------------------------------------------------- /docs/examples/complex_solve.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8d41e1dd-93da-4e81-bd4a-33e5df8915f1", 6 | "metadata": {}, 7 | "source": [ 8 | "# Complex solve\n", 9 | "\n", 10 | "We can also solve a system with complex entries. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "cb3a7781-2358-40c4-82f3-e908bddeb578", 17 | "metadata": { 18 | "tags": [], 19 | "ExecuteTime": { 20 | "end_time": "2024-04-02T05:29:04.909894Z", 21 | "start_time": "2024-04-02T05:29:04.103141Z" 22 | } 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "A=\n", 30 | "[[-1.8459436 -0.2744466j 0.02393756-0.03172905j 0.76815367-1.4444253j ]\n", 31 | " [-1.0467293 +0.05608991j 1.0891742 -0.03264743j 0.7513123 +0.56285536j]\n", 32 | " [ 0.38307396-1.0190808j 0.01203694-1.1971304j 0.19252291-0.26424018j]]\n", 33 | "b=[0.23162952+0.3614433j 0.05800135+1.6094692j 0.8979094 +0.16941352j]\n", 34 | "x=[-0.07652722-0.34397143j -0.22629777+1.0359733j 0.22135164-0.00880566j]\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "import jax.numpy as jnp\n", 40 | "import jax.random as jr\n", 41 | "import lineax as lx\n", 42 | "\n", 43 | "\n", 44 | "matrix = jr.normal(jr.PRNGKey(0), (3, 3), dtype=jnp.complex64)\n", 45 | "vector = jr.normal(jr.PRNGKey(1), (3,), dtype=jnp.complex64)\n", 46 | "operator = lx.MatrixLinearOperator(matrix)\n", 47 | "solution = lx.linear_solve(operator, vector)\n", 48 | "print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")" 49 | ] 50 | } 51 | ], 52 | "metadata": { 53 | "kernelspec": { 54 | "display_name": "Python 3 (ipykernel)", 55 | "language": "python", 56 | "name": "python3" 57 | }, 58 | "language_info": { 59 | "codemirror_mode": { 60 | "name": "ipython", 61 | "version": 3 62 | }, 63 | "file_extension": ".py", 64 | "mimetype": "text/x-python", 65 | "name": "python", 66 | "nbconvert_exporter": "python", 67 | "pygments_lexer": "ipython3", 68 | "version": "3.9.16" 69 | } 70 | }, 71 | "nbformat": 4, 72 | "nbformat_minor": 5 73 | } 74 | -------------------------------------------------------------------------------- /lineax/_tags.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class _HasRepr: 17 | def __init__(self, string: str): 18 | self.string = string 19 | 20 | def __repr__(self): 21 | return self.string 22 | 23 | 24 | symmetric_tag = _HasRepr("symmetric_tag") 25 | diagonal_tag = _HasRepr("diagonal_tag") 26 | tridiagonal_tag = _HasRepr("tridiagonal_tag") 27 | unit_diagonal_tag = _HasRepr("unit_diagonal_tag") 28 | lower_triangular_tag = _HasRepr("lower_triangular_tag") 29 | upper_triangular_tag = _HasRepr("upper_triangular_tag") 30 | positive_semidefinite_tag = _HasRepr("positive_semidefinite_tag") 31 | negative_semidefinite_tag = _HasRepr("negative_semidefinite_tag") 32 | 33 | 34 | transpose_tags_rules = [] 35 | 36 | 37 | for tag in ( 38 | symmetric_tag, 39 | unit_diagonal_tag, 40 | diagonal_tag, 41 | positive_semidefinite_tag, 42 | negative_semidefinite_tag, 43 | tridiagonal_tag, 44 | ): 45 | 46 | @transpose_tags_rules.append 47 | def _(tags: frozenset[object], tag=tag): 48 | if tag in tags: 49 | return tag 50 | 51 | 52 | @transpose_tags_rules.append 53 | def _(tags: frozenset[object]): 54 | if lower_triangular_tag in tags: 55 | return upper_triangular_tag 56 | 57 | 58 | @transpose_tags_rules.append 59 | def _(tags: frozenset[object]): 60 | if upper_triangular_tag in tags: 61 | return lower_triangular_tag 62 | 63 | 64 | def transpose_tags(tags: frozenset[object]): 65 | """Lineax uses "tags" to declare that a particular linear operator exhibits some 66 | property, e.g. symmetry. 67 | 68 | This function takes in a collection of tags representing a linear operator, and 69 | returns a collection of tags that should be associated with the transpose of that 70 | linear operator. 71 | 72 | **Arguments:** 73 | 74 | - `tags`: a `frozenset` of tags. 75 | 76 | **Returns:** 77 | 78 | A `frozenset` of tags. 79 | """ 80 | if symmetric_tag in tags: 81 | return tags 82 | new_tags = [] 83 | for rule in transpose_tags_rules: 84 | out = rule(tags) 85 | if out is not None: 86 | new_tags.append(out) 87 | return frozenset(new_tags) 88 | -------------------------------------------------------------------------------- /tests/test_transpose.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import equinox as eqx 16 | import jax 17 | import jax.numpy as jnp 18 | import jax.random as jr 19 | import lineax as lx 20 | import pytest 21 | 22 | from .helpers import construct_matrix, params, tree_allclose 23 | 24 | 25 | class TestTranspose: 26 | @pytest.fixture(scope="class") 27 | def assert_transpose_fixture(_): 28 | @eqx.filter_jit 29 | def solve_transpose(operator, out_vec, in_vec, solver): 30 | return jax.linear_transpose( 31 | lambda v: lx.linear_solve(operator, v, solver).value, out_vec 32 | )(in_vec) 33 | 34 | def assert_transpose(operator, out_vec, in_vec, solver): 35 | (out,) = solve_transpose(operator, out_vec, in_vec, solver) 36 | true_out = lx.linear_solve(operator.T, in_vec, solver).value 37 | assert tree_allclose(out, true_out) 38 | 39 | return assert_transpose 40 | 41 | @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False)) 42 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 43 | def test_transpose( 44 | _, make_operator, solver, tags, assert_transpose_fixture, dtype, getkey 45 | ): 46 | (matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype) 47 | operator = make_operator(getkey, matrix, tags) 48 | out_size, in_size = matrix.shape 49 | out_vec = jr.normal(getkey(), (out_size,), dtype=dtype) 50 | in_vec = jr.normal(getkey(), (in_size,), dtype=dtype) 51 | solver = lx.AutoLinearSolver(well_posed=True) 52 | assert_transpose_fixture(operator, out_vec, in_vec, solver) 53 | 54 | def test_pytree_transpose(_, assert_transpose_fixture): # pyright: ignore 55 | a = jnp.array 56 | pytree = [[a(1), a(2), a(3)], [a(4), a(5), a(6)]] 57 | output_structure = jax.eval_shape(lambda: [1, 2]) 58 | operator = lx.PyTreeLinearOperator(pytree, output_structure) 59 | out_vec = [a(1.0), a(2.0)] 60 | in_vec = [a(1.0), 2.0, 3.0] 61 | solver = lx.AutoLinearSolver(well_posed=False) 62 | assert_transpose_fixture(operator, out_vec, in_vec, solver) 63 | -------------------------------------------------------------------------------- /benchmarks/gmres_fails_safely.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools as ft 16 | 17 | import equinox as eqx 18 | import equinox.internal as eqxi 19 | import jax 20 | import jax.numpy as jnp 21 | import jax.random as jr 22 | import jax.scipy as jsp 23 | import lineax as lx 24 | 25 | 26 | getkey = eqxi.GetKey() 27 | 28 | 29 | def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8): 30 | return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol) 31 | 32 | 33 | jax.config.update("jax_enable_x64", True) 34 | 35 | 36 | def make_problem(mat_size: int, *, key): 37 | mat = jr.normal(key, (mat_size, mat_size)) 38 | true_x = jr.normal(key, (mat_size,)) 39 | b = mat @ true_x 40 | op = lx.MatrixLinearOperator(mat) 41 | return mat, op, b, true_x 42 | 43 | 44 | def benchmark_jax(mat_size: int, *, key): 45 | mat, _, b, true_x = make_problem(mat_size, key=key) 46 | 47 | solve_with_jax = ft.partial( 48 | jsp.sparse.linalg.gmres, tol=1e-5, solve_method="batched" 49 | ) 50 | gmres_jit = jax.jit(solve_with_jax) 51 | jax_soln, info = gmres_jit(mat, b) 52 | 53 | # info == 0.0 implies that the solve has succeeded. 54 | returned_failed = jnp.all(info != 0.0) 55 | actually_failed = not tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4) 56 | 57 | assert actually_failed 58 | 59 | captured_failure = returned_failed & actually_failed 60 | return captured_failure 61 | 62 | 63 | def benchmark_lx(mat_size: int, *, key): 64 | _, op, b, true_x = make_problem(mat_size, key=key) 65 | 66 | lx_soln = lx.linear_solve(op, b, lx.GMRES(atol=1e-5, rtol=1e-5), throw=False) 67 | 68 | returned_failed = jnp.all(lx_soln.result != lx.RESULTS.successful) 69 | actually_failed = not tree_allclose(lx_soln.value, true_x, atol=1e-4, rtol=1e-4) 70 | 71 | assert actually_failed 72 | 73 | captured_failure = returned_failed & actually_failed 74 | return captured_failure 75 | 76 | 77 | lx_failed_safely = 0 78 | jax_failed_safely = 0 79 | 80 | for _ in range(100): 81 | key = getkey() 82 | jax_captured_failure = benchmark_jax(100, key=key) 83 | lx_captured_failure = benchmark_lx(100, key=key) 84 | 85 | jax_failed_safely = jax_failed_safely + jax_captured_failure 86 | lx_failed_safely = lx_failed_safely + lx_captured_failure 87 | 88 | print(f"JAX failed safely {jax_failed_safely} out of 100 times") 89 | print(f"Lineax failed safely {lx_failed_safely} out of 100 times") 90 | -------------------------------------------------------------------------------- /docs/api/solvers.md: -------------------------------------------------------------------------------- 1 | # Solvers 2 | 3 | If you're not sure what to use, then pick [`lineax.AutoLinearSolver`][] and it will automatically dispatch to an efficient solver depending on what structure your linear operator is declared to exhibit. (See the [tags](./tags.md) page.) 4 | 5 | ??? abstract "`lineax.AbstractLinearSolver`" 6 | 7 | ::: lineax.AbstractLinearSolver 8 | options: 9 | members: 10 | - init 11 | - compute 12 | - allow_dependent_columns 13 | - allow_dependent_rows 14 | - transpose 15 | 16 | ::: lineax.AutoLinearSolver 17 | options: 18 | members: 19 | - __init__ 20 | - select_solver 21 | 22 | --- 23 | 24 | ::: lineax.LU 25 | options: 26 | members: 27 | - __init__ 28 | 29 | ## Least squares solvers 30 | 31 | These are capable of solving ill-posed linear problems. 32 | 33 | ::: lineax.QR 34 | options: 35 | members: 36 | - __init__ 37 | 38 | --- 39 | 40 | ::: lineax.SVD 41 | options: 42 | members: 43 | - __init__ 44 | 45 | !!! info 46 | 47 | In addition to these, `lineax.Diagonal(well_posed=False)` and [`lineax.NormalCG`][] (below) also support ill-posed problems. 48 | 49 | ## Structure-exploiting solvers 50 | 51 | These require special structure in the operator. (And will throw an error if passed an operator without that structure.) In return, they are able to solve the linear problem much more efficiently. 52 | 53 | ::: lineax.Cholesky 54 | options: 55 | members: 56 | - __init__ 57 | 58 | --- 59 | 60 | ::: lineax.Diagonal 61 | options: 62 | members: 63 | - __init__ 64 | 65 | --- 66 | 67 | ::: lineax.Triangular 68 | options: 69 | members: 70 | - __init__ 71 | 72 | --- 73 | 74 | ::: lineax.Tridiagonal 75 | options: 76 | members: 77 | - __init__ 78 | 79 | !!! info 80 | 81 | In addition to these, [`lineax.CG`][] also requires special structure (positive or negative definiteness). 82 | 83 | ## Iterative solvers 84 | 85 | These solvers use only matrix-vector products, and do not require instantiating the whole matrix. This makes them good when used alongside e.g. [`lineax.JacobianLinearOperator`][] or [`lineax.FunctionLinearOperator`][], which only provide matrix-vector products. 86 | 87 | !!! warning 88 | 89 | Note that [`lineax.BiCGStab`][] and [`lineax.GMRES`][] may fail to converge on some (typically non-sparse) problems. 90 | 91 | ::: lineax.CG 92 | options: 93 | members: 94 | - __init__ 95 | 96 | --- 97 | 98 | ::: lineax.NormalCG 99 | options: 100 | members: 101 | - __init__ 102 | 103 | --- 104 | 105 | ::: lineax.BiCGStab 106 | options: 107 | members: 108 | - __init__ 109 | 110 | --- 111 | 112 | ::: lineax.GMRES 113 | options: 114 | members: 115 | - __init__ 116 | -------------------------------------------------------------------------------- /tests/test_well_posed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import jax.random as jr 18 | import lineax as lx 19 | import pytest 20 | 21 | from .helpers import ( 22 | construct_matrix, 23 | ops, 24 | params, 25 | solvers, 26 | tree_allclose, 27 | ) 28 | 29 | 30 | @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False)) 31 | @pytest.mark.parametrize("ops", ops) 32 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 33 | def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): 34 | if jax.config.jax_enable_x64: # pyright: ignore 35 | tol = 1e-10 36 | else: 37 | tol = 1e-4 38 | (matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype) 39 | operator = make_operator(getkey, matrix, tags) 40 | operator, matrix = ops(operator, matrix) 41 | assert tree_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol) 42 | out_size, _ = matrix.shape 43 | true_x = jr.normal(getkey(), (out_size,), dtype=dtype) 44 | b = matrix @ true_x 45 | x = lx.linear_solve(operator, b, solver=solver).value 46 | jax_x = jnp.linalg.solve(matrix, b) # pyright: ignore 47 | assert tree_allclose(x, true_x, atol=tol, rtol=tol) 48 | assert tree_allclose(x, jax_x, atol=tol, rtol=tol) 49 | 50 | 51 | @pytest.mark.parametrize("solver", solvers) 52 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 53 | def test_pytree_wellposed(solver, getkey, dtype): 54 | if not isinstance( 55 | solver, 56 | (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.CG, lx.NormalCG), 57 | ): 58 | if jax.config.jax_enable_x64: # pyright: ignore 59 | tol = 1e-10 60 | else: 61 | tol = 1e-4 62 | 63 | true_x = [ 64 | jr.normal(getkey(), shape=(2, 4), dtype=dtype), 65 | jr.normal(getkey(), (3,), dtype=dtype), 66 | ] 67 | pytree = [ 68 | [ 69 | jr.normal(getkey(), shape=(2, 4, 2, 4), dtype=dtype), 70 | jr.normal(getkey(), shape=(2, 4, 3), dtype=dtype), 71 | ], 72 | [ 73 | jr.normal(getkey(), shape=(3, 2, 4), dtype=dtype), 74 | jr.normal(getkey(), shape=(3, 3), dtype=dtype), 75 | ], 76 | ] 77 | out_structure = jax.eval_shape(lambda: true_x) 78 | 79 | operator = lx.PyTreeLinearOperator(pytree, out_structure) 80 | b = operator.mv(true_x) 81 | lx_x = lx.linear_solve(operator, b, solver, throw=False) 82 | assert tree_allclose(lx_x.value, true_x, atol=tol, rtol=tol) 83 | -------------------------------------------------------------------------------- /lineax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib.metadata 16 | 17 | from . import internal as internal 18 | from ._operator import ( 19 | AbstractLinearOperator as AbstractLinearOperator, 20 | AddLinearOperator as AddLinearOperator, 21 | AuxLinearOperator as AuxLinearOperator, 22 | ComposedLinearOperator as ComposedLinearOperator, 23 | conj as conj, 24 | diagonal as diagonal, 25 | DiagonalLinearOperator as DiagonalLinearOperator, 26 | DivLinearOperator as DivLinearOperator, 27 | FunctionLinearOperator as FunctionLinearOperator, 28 | has_unit_diagonal as has_unit_diagonal, 29 | IdentityLinearOperator as IdentityLinearOperator, 30 | is_diagonal as is_diagonal, 31 | is_lower_triangular as is_lower_triangular, 32 | is_negative_semidefinite as is_negative_semidefinite, 33 | is_positive_semidefinite as is_positive_semidefinite, 34 | is_symmetric as is_symmetric, 35 | is_tridiagonal as is_tridiagonal, 36 | is_upper_triangular as is_upper_triangular, 37 | JacobianLinearOperator as JacobianLinearOperator, 38 | linearise as linearise, 39 | materialise as materialise, 40 | MatrixLinearOperator as MatrixLinearOperator, 41 | MulLinearOperator as MulLinearOperator, 42 | NegLinearOperator as NegLinearOperator, 43 | PyTreeLinearOperator as PyTreeLinearOperator, 44 | TaggedLinearOperator as TaggedLinearOperator, 45 | TangentLinearOperator as TangentLinearOperator, 46 | tridiagonal as tridiagonal, 47 | TridiagonalLinearOperator as TridiagonalLinearOperator, 48 | ) 49 | from ._solution import RESULTS as RESULTS, Solution as Solution 50 | from ._solve import ( 51 | AbstractLinearSolver as AbstractLinearSolver, 52 | AutoLinearSolver as AutoLinearSolver, 53 | linear_solve as linear_solve, 54 | ) 55 | from ._solver import ( 56 | BiCGStab as BiCGStab, 57 | CG as CG, 58 | Cholesky as Cholesky, 59 | Diagonal as Diagonal, 60 | GMRES as GMRES, 61 | LSMR as LSMR, 62 | LU as LU, 63 | NormalCG as NormalCG, 64 | QR as QR, 65 | SVD as SVD, 66 | Triangular as Triangular, 67 | Tridiagonal as Tridiagonal, 68 | ) 69 | from ._tags import ( 70 | diagonal_tag as diagonal_tag, 71 | lower_triangular_tag as lower_triangular_tag, 72 | negative_semidefinite_tag as negative_semidefinite_tag, 73 | positive_semidefinite_tag as positive_semidefinite_tag, 74 | symmetric_tag as symmetric_tag, 75 | transpose_tags as transpose_tags, 76 | transpose_tags_rules as transpose_tags_rules, 77 | tridiagonal_tag as tridiagonal_tag, 78 | unit_diagonal_tag as unit_diagonal_tag, 79 | upper_triangular_tag as upper_triangular_tag, 80 | ) 81 | 82 | 83 | __version__ = importlib.metadata.version("lineax") 84 | -------------------------------------------------------------------------------- /lineax/_solution.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any 16 | 17 | import equinox as eqx 18 | import equinox.internal as eqxi 19 | from jaxtyping import Array, ArrayLike, PyTree 20 | 21 | 22 | _singular_msg = """ 23 | A linear solver returned non-finite (NaN or inf) output. This usually means that an 24 | operator was not well-posed, and that its solver does not support this. 25 | 26 | If you are trying solve a linear least-squares problem then you should pass 27 | `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` 28 | assumes that the operator is square and nonsingular. 29 | 30 | If you *were* expecting this solver to work with this operator, then it may be because: 31 | 32 | (a) the operator is singular, and your code has a bug; or 33 | 34 | (b) the operator was nearly singular (i.e. it had a high condition number: 35 | `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from 36 | numerical instability issues; or 37 | 38 | (c) the operator is declared to exhibit a certain property (e.g. positive definiteness) 39 | that is does not actually satisfy. 40 | """.strip() 41 | 42 | 43 | _nonfinite_msg = """ 44 | A linear solver received non-finite (NaN or inf) input and cannot determine a 45 | solution. 46 | 47 | This means that you have a bug upstream of Lineax and should check the inputs to 48 | `lineax.linear_solve` for non-finite values. 49 | """.strip() 50 | 51 | 52 | class RESULTS(eqxi.Enumeration): 53 | successful = "" 54 | max_steps_reached = ( 55 | "The maximum number of solver steps was reached. Try increasing `max_steps`." 56 | ) 57 | singular = _singular_msg 58 | breakdown = ( 59 | "A form of iterative breakdown has occured in a linear solve. " 60 | "Try using a different solver for this problem or increase `restart` " 61 | "if using GMRES." 62 | ) 63 | stagnation = ( 64 | "A stagnation in an iterative linear solve has occurred. Try increasing " 65 | "`stagnation_iters` or `restart`." 66 | ) 67 | conlim = "Condition number of A seems to be larger than `conlim`." 68 | nonfinite_input = _nonfinite_msg 69 | 70 | 71 | class Solution(eqx.Module): 72 | """The solution to a linear solve. 73 | 74 | **Attributes:** 75 | 76 | - `value`: The solution to the solve. 77 | - `result`: An integer representing whether the solve was successful or not. This 78 | can be converted into a human-readable error message via 79 | `lineax.RESULTS[result]`. 80 | - `stats`: Statistics about the solver, e.g. the number of steps that were required. 81 | - `state`: The internal state of the solver. The meaning of this is specific to each 82 | solver. 83 | """ 84 | 85 | value: PyTree[Array] 86 | result: RESULTS 87 | stats: dict[str, PyTree[ArrayLike]] 88 | state: PyTree[Any] 89 | -------------------------------------------------------------------------------- /lineax/_solver/lu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, TypeAlias 16 | 17 | import jax.numpy as jnp 18 | import jax.scipy as jsp 19 | from jaxtyping import Array, PyTree 20 | 21 | from .._operator import AbstractLinearOperator, is_diagonal 22 | from .._solution import RESULTS 23 | from .._solve import AbstractLinearSolver 24 | from .misc import ( 25 | pack_structures, 26 | PackedStructures, 27 | ravel_vector, 28 | transpose_packed_structures, 29 | unravel_solution, 30 | ) 31 | 32 | 33 | _LUState: TypeAlias = tuple[tuple[Array, Array], PackedStructures, bool] 34 | 35 | 36 | class LU(AbstractLinearSolver[_LUState]): 37 | """LU solver for linear systems. 38 | 39 | This solver can only handle square nonsingular operators. 40 | """ 41 | 42 | def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): 43 | del options 44 | if operator.in_size() != operator.out_size(): 45 | raise ValueError( 46 | "`LU` may only be used for linear solves with square matrices" 47 | ) 48 | packed_structures = pack_structures(operator) 49 | if is_diagonal(operator): 50 | lu = operator.as_matrix(), jnp.arange(operator.in_size(), dtype=jnp.int32) 51 | else: 52 | lu = jsp.linalg.lu_factor(operator.as_matrix()) 53 | return lu, packed_structures, False 54 | 55 | def compute( 56 | self, state: _LUState, vector: PyTree[Array], options: dict[str, Any] 57 | ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: 58 | del options 59 | lu_and_piv, packed_structures, transpose = state 60 | trans = 1 if transpose else 0 61 | vector = ravel_vector(vector, packed_structures) 62 | solution = jsp.linalg.lu_solve(lu_and_piv, vector, trans=trans) 63 | solution = unravel_solution(solution, packed_structures) 64 | return solution, RESULTS.successful, {} 65 | 66 | def transpose( 67 | self, 68 | state: _LUState, 69 | options: dict[str, Any], 70 | ): 71 | lu_and_piv, packed_structures, transpose = state 72 | transposed_packed_structures = transpose_packed_structures(packed_structures) 73 | transpose_state = lu_and_piv, transposed_packed_structures, not transpose 74 | transpose_options = {} 75 | return transpose_state, transpose_options 76 | 77 | def conj( 78 | self, 79 | state: _LUState, 80 | options: dict[str, Any], 81 | ): 82 | (lu, piv), packed_structures, transpose = state 83 | conj_state = (lu.conj(), piv), packed_structures, not transpose 84 | conj_options = {} 85 | return conj_state, conj_options 86 | 87 | def allow_dependent_columns(self, operator): 88 | return False 89 | 90 | def allow_dependent_rows(self, operator): 91 | return False 92 | 93 | 94 | LU.__init__.__doc__ = """**Arguments:** 95 | 96 | Nothing. 97 | """ 98 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | 3 | Lineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.) 4 | 5 | Features include: 6 | 7 | - PyTree-valued matrices and vectors; 8 | - General linear operators for Jacobians, transposes, etc.; 9 | - Efficient linear least squares (e.g. QR solvers); 10 | - Numerically stable gradients through linear least squares; 11 | - Support for structured (e.g. symmetric) matrices; 12 | - Improved compilation times; 13 | - Improved runtime of some algorithms; 14 | - Support for both real-valued and complex-valued inputs; 15 | - All the benefits of working with JAX: autodiff, autoparallism, GPU/TPU support, etc. 16 | 17 | ## Installation 18 | 19 | ```bash 20 | pip install lineax 21 | ``` 22 | 23 | Requires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+. 24 | 25 | ## Quick example 26 | 27 | Lineax can solve a least squares problem with an explicit matrix operator: 28 | 29 | ```python 30 | import jax.random as jr 31 | import lineax as lx 32 | 33 | matrix_key, vector_key = jr.split(jr.PRNGKey(0)) 34 | matrix = jr.normal(matrix_key, (10, 8)) 35 | vector = jr.normal(vector_key, (10,)) 36 | operator = lx.MatrixLinearOperator(matrix) 37 | solution = lx.linear_solve(operator, vector, solver=lx.QR()) 38 | ``` 39 | 40 | or Lineax can solve a problem without ever materializing a matrix, as done in this 41 | quadratic solve: 42 | 43 | ```python 44 | import jax 45 | import lineax as lx 46 | 47 | key = jax.random.PRNGKey(0) 48 | y = jax.random.normal(key, (10,)) 49 | 50 | def quadratic_fn(y, args): 51 | return jax.numpy.sum((y - 1)**2) 52 | 53 | gradient_fn = jax.grad(quadratic_fn) 54 | hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag) 55 | solver = lx.CG(rtol=1e-6, atol=1e-6) 56 | out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver) 57 | minimum = y - out.value 58 | ``` 59 | 60 | ## Next steps 61 | 62 | Check out the examples or the API reference on the left-hand bar. 63 | 64 | ## See also: other libraries in the JAX ecosystem 65 | 66 | **Always useful** 67 | [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! 68 | [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays. 69 | 70 | **Deep learning** 71 | [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. 72 | [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). 73 | [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). 74 | [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. 75 | 76 | **Scientific computing** 77 | [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers. 78 | [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. 79 | [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. 80 | [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. 81 | [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) 82 | 83 | **Awesome JAX** 84 | [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. 85 | -------------------------------------------------------------------------------- /lineax/_solver/cholesky.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, TypeAlias 16 | 17 | import jax.flatten_util as jfu 18 | import jax.scipy as jsp 19 | from jaxtyping import Array, PyTree 20 | 21 | from .._operator import ( 22 | AbstractLinearOperator, 23 | is_negative_semidefinite, 24 | is_positive_semidefinite, 25 | ) 26 | from .._solution import RESULTS 27 | from .._solve import AbstractLinearSolver 28 | 29 | 30 | _CholeskyState: TypeAlias = tuple[Array, bool] 31 | 32 | 33 | class Cholesky(AbstractLinearSolver[_CholeskyState]): 34 | """Cholesky solver for linear systems. This is generally the preferred solver for 35 | positive or negative definite systems. 36 | 37 | Equivalent to `scipy.linalg.solve(..., assume_a="pos")`. 38 | 39 | The operator must be square, nonsingular, and either positive or negative definite. 40 | """ 41 | 42 | def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): 43 | del options 44 | is_nsd = is_negative_semidefinite(operator) 45 | if not (is_positive_semidefinite(operator) | is_nsd): 46 | raise ValueError( 47 | "`Cholesky(..., normal=False)` may only be used for positive " 48 | "or negative definite linear operators" 49 | ) 50 | matrix = operator.as_matrix() 51 | m, n = matrix.shape 52 | if m != n: 53 | raise ValueError( 54 | "`Cholesky(..., normal=False)` may only be used for linear solves " 55 | "with square matrices" 56 | ) 57 | if is_nsd: 58 | matrix = -matrix 59 | factor, lower = jsp.linalg.cho_factor(matrix) 60 | # Fix lower triangular for simplicity. 61 | assert lower is False 62 | return factor, is_nsd 63 | 64 | def compute( 65 | self, state: _CholeskyState, vector: PyTree[Array], options: dict[str, Any] 66 | ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: 67 | factor, is_nsd = state 68 | del options 69 | # Cholesky => PSD => symmetric => (in_structure == out_structure) => 70 | # we don't need to use packed structures. 71 | vector, unflatten = jfu.ravel_pytree(vector) 72 | solution = jsp.linalg.cho_solve((factor, False), vector) 73 | if is_nsd: 74 | solution = -solution 75 | solution = unflatten(solution) 76 | return solution, RESULTS.successful, {} 77 | 78 | def transpose(self, state: _CholeskyState, options: dict[str, Any]): 79 | # Matrix is self-adjoint 80 | factor, is_nsd = state 81 | return (factor.conj(), is_nsd), options 82 | 83 | def conj(self, state: _CholeskyState, options: dict[str, Any]): 84 | # Matrix is self-adjoint 85 | factor, is_nsd = state 86 | return (factor.conj(), is_nsd), options 87 | 88 | def allow_dependent_columns(self, operator): 89 | return False 90 | 91 | def allow_dependent_rows(self, operator): 92 | return False 93 | 94 | 95 | Cholesky.__init__.__doc__ = """**Arguments:** 96 | 97 | Nothing. 98 | """ 99 | -------------------------------------------------------------------------------- /lineax/_solver/tridiagonal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, TypeAlias 16 | 17 | import jax.lax as lax 18 | import jax.numpy as jnp 19 | from jaxtyping import Array, PyTree 20 | 21 | from .._operator import AbstractLinearOperator, is_tridiagonal, tridiagonal 22 | from .._solution import RESULTS 23 | from .._solve import AbstractLinearSolver 24 | from .misc import ( 25 | pack_structures, 26 | PackedStructures, 27 | ravel_vector, 28 | transpose_packed_structures, 29 | unravel_solution, 30 | ) 31 | 32 | 33 | _TridiagonalState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures] 34 | 35 | 36 | class Tridiagonal(AbstractLinearSolver[_TridiagonalState]): 37 | """Tridiagonal solver for linear systems, uses the LAPACK/cusparse implementation 38 | of Gaussian elimination with partial pivotting (which increases stability). 39 | .""" 40 | 41 | def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): 42 | del options 43 | if operator.in_size() != operator.out_size(): 44 | raise ValueError( 45 | "`Tridiagonal` may only be used for linear solves with square matrices" 46 | ) 47 | if not is_tridiagonal(operator): 48 | raise ValueError( 49 | "`Tridiagonal` may only be used for linear solves with tridiagonal " 50 | "matrices" 51 | ) 52 | return tridiagonal(operator), pack_structures(operator) 53 | 54 | def compute( 55 | self, 56 | state: _TridiagonalState, 57 | vector, 58 | options, 59 | ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: 60 | (diagonal, lower_diagonal, upper_diagonal), packed_structures = state 61 | del state, options 62 | vector = ravel_vector(vector, packed_structures) 63 | 64 | solution = lax.linalg.tridiagonal_solve( 65 | jnp.append(0.0, lower_diagonal), 66 | diagonal, 67 | jnp.append(upper_diagonal, 0.0), 68 | vector[:, None], 69 | ).flatten() 70 | 71 | solution = unravel_solution(solution, packed_structures) 72 | return solution, RESULTS.successful, {} 73 | 74 | def transpose(self, state: _TridiagonalState, options: dict[str, Any]): 75 | (diagonal, lower_diagonal, upper_diagonal), packed_structures = state 76 | transposed_packed_structures = transpose_packed_structures(packed_structures) 77 | transpose_diagonals = (diagonal, upper_diagonal, lower_diagonal) 78 | transpose_state = (transpose_diagonals, transposed_packed_structures) 79 | return transpose_state, options 80 | 81 | def conj(self, state: _TridiagonalState, options: dict[str, Any]): 82 | (diagonal, lower_diagonal, upper_diagonal), packed_structures = state 83 | conj_diagonals = (diagonal.conj(), lower_diagonal.conj(), upper_diagonal.conj()) 84 | conj_state = (conj_diagonals, packed_structures) 85 | return conj_state, options 86 | 87 | def allow_dependent_columns(self, operator): 88 | return False 89 | 90 | def allow_dependent_rows(self, operator): 91 | return False 92 | 93 | 94 | Tridiagonal.__init__.__doc__ = """**Arguments:** 95 | 96 | Nothing. 97 | """ 98 | -------------------------------------------------------------------------------- /lineax/_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Callable 16 | 17 | import equinox as eqx 18 | import jax 19 | import jax.numpy as jnp 20 | import jax.tree_util as jtu 21 | from jaxtyping import Array, ArrayLike, Bool, PyTree # pyright:ignore 22 | 23 | 24 | def tree_where( 25 | pred: Bool[ArrayLike, ""], true: PyTree[ArrayLike], false: PyTree[ArrayLike] 26 | ) -> PyTree[Array]: 27 | keep = lambda a, b: jnp.where(pred, a, b) 28 | return jtu.tree_map(keep, true, false) 29 | 30 | 31 | def resolve_rcond(rcond, n, m, dtype): 32 | if rcond is None: 33 | # This `2 *` is a heuristic: I have seen very rare failures without it, in ways 34 | # that seem to depend on JAX compilation state. (E.g. running unrelated JAX 35 | # computations beforehand, in a completely different JIT-compiled region, can 36 | # result in differences in the success/failure of the solve.) 37 | return 2 * jnp.finfo(dtype).eps * max(n, m) 38 | else: 39 | return jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) 40 | 41 | 42 | class NoneAux(eqx.Module): 43 | fn: Callable 44 | 45 | def __call__(self, *args, **kwargs): 46 | return self.fn(*args, **kwargs), None 47 | 48 | 49 | def jacobian(fn, in_size, out_size, holomorphic=False, has_aux=False, jac=None): 50 | if jac is None: 51 | # Heuristic for which is better in each case 52 | # These could probably be tuned a lot more. 53 | jac_fwd = (in_size < 100) or (in_size <= 1.5 * out_size) 54 | elif jac == "fwd": 55 | jac_fwd = True 56 | elif jac == "bwd": 57 | jac_fwd = False 58 | else: 59 | raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") 60 | if jac_fwd: 61 | return jax.jacfwd(fn, holomorphic=holomorphic, has_aux=has_aux) 62 | else: 63 | return jax.jacrev(fn, holomorphic=holomorphic, has_aux=has_aux) 64 | 65 | 66 | def _asarray(dtype, x): 67 | return jnp.asarray(x, dtype=dtype) 68 | 69 | 70 | # Work around JAX issue #15676 71 | _asarray = jax.custom_jvp(_asarray, nondiff_argnums=(0,)) 72 | 73 | 74 | @_asarray.defjvp 75 | def _asarray_jvp(dtype, x, tx): 76 | (x,) = x 77 | (tx,) = tx 78 | return _asarray(dtype, x), _asarray(dtype, tx) 79 | 80 | 81 | def default_floating_dtype(): 82 | if jax.config.jax_enable_x64: # pyright: ignore 83 | return jnp.float64 84 | else: 85 | return jnp.float32 86 | 87 | 88 | def inexact_asarray(x): 89 | dtype = jnp.result_type(x) 90 | if not jnp.issubdtype(jnp.result_type(x), jnp.inexact): 91 | dtype = default_floating_dtype() 92 | return _asarray(dtype, x) 93 | 94 | 95 | def complex_to_real_dtype(dtype): 96 | return jnp.finfo(dtype).dtype 97 | 98 | 99 | def strip_weak_dtype(tree: PyTree) -> PyTree: 100 | return jtu.tree_map( 101 | lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding) 102 | if type(x) is jax.ShapeDtypeStruct 103 | else x, 104 | tree, 105 | ) 106 | 107 | 108 | def structure_equal(x, y) -> bool: 109 | x = strip_weak_dtype(jax.eval_shape(lambda: x)) 110 | y = strip_weak_dtype(jax.eval_shape(lambda: y)) 111 | return eqx.tree_equal(x, y) is True 112 | -------------------------------------------------------------------------------- /docs/_static/custom_css.css: -------------------------------------------------------------------------------- 1 | /* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ 2 | html { 3 | scroll-padding-top: 50px; 4 | } 5 | 6 | /* Fit the Twitter handle alongside the GitHub one in the top right. */ 7 | 8 | div.md-header__source { 9 | width: revert; 10 | max-width: revert; 11 | } 12 | 13 | a.md-source { 14 | display: inline-block; 15 | } 16 | 17 | .md-source__repository { 18 | max-width: 100%; 19 | } 20 | 21 | /* Emphasise sections of nav on left hand side */ 22 | 23 | nav.md-nav { 24 | padding-left: 5px; 25 | } 26 | 27 | nav.md-nav--secondary { 28 | border-left: revert !important; 29 | } 30 | 31 | .md-nav__title { 32 | font-size: 0.9rem; 33 | } 34 | 35 | .md-nav__item--section > .md-nav__link { 36 | font-size: 0.9rem; 37 | } 38 | 39 | /* Indent autogenerated documentation */ 40 | 41 | div.doc-contents { 42 | padding-left: 25px; 43 | border-left: 4px solid rgba(230, 230, 230); 44 | } 45 | 46 | /* Increase visibility of splitters "---" */ 47 | 48 | [data-md-color-scheme="default"] .md-typeset hr { 49 | border-bottom-color: rgb(0, 0, 0); 50 | border-bottom-width: 1pt; 51 | } 52 | 53 | [data-md-color-scheme="slate"] .md-typeset hr { 54 | border-bottom-color: rgb(230, 230, 230); 55 | } 56 | 57 | /* More space at the bottom of the page */ 58 | 59 | .md-main__inner { 60 | margin-bottom: 1.5rem; 61 | } 62 | 63 | /* Remove prev/next footer buttons */ 64 | 65 | .md-footer__inner { 66 | display: none; 67 | } 68 | 69 | /* Change font sizes */ 70 | 71 | html { 72 | /* Decrease font size for overall webpage 73 | Down from 137.5% which is the Material default */ 74 | font-size: 110%; 75 | } 76 | 77 | .md-typeset .admonition { 78 | /* Increase font size in admonitions */ 79 | font-size: 100% !important; 80 | } 81 | 82 | .md-typeset details { 83 | /* Increase font size in details */ 84 | font-size: 100% !important; 85 | } 86 | 87 | .md-typeset h1 { 88 | font-size: 1.6rem; 89 | } 90 | 91 | .md-typeset h2 { 92 | font-size: 1.5rem; 93 | } 94 | 95 | .md-typeset h3 { 96 | font-size: 1.3rem; 97 | } 98 | 99 | .md-typeset h4 { 100 | font-size: 1.1rem; 101 | } 102 | 103 | .md-typeset h5 { 104 | font-size: 0.9rem; 105 | } 106 | 107 | .md-typeset h6 { 108 | font-size: 0.8rem; 109 | } 110 | 111 | /* Bugfix: remove the superfluous parts generated when doing: 112 | 113 | ??? Blah 114 | 115 | ::: library.something 116 | */ 117 | 118 | .md-typeset details .mkdocstrings > h4 { 119 | display: none; 120 | } 121 | 122 | .md-typeset details .mkdocstrings > h5 { 123 | display: none; 124 | } 125 | 126 | /* Change default colours for tags */ 127 | 128 | [data-md-color-scheme="default"] { 129 | --md-typeset-a-color: rgb(0, 189, 164) !important; 130 | } 131 | [data-md-color-scheme="slate"] { 132 | --md-typeset-a-color: rgb(0, 189, 164) !important; 133 | } 134 | 135 | /* Highlight functions, classes etc. type signatures. Really helps to make clear where 136 | one item ends and another begins. */ 137 | 138 | [data-md-color-scheme="default"] { 139 | --doc-heading-color: #DDD; 140 | --doc-heading-border-color: #CCC; 141 | --doc-heading-color-alt: #F0F0F0; 142 | } 143 | [data-md-color-scheme="slate"] { 144 | --doc-heading-color: rgb(25,25,33); 145 | --doc-heading-border-color: rgb(25,25,33); 146 | --doc-heading-color-alt: rgb(33,33,44); 147 | --md-code-bg-color: rgb(38,38,50); 148 | } 149 | 150 | h4.doc-heading { 151 | /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ 152 | background-color: var(--doc-heading-color); 153 | border: solid var(--doc-heading-border-color); 154 | border-width: 1.5pt; 155 | border-radius: 2pt; 156 | padding: 0pt 5pt 2pt 5pt; 157 | } 158 | h5.doc-heading, h6.heading { 159 | background-color: var(--doc-heading-color-alt); 160 | border-radius: 2pt; 161 | padding: 0pt 5pt 2pt 5pt; 162 | } 163 | -------------------------------------------------------------------------------- /lineax/_solver/svd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, TypeAlias 16 | 17 | import jax.lax as lax 18 | import jax.numpy as jnp 19 | import jax.scipy as jsp 20 | from jaxtyping import Array, PyTree 21 | 22 | from .._misc import resolve_rcond 23 | from .._operator import AbstractLinearOperator 24 | from .._solution import RESULTS 25 | from .._solve import AbstractLinearSolver 26 | from .misc import ( 27 | pack_structures, 28 | PackedStructures, 29 | ravel_vector, 30 | transpose_packed_structures, 31 | unravel_solution, 32 | ) 33 | 34 | 35 | _SVDState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures] 36 | 37 | 38 | class SVD(AbstractLinearSolver[_SVDState]): 39 | """SVD solver for linear systems. 40 | 41 | This solver can handle any operator, even nonsquare or singular ones. In these 42 | cases it will return the pseudoinverse solution to the linear system. 43 | 44 | Equivalent to `scipy.linalg.lstsq`. 45 | """ 46 | 47 | rcond: float | None = None 48 | 49 | def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): 50 | del options 51 | svd = jsp.linalg.svd(operator.as_matrix(), full_matrices=False) 52 | packed_structures = pack_structures(operator) 53 | return svd, packed_structures 54 | 55 | def compute( 56 | self, 57 | state: _SVDState, 58 | vector: PyTree[Array], 59 | options: dict[str, Any], 60 | ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: 61 | del options 62 | (u, s, vt), packed_structures = state 63 | vector = ravel_vector(vector, packed_structures) 64 | m, _ = u.shape 65 | _, n = vt.shape 66 | rcond = resolve_rcond(self.rcond, n, m, s.dtype) 67 | rcond = jnp.array(rcond, dtype=s.dtype) 68 | if s.size > 0: 69 | rcond = rcond * s[0] 70 | # Not >=, or this fails with a matrix of all-zeros. 71 | mask = s > rcond 72 | rank = mask.sum() 73 | safe_s = jnp.where(mask, s, 1) 74 | s_inv = jnp.where(mask, jnp.array(1.0) / safe_s, 0).astype(u.dtype) 75 | uTb = jnp.matmul(u.conj().T, vector, precision=lax.Precision.HIGHEST) 76 | solution = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) 77 | solution = unravel_solution(solution, packed_structures) 78 | return solution, RESULTS.successful, {"rank": rank} 79 | 80 | def transpose(self, state: _SVDState, options: dict[str, Any]): 81 | del options 82 | (u, s, vt), packed_structures = state 83 | transposed_packed_structures = transpose_packed_structures(packed_structures) 84 | transpose_state = (vt.T, s, u.T), transposed_packed_structures 85 | transpose_options = {} 86 | return transpose_state, transpose_options 87 | 88 | def conj(self, state: _SVDState, options: dict[str, Any]): 89 | del options 90 | (u, s, vt), packed_structures = state 91 | conj_state = (u.conj(), s, vt.conj()), packed_structures 92 | conj_options = {} 93 | return conj_state, conj_options 94 | 95 | def allow_dependent_columns(self, operator): 96 | return True 97 | 98 | def allow_dependent_rows(self, operator): 99 | return True 100 | 101 | 102 | SVD.__init__.__doc__ = """**Arguments**: 103 | 104 | - `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine 105 | precision times `max(N, M)`, where `(N, M)` is the shape of the operator. (I.e. 106 | `N` is the output size and `M` is the input size.) 107 | """ 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Lineax 2 | 3 | Lineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.) 4 | 5 | Features include: 6 | - PyTree-valued matrices and vectors; 7 | - General linear operators for Jacobians, transposes, etc.; 8 | - Efficient linear least squares (e.g. QR solvers); 9 | - Numerically stable gradients through linear least squares; 10 | - Support for structured (e.g. symmetric) matrices; 11 | - Improved compilation times; 12 | - Improved runtime of some algorithms; 13 | - Support for both real-valued and complex-valued inputs; 14 | - All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support, etc. 15 | 16 | ## Installation 17 | 18 | ```bash 19 | pip install lineax 20 | ``` 21 | 22 | Requires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+. 23 | 24 | ## Documentation 25 | 26 | Available at [https://docs.kidger.site/lineax](https://docs.kidger.site/lineax). 27 | 28 | ## Quick examples 29 | 30 | Lineax can solve a least squares problem with an explicit matrix operator: 31 | 32 | ```python 33 | import jax.random as jr 34 | import lineax as lx 35 | 36 | matrix_key, vector_key = jr.split(jr.PRNGKey(0)) 37 | matrix = jr.normal(matrix_key, (10, 8)) 38 | vector = jr.normal(vector_key, (10,)) 39 | operator = lx.MatrixLinearOperator(matrix) 40 | solution = lx.linear_solve(operator, vector, solver=lx.QR()) 41 | ``` 42 | 43 | or Lineax can solve a problem without ever materializing a matrix, as done in this 44 | quadratic solve: 45 | 46 | ```python 47 | import jax 48 | import lineax as lx 49 | 50 | key = jax.random.PRNGKey(0) 51 | y = jax.random.normal(key, (10,)) 52 | 53 | def quadratic_fn(y, args): 54 | return jax.numpy.sum((y - 1)**2) 55 | 56 | gradient_fn = jax.grad(quadratic_fn) 57 | hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag) 58 | solver = lx.CG(rtol=1e-6, atol=1e-6) 59 | out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver) 60 | minimum = y - out.value 61 | ``` 62 | 63 | ## Citation 64 | 65 | If you found this library to be useful in academic work, then please cite: ([arXiv link](https://arxiv.org/abs/2311.17283)) 66 | 67 | ```bibtex 68 | @article{lineax2023, 69 | title={Lineax: unified linear solves and linear least-squares in JAX and Equinox}, 70 | author={Jason Rader and Terry Lyons and Patrick Kidger}, 71 | journal={ 72 | AI for science workshop at Neural Information Processing Systems 2023, 73 | arXiv:2311.17283 74 | }, 75 | year={2023}, 76 | } 77 | ``` 78 | 79 | (Also consider starring the project on GitHub.) 80 | 81 | ## See also: other libraries in the JAX ecosystem 82 | 83 | **Always useful** 84 | [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! 85 | [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays. 86 | 87 | **Deep learning** 88 | [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. 89 | [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). 90 | [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). 91 | [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. 92 | 93 | **Scientific computing** 94 | [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers. 95 | [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. 96 | [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. 97 | [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. 98 | [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) 99 | 100 | **Awesome JAX** 101 | [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. 102 | -------------------------------------------------------------------------------- /lineax/_solver/triangular.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, TypeAlias 16 | 17 | import jax.scipy as jsp 18 | from jaxtyping import Array, PyTree 19 | 20 | from .._operator import ( 21 | AbstractLinearOperator, 22 | has_unit_diagonal, 23 | is_lower_triangular, 24 | is_upper_triangular, 25 | ) 26 | from .._solution import RESULTS 27 | from .._solve import AbstractLinearSolver 28 | from .misc import ( 29 | pack_structures, 30 | PackedStructures, 31 | ravel_vector, 32 | transpose_packed_structures, 33 | unravel_solution, 34 | ) 35 | 36 | 37 | _TriangularState: TypeAlias = tuple[Array, bool, bool, PackedStructures, bool] 38 | 39 | 40 | class Triangular(AbstractLinearSolver[_TriangularState]): 41 | """Triangular solver for linear systems. 42 | 43 | The operator should either be lower triangular or upper triangular. 44 | """ 45 | 46 | def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): 47 | del options 48 | if operator.in_size() != operator.out_size(): 49 | raise ValueError( 50 | "`Triangular` may only be used for linear solves with square matrices" 51 | ) 52 | if not (is_lower_triangular(operator) or is_upper_triangular(operator)): 53 | raise ValueError( 54 | "`Triangular` may only be used for linear solves with triangular " 55 | "matrices" 56 | ) 57 | return ( 58 | operator.as_matrix(), 59 | is_lower_triangular(operator), 60 | has_unit_diagonal(operator), 61 | pack_structures(operator), 62 | False, # transposed 63 | ) 64 | 65 | def compute( 66 | self, state: _TriangularState, vector: PyTree[Array], options: dict[str, Any] 67 | ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: 68 | matrix, lower, unit_diagonal, packed_structures, transpose = state 69 | del state, options 70 | vector = ravel_vector(vector, packed_structures) 71 | if transpose: 72 | trans = "T" 73 | else: 74 | trans = "N" 75 | solution = jsp.linalg.solve_triangular( 76 | matrix, vector, trans=trans, lower=lower, unit_diagonal=unit_diagonal 77 | ) 78 | solution = unravel_solution(solution, packed_structures) 79 | return solution, RESULTS.successful, {} 80 | 81 | def transpose(self, state: _TriangularState, options: dict[str, Any]): 82 | matrix, lower, unit_diagonal, packed_structures, transpose = state 83 | transposed_packed_structures = transpose_packed_structures(packed_structures) 84 | transpose_state = ( 85 | matrix, 86 | lower, 87 | unit_diagonal, 88 | transposed_packed_structures, 89 | not transpose, 90 | ) 91 | transpose_options = {} 92 | return transpose_state, transpose_options 93 | 94 | def conj(self, state: _TriangularState, options: dict[str, Any]): 95 | matrix, lower, unit_diagonal, packed_structures, transpose = state 96 | conj_state = ( 97 | matrix.conj(), 98 | lower, 99 | unit_diagonal, 100 | packed_structures, 101 | transpose, 102 | ) 103 | conj_options = {} 104 | return conj_state, conj_options 105 | 106 | def allow_dependent_columns(self, operator): 107 | return False 108 | 109 | def allow_dependent_rows(self, operator): 110 | return False 111 | 112 | 113 | Triangular.__init__.__doc__ = """**Arguments:** 114 | 115 | Nothing. 116 | """ 117 | -------------------------------------------------------------------------------- /tests/test_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.flatten_util as jfu 17 | import jax.numpy as jnp 18 | import lineax.internal as lxi 19 | 20 | from .helpers import tree_allclose 21 | 22 | 23 | def _square(x): 24 | return x * jnp.conj(x) 25 | 26 | 27 | def _two_norm(x): 28 | return jnp.sqrt(jnp.sum(_square(jfu.ravel_pytree(x)[0]))).real 29 | 30 | 31 | def _rms_norm(x): 32 | return jnp.sqrt(jnp.mean(_square(jfu.ravel_pytree(x)[0]))).real 33 | 34 | 35 | def _max_norm(x): 36 | return jnp.max(jnp.abs(jfu.ravel_pytree(x)[0])) 37 | 38 | 39 | def test_nonzero(): 40 | zero = [jnp.array(0.0), jnp.zeros((2, 2))] 41 | x = [jnp.array(1.0), jnp.arange(4.0).reshape(2, 2)] 42 | tx = [jnp.array(0.5), jnp.arange(1.0, 5.0).reshape(2, 2)] 43 | 44 | two = lxi.two_norm(x) 45 | rms = lxi.rms_norm(x) 46 | max = lxi.max_norm(x) 47 | true_two = _two_norm(x) 48 | true_rms = _rms_norm(x) 49 | true_max = _max_norm(x) 50 | assert jnp.allclose(two, true_two) 51 | assert jnp.allclose(rms, true_rms) 52 | assert jnp.allclose(max, true_max) 53 | 54 | two_jvp = jax.jvp(lxi.two_norm, (x,), (tx,)) 55 | true_two_jvp = jax.jvp(_two_norm, (x,), (tx,)) 56 | rms_jvp = jax.jvp(lxi.rms_norm, (x,), (tx,)) 57 | true_rms_jvp = jax.jvp(_rms_norm, (x,), (tx,)) 58 | max_jvp = jax.jvp(lxi.max_norm, (x,), (tx,)) 59 | true_max_jvp = jax.jvp(_max_norm, (x,), (tx,)) 60 | assert tree_allclose(two_jvp, true_two_jvp) 61 | assert tree_allclose(rms_jvp, true_rms_jvp) 62 | assert tree_allclose(max_jvp, true_max_jvp) 63 | 64 | two0_jvp = jax.jvp(lxi.two_norm, (x,), (zero,)) 65 | rms0_jvp = jax.jvp(lxi.rms_norm, (x,), (zero,)) 66 | max0_jvp = jax.jvp(lxi.max_norm, (x,), (zero,)) 67 | assert tree_allclose(two0_jvp, (true_two, jnp.array(0.0))) 68 | assert tree_allclose(rms0_jvp, (true_rms, jnp.array(0.0))) 69 | assert tree_allclose(max0_jvp, (true_max, jnp.array(0.0))) 70 | 71 | 72 | def test_zero(): 73 | zero = [jnp.array(0.0), jnp.zeros((2, 2))] 74 | tx = [jnp.array(0.5), jnp.arange(1.0, 5.0).reshape(2, 2)] 75 | for t in (zero, tx): 76 | two0 = jax.jvp(lxi.two_norm, (zero,), (t,)) 77 | rms0 = jax.jvp(lxi.rms_norm, (zero,), (t,)) 78 | max0 = jax.jvp(lxi.max_norm, (zero,), (t,)) 79 | true0 = (jnp.array(0.0), jnp.array(0.0)) 80 | assert tree_allclose(two0, true0) 81 | assert tree_allclose(rms0, true0) 82 | assert tree_allclose(max0, true0) 83 | 84 | 85 | def test_complex(): 86 | x = jnp.array([3 + 1.2j, -0.5 + 4.9j]) 87 | tx = jnp.array([2 - 0.3j, -0.7j]) 88 | two = jax.jvp(lxi.two_norm, (x,), (tx,)) 89 | true_two = jax.jvp(_two_norm, (x,), (tx,)) 90 | rms = jax.jvp(lxi.rms_norm, (x,), (tx,)) 91 | true_rms = jax.jvp(_rms_norm, (x,), (tx,)) 92 | max = jax.jvp(lxi.max_norm, (x,), (tx,)) 93 | true_max = jax.jvp(_max_norm, (x,), (tx,)) 94 | assert two[0].imag == 0 95 | assert tree_allclose(two, true_two) 96 | assert rms[0].imag == 0 97 | assert tree_allclose(rms, true_rms) 98 | assert max[0].imag == 0 99 | assert tree_allclose(max, true_max) 100 | 101 | 102 | def test_size_zero(): 103 | zero = jnp.array(0.0) 104 | for x in (jnp.array([]), [jnp.array([]), jnp.array([])]): 105 | assert tree_allclose(lxi.two_norm(x), zero) 106 | assert tree_allclose(lxi.rms_norm(x), zero) 107 | assert tree_allclose(lxi.max_norm(x), zero) 108 | assert tree_allclose(jax.jvp(lxi.two_norm, (x,), (x,)), (zero, zero)) 109 | assert tree_allclose(jax.jvp(lxi.rms_norm, (x,), (x,)), (zero, zero)) 110 | assert tree_allclose(jax.jvp(lxi.max_norm, (x,), (x,)), (zero, zero)) 111 | -------------------------------------------------------------------------------- /tests/test_adjoint.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.random as jr 4 | import lineax as lx 5 | import pytest 6 | from lineax import FunctionLinearOperator 7 | 8 | from .helpers import ( 9 | make_identity_operator, 10 | make_operators, 11 | make_tridiagonal_operator, 12 | make_trivial_diagonal_operator, 13 | tree_allclose, 14 | ) 15 | 16 | 17 | @pytest.mark.parametrize("make_operator", make_operators) 18 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 19 | def test_adjoint(make_operator, dtype, getkey): 20 | if ( 21 | make_operator is make_trivial_diagonal_operator 22 | or make_operator is make_identity_operator 23 | ): 24 | matrix = jnp.eye(4, dtype=dtype) 25 | tags = lx.diagonal_tag 26 | in_size = out_size = 4 27 | elif make_operator is make_tridiagonal_operator: 28 | matrix = jnp.eye(4, dtype=dtype) 29 | tags = lx.tridiagonal_tag 30 | in_size = out_size = 4 31 | else: 32 | matrix = jr.normal(getkey(), (3, 5), dtype=dtype) 33 | tags = () 34 | in_size = 5 35 | out_size = 3 36 | operator = make_operator(getkey, matrix, tags) 37 | v1, v2 = ( 38 | jr.normal(getkey(), (in_size,), dtype=dtype), 39 | jr.normal(getkey(), (out_size,), dtype=dtype), 40 | ) 41 | 42 | inner1 = operator.mv(v1) @ v2.conj() 43 | adjoint_op1 = lx.conj(operator).transpose() 44 | ov2 = adjoint_op1.mv(v2) 45 | inner2 = v1 @ ov2.conj() 46 | assert tree_allclose(inner1, inner2) 47 | 48 | adjoint_op2 = lx.conj(operator.transpose()) 49 | ov2 = adjoint_op2.mv(v2) 50 | inner2 = v1 @ ov2.conj() 51 | assert tree_allclose(inner1, inner2) 52 | 53 | 54 | def test_functional_pytree_adjoint(): 55 | def fn(y): 56 | return {"b": y["a"]} 57 | 58 | y_struct = jax.eval_shape(lambda: {"a": 0.0}) 59 | operator = FunctionLinearOperator(fn, y_struct) 60 | conj_operator = lx.conj(operator) 61 | assert tree_allclose(lx.materialise(conj_operator), lx.materialise(operator)) 62 | 63 | 64 | def test_functional_pytree_adjoint_complex(): 65 | def fn(y): 66 | return {"b": y["a"]} 67 | 68 | y_struct = jax.eval_shape(lambda: {"a": 0.0j}) 69 | operator = FunctionLinearOperator(fn, y_struct) 70 | conj_operator = lx.conj(operator) 71 | assert tree_allclose(lx.materialise(conj_operator), lx.materialise(operator)) 72 | 73 | 74 | if jax.config.jax_enable_x64: # pyright: ignore 75 | tol = 1e-12 76 | else: 77 | tol = 1e-6 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "solver", 82 | [ 83 | # in theory only 1 iteration is needed, but stopping criteria are 84 | # complicated, see gh #160 85 | lx.GMRES(tol, tol, max_steps=4, restart=1), 86 | lx.BiCGStab(tol, tol, max_steps=3), 87 | lx.NormalCG(tol, tol, max_steps=4), 88 | lx.CG(tol, tol, max_steps=3), 89 | ], 90 | ) 91 | def test_preconditioner_adjoint(solver): 92 | """Test for fix to gh #160""" 93 | # Nonsymmetric poorly conditioned matrix. Without preconditioning, 94 | # this would take 20+ iterations (100s for GMRES) 95 | key = jax.random.key(123) 96 | key, subkey = jax.random.split(key) 97 | A = jax.random.uniform(key, (10, 10)) 98 | A += jnp.diag(jnp.arange(A.shape[0]) ** 6).astype(A.dtype) 99 | b = jax.random.uniform(subkey, (A.shape[0],)) 100 | if isinstance(solver, lx.CG): 101 | A = A.T @ A 102 | tags = (lx.positive_semidefinite_tag,) 103 | else: 104 | tags = () 105 | 106 | A = lx.MatrixLinearOperator(A, tags=tags) 107 | # exact inverse, should only take ~1 iteration 108 | M = lx.MatrixLinearOperator( 109 | jnp.linalg.inv(A.matrix), 110 | tags=tags, 111 | ) 112 | 113 | def solve(b): 114 | out = lx.linear_solve( 115 | A, b, solver=solver, options={"preconditioner": M}, throw=True 116 | ) 117 | return out.value 118 | 119 | # if they don't converge then this will throw an error 120 | _ = solve(b) 121 | A1 = jax.jacfwd(solve)(b) 122 | A2 = jax.jacrev(solve)(b) 123 | 124 | # we also do a sanity check, dx/db should give A^{-1} 125 | assert tree_allclose(A1, jnp.linalg.inv(A.matrix), atol=tol, rtol=tol) 126 | assert tree_allclose(A2, jnp.linalg.inv(A.matrix), atol=tol, rtol=tol) 127 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | theme: 2 | name: material 3 | features: 4 | - navigation.sections # Sections are included in the navigation on the left. 5 | - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. 6 | - header.autohide # header disappears as you scroll 7 | palette: 8 | # Light mode / dark mode 9 | # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as 10 | # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. 11 | - scheme: default 12 | primary: white 13 | accent: amber 14 | toggle: 15 | icon: material/weather-night 16 | name: Switch to dark mode 17 | - scheme: slate 18 | primary: black 19 | accent: amber 20 | toggle: 21 | icon: material/weather-sunny 22 | name: Switch to light mode 23 | icon: 24 | repo: fontawesome/brands/github # GitHub logo in top right 25 | logo: "material/matrix" # lineax logo in top left 26 | favicon: "_static/favicon.png" 27 | custom_dir: "docs/_overrides" # Overriding part of the HTML 28 | 29 | # These additions are my own custom ones, having overridden a partial. 30 | twitter_bluesky_name: "@PatrickKidger" 31 | twitter_url: "https://twitter.com/PatrickKidger" 32 | bluesky_url: "https://PatrickKidger.bsky.social" 33 | 34 | site_name: lineax 35 | site_description: The documentation for the Lineax software library. 36 | site_author: Patrick Kidger 37 | site_url: https://docs.kidger.site/lineax 38 | 39 | repo_url: https://github.com/patrick-kidger/lineax 40 | repo_name: patrick-kidger/lineax 41 | edit_uri: "" 42 | 43 | strict: true # Don't allow warnings during the build process 44 | 45 | extra_javascript: 46 | # The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ 47 | - _static/mathjax.js 48 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 49 | 50 | extra_css: 51 | - _static/custom_css.css 52 | 53 | markdown_extensions: 54 | - pymdownx.arithmatex: # Render LaTeX via MathJax 55 | generic: true 56 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. 57 | - pymdownx.details # Allowing hidden expandable regions denoted by ??? 58 | - pymdownx.snippets: # Include one Markdown file into another 59 | base_path: docs 60 | - admonition 61 | - toc: 62 | permalink: "¤" # Adds a clickable permalink to each section heading 63 | toc_depth: 4 64 | 65 | plugins: 66 | - search: 67 | separator: '[\s\-,:!=\[\]()"/]+|(?!\b)(?=[A-Z][a-z])|\.(?!\d)|&[lg]t;' 68 | - include_exclude_files: 69 | include: 70 | - ".htaccess" 71 | exclude: 72 | - "_overrides" 73 | - "examples/.ipynb_checkpoints/" 74 | - ipynb 75 | - hippogriffe: 76 | extra_public_objects: 77 | - jax.ShapeDtypeStruct 78 | - mkdocstrings: 79 | handlers: 80 | python: 81 | options: 82 | force_inspection: true 83 | heading_level: 4 84 | inherited_members: true 85 | members_order: source 86 | show_bases: false 87 | show_if_no_docstring: true 88 | show_overloads: false 89 | show_root_heading: true 90 | show_signature_annotations: true 91 | show_source: false 92 | show_symbol_type_heading: true 93 | show_symbol_type_toc: true 94 | 95 | nav: 96 | - 'index.md' 97 | - Examples: 98 | - 'examples/classical_solve.ipynb' 99 | - 'examples/least_squares.ipynb' 100 | - 'examples/structured_matrices.ipynb' 101 | - 'examples/no_materialisation.ipynb' 102 | - 'examples/operators.ipynb' 103 | - 'examples/complex_solve.ipynb' 104 | - API: 105 | - 'api/linear_solve.md' 106 | - 'api/solvers.md' 107 | - 'api/operators.md' 108 | - 'api/tags.md' 109 | - 'api/solution.md' 110 | - 'api/functions.md' 111 | - 'faq.md' 112 | -------------------------------------------------------------------------------- /docs/examples/no_materialisation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a7299095-8906-4867-82ef-d6b84b161366", 6 | "metadata": {}, 7 | "source": [ 8 | "# Using only matrix-vector operations\n", 9 | "\n", 10 | "When solving a linear system $Ax = b$, it is relatively common not to have immediate access to the full matrix $A$, but only to a function $F(x) = Ax$ computing the matrix-vector product. (We could compute $A$ from $F$, but is the matrix is large then this may be very inefficient.)\n", 11 | "\n", 12 | "**Example: Newton's method**\n", 13 | "\n", 14 | "For example, this comes up when using [Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method#k_variables,_k_functions). In this case, we have a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^n$, and wish to find the $\\delta \\in \\mathbb{R}^n$ for which $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\; \\delta = -f(y)$. (Where $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\in \\mathbb{R}^{n \\times n}$ is a matrix: it is the Jacobian of $f$.)\n", 15 | "\n", 16 | "In this case it is possible to use forward-mode autodifferentiation to evaluate $F(x) = \\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\; x$, without ever instantiating the whole Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y)$. Indeed, JAX has a [Jacobian-vector product function](https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html#jax.jvp) for exactly this purpose.\n", 17 | "```python\n", 18 | "f = ...\n", 19 | "y = ...\n", 20 | "\n", 21 | "def F(x):\n", 22 | " \"\"\"Computes (df/dy) @ x.\"\"\"\n", 23 | " _, out = jax.jvp(f, (y,), (x,))\n", 24 | " return out\n", 25 | "```\n", 26 | "\n", 27 | "**Solving a linear system using only matrix-vector operations**\n", 28 | "\n", 29 | "Lineax offers [iterative solvers](../api/solvers.md#iterative-solvers), which are capable of solving a linear system knowing only its matrix-vector products." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "id": "b221ee1f-bd6b-4cbf-b69b-ed2e388602e1", 36 | "metadata": { 37 | "tags": [] 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "import jax.numpy as jnp\n", 42 | "import lineax as lx\n", 43 | "from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n", 44 | "\n", 45 | "\n", 46 | "def f(y: Float[Array, \"3\"], args) -> Float[Array, \"3\"]:\n", 47 | " y0, y1, y2 = y\n", 48 | " f0 = 5 * y0 + y1**2\n", 49 | " f1 = y1 - y2 + 5\n", 50 | " f2 = y0 / (1 + 5 * y2**2)\n", 51 | " return jnp.stack([f0, f1, f2])\n", 52 | "\n", 53 | "\n", 54 | "y = jnp.array([1.0, 2.0, 3.0])\n", 55 | "operator = lx.JacobianLinearOperator(f, y, args=None)\n", 56 | "vector = f(y, args=None)\n", 57 | "solver = lx.NormalCG(rtol=1e-6, atol=1e-6)\n", 58 | "solution = lx.linear_solve(operator, vector, solver)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "87568426-35ed-404b-bf78-425a6f519218", 64 | "metadata": {}, 65 | "source": [ 66 | "!!! warning\n", 67 | "\n", 68 | " Note that iterative solvers are something of a \"last resort\", and they are not suitable for all problems.\n", 69 | "\n", 70 | " - [CG](https://en.wikipedia.org/wiki/Conjugate_gradient_method) requires that the problem be positive or negative semidefinite.\n", 71 | " - Normalised CG (this is CG applied to the \"normal equations\" $(A^\\top A) x = (A^\\top b)$; note that $A^\\top A$ is always positive semidefinite) squares the condition number of $A$. In practice this means it may produce low-accuracy results if used with matrices with high condition number.\n", 72 | " - [BiCGStab](https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method) and [GMRES](https://en.wikipedia.org/wiki/Generalized_minimal_residual_method) will fail on many problems. They are primarily meant as specialised tools for e.g. the matrices that arise when solving elliptic systems." 73 | ] 74 | } 75 | ], 76 | "metadata": { 77 | "kernelspec": { 78 | "display_name": "py39", 79 | "language": "python", 80 | "name": "py39" 81 | }, 82 | "language_info": { 83 | "codemirror_mode": { 84 | "name": "ipython", 85 | "version": 3 86 | }, 87 | "file_extension": ".py", 88 | "mimetype": "text/x-python", 89 | "name": "python", 90 | "nbconvert_exporter": "python", 91 | "pygments_lexer": "ipython3", 92 | "version": "3.9.16" 93 | } 94 | }, 95 | "nbformat": 4, 96 | "nbformat_minor": 5 97 | } 98 | -------------------------------------------------------------------------------- /lineax/_solver/qr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, TypeAlias 16 | 17 | import jax.numpy as jnp 18 | import jax.scipy as jsp 19 | from jaxtyping import Array, PyTree 20 | 21 | from .._solution import RESULTS 22 | from .._solve import AbstractLinearSolver 23 | from .misc import ( 24 | pack_structures, 25 | PackedStructures, 26 | ravel_vector, 27 | transpose_packed_structures, 28 | unravel_solution, 29 | ) 30 | 31 | 32 | _QRState: TypeAlias = tuple[tuple[Array, Array], bool, PackedStructures] 33 | 34 | 35 | class QR(AbstractLinearSolver): 36 | """QR solver for linear systems. 37 | 38 | This solver can handle non-square operators. 39 | 40 | This is usually the preferred solver when dealing with non-square operators. 41 | 42 | !!! info 43 | 44 | Note that whilst this does handle non-square operators, it still can only 45 | handle full-rank operators. 46 | 47 | This is because JAX does not currently support a rank-revealing/pivoted QR 48 | decomposition, see [issue #12897](https://github.com/google/jax/issues/12897). 49 | 50 | For such use cases, switch to [`lineax.SVD`][] instead. 51 | """ 52 | 53 | def init(self, operator, options): 54 | del options 55 | matrix = operator.as_matrix() 56 | m, n = matrix.shape 57 | transpose = n > m 58 | if transpose: 59 | matrix = matrix.T 60 | qr = jnp.linalg.qr(matrix, mode="reduced") # pyright: ignore 61 | packed_structures = pack_structures(operator) 62 | return qr, transpose, packed_structures 63 | 64 | def compute( 65 | self, 66 | state: _QRState, 67 | vector: PyTree[Array], 68 | options: dict[str, Any], 69 | ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: 70 | (q, r), transpose, packed_structures = state 71 | del state, options 72 | vector = ravel_vector(vector, packed_structures) 73 | if transpose: 74 | # Minimal norm solution if underdetermined. 75 | solution = q.conj() @ jsp.linalg.solve_triangular( 76 | r, vector, trans="T", unit_diagonal=False 77 | ) 78 | else: 79 | # Least squares solution if overdetermined. 80 | solution = jsp.linalg.solve_triangular( 81 | r, q.T.conj() @ vector, trans="N", unit_diagonal=False 82 | ) 83 | solution = unravel_solution(solution, packed_structures) 84 | return solution, RESULTS.successful, {} 85 | 86 | def transpose(self, state: _QRState, options: dict[str, Any]): 87 | (q, r), transpose, structures = state 88 | transposed_packed_structures = transpose_packed_structures(structures) 89 | transpose_state = (q, r), not transpose, transposed_packed_structures 90 | transpose_options = {} 91 | return transpose_state, transpose_options 92 | 93 | def conj(self, state: _QRState, options: dict[str, Any]): 94 | (q, r), transpose, structures = state 95 | conj_state = ( 96 | (q.conj(), r.conj()), 97 | transpose, 98 | structures, 99 | ) 100 | conj_options = {} 101 | return conj_state, conj_options 102 | 103 | def allow_dependent_columns(self, operator): 104 | rows = operator.out_size() 105 | columns = operator.in_size() 106 | # We're able to pull an efficiency trick here. 107 | # 108 | # As we don't use a rank-revealing implementation, then we always require that 109 | # the operator have full rank. 110 | # 111 | # So if we have columns <= rows, then we know that all our columns are linearly 112 | # independent. We can return `False` and get a computationally cheaper jvp rule. 113 | return columns > rows 114 | 115 | def allow_dependent_rows(self, operator): 116 | rows = operator.out_size() 117 | columns = operator.in_size() 118 | return rows > columns 119 | 120 | 121 | QR.__init__.__doc__ = """**Arguments:** 122 | 123 | Nothing. 124 | """ 125 | -------------------------------------------------------------------------------- /tests/test_vmap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import equinox as eqx 16 | import jax 17 | import jax.numpy as jnp 18 | import jax.random as jr 19 | import lineax as lx 20 | import pytest 21 | 22 | from .helpers import ( 23 | construct_matrix, 24 | construct_singular_matrix, 25 | make_jac_operator, 26 | make_matrix_operator, 27 | solvers_tags_pseudoinverse, 28 | tree_allclose, 29 | ) 30 | 31 | 32 | @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) 33 | @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) 34 | @pytest.mark.parametrize("use_state", (True, False)) 35 | @pytest.mark.parametrize( 36 | "make_matrix", 37 | ( 38 | construct_matrix, 39 | construct_singular_matrix, 40 | ), 41 | ) 42 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 43 | def test_vmap( 44 | getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype 45 | ): 46 | if (make_matrix is construct_matrix) or pseudoinverse: 47 | 48 | def wrap_solve(matrix, vector): 49 | operator = make_operator(getkey, matrix, tags) 50 | if use_state: 51 | state = solver.init(operator, options={}) 52 | return lx.linear_solve(operator, vector, solver, state=state).value 53 | else: 54 | return lx.linear_solve(operator, vector, solver).value 55 | 56 | for op_axis, vec_axis in ( 57 | (None, 0), 58 | (eqx.if_array(0), None), 59 | (eqx.if_array(0), 0), 60 | ): 61 | if op_axis is None: 62 | axis_size = None 63 | out_axes = None 64 | else: 65 | axis_size = 10 66 | out_axes = eqx.if_array(0) 67 | 68 | (matrix,) = eqx.filter_vmap( 69 | lambda getkey, solver, tags: make_matrix( 70 | getkey, solver, tags, dtype=dtype 71 | ), 72 | axis_size=axis_size, 73 | out_axes=out_axes, 74 | )(getkey, solver, tags) 75 | out_dim = matrix.shape[-2] 76 | 77 | if vec_axis is None: 78 | vec = jr.normal(getkey(), (out_dim,), dtype=dtype) 79 | else: 80 | vec = jr.normal(getkey(), (10, out_dim), dtype=dtype) 81 | 82 | jax_result, _, _, _ = eqx.filter_vmap( 83 | jnp.linalg.lstsq, 84 | in_axes=(op_axis, vec_axis), # pyright: ignore 85 | )(matrix, vec) 86 | lx_result = eqx.filter_vmap(wrap_solve, in_axes=(op_axis, vec_axis))( 87 | matrix, vec 88 | ) 89 | assert tree_allclose(lx_result, jax_result) 90 | 91 | 92 | # https://github.com/patrick-kidger/lineax/issues/101 93 | def test_grad_vmap_basic(getkey): 94 | A = jr.normal(getkey(), (16, 8)) 95 | B = jr.normal(getkey(), (128, 16)) 96 | 97 | @jax.jit 98 | @jax.grad 99 | def fn(A): 100 | op = lx.MatrixLinearOperator(A) 101 | return jax.vmap( 102 | lambda b: lx.linear_solve( 103 | op, b, lx.AutoLinearSolver(well_posed=False) 104 | ).value 105 | )(B).mean() 106 | 107 | fn(A) 108 | 109 | 110 | def test_grad_vmap_advanced(getkey): 111 | # this is a more complicated version of the above test, in which the batch axes and 112 | # the undefinedprimals do not necessarily line up in the same arguments. 113 | A = jr.normal(getkey(), (2, 8)), jr.normal(getkey(), (3, 8, 128)) 114 | B = jr.normal(getkey(), (2, 128)), jr.normal(getkey(), (3,)) 115 | 116 | output_structure = ( 117 | jax.ShapeDtypeStruct((2,), jnp.float64), 118 | jax.ShapeDtypeStruct((3,), jnp.float64), 119 | ) 120 | 121 | def to_vmap(A, B): 122 | op = lx.PyTreeLinearOperator(A, output_structure) 123 | return lx.linear_solve(op, B, lx.AutoLinearSolver(well_posed=False)).value 124 | 125 | @jax.jit 126 | @jax.grad 127 | def fn(A): 128 | return jax.vmap(to_vmap, in_axes=((None, 2), (1, None)))(A, B).mean() 129 | 130 | fn(A) 131 | -------------------------------------------------------------------------------- /lineax/_solver/diagonal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, TypeAlias 16 | 17 | import jax.numpy as jnp 18 | from jaxtyping import Array, PyTree 19 | 20 | from .._misc import resolve_rcond 21 | from .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal 22 | from .._solution import RESULTS 23 | from .._solve import AbstractLinearSolver 24 | from .misc import ( 25 | pack_structures, 26 | PackedStructures, 27 | ravel_vector, 28 | transpose_packed_structures, 29 | unravel_solution, 30 | ) 31 | 32 | 33 | _DiagonalState: TypeAlias = tuple[Array | None, PackedStructures] 34 | 35 | 36 | class Diagonal(AbstractLinearSolver[_DiagonalState]): 37 | """Diagonal solver for linear systems. 38 | 39 | Requires that the operator be diagonal. Then $Ax = b$, with $A = diag[a]$, is 40 | solved simply by doing an elementwise division $x = b / a$. 41 | 42 | This solver can handle singular operators (i.e. diagonal entries with value 0). 43 | """ 44 | 45 | well_posed: bool = False 46 | rcond: float | None = None 47 | 48 | def init( 49 | self, operator: AbstractLinearOperator, options: dict[str, Any] 50 | ) -> _DiagonalState: 51 | del options 52 | if operator.in_size() != operator.out_size(): 53 | raise ValueError( 54 | "`Diagonal` may only be used for linear solves with square matrices" 55 | ) 56 | if not is_diagonal(operator): 57 | raise ValueError( 58 | "`Diagonal` may only be used for linear solves with diagonal matrices" 59 | ) 60 | packed_structures = pack_structures(operator) 61 | if has_unit_diagonal(operator): 62 | return None, packed_structures 63 | else: 64 | return diagonal(operator), packed_structures 65 | 66 | def compute( 67 | self, state: _DiagonalState, vector: PyTree[Array], options: dict[str, Any] 68 | ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: 69 | diag, packed_structures = state 70 | del state, options 71 | unit_diagonal = diag is None 72 | vector = ravel_vector(vector, packed_structures) 73 | if unit_diagonal: 74 | solution = vector 75 | else: 76 | if not self.well_posed: 77 | (size,) = diag.shape 78 | rcond = resolve_rcond(self.rcond, size, size, diag.dtype) 79 | abs_diag = jnp.abs(diag) 80 | diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf) 81 | solution = vector / diag 82 | solution = unravel_solution(solution, packed_structures) 83 | return solution, RESULTS.successful, {} 84 | 85 | def transpose(self, state: _DiagonalState, options: dict[str, Any]): 86 | del options 87 | diag, packed_structures = state 88 | transposed_packed_structures = transpose_packed_structures(packed_structures) 89 | transpose_state = diag, transposed_packed_structures 90 | transpose_options = {} 91 | return transpose_state, transpose_options 92 | 93 | def conj(self, state: _DiagonalState, options: dict[str, Any]): 94 | del options 95 | diag, packed_structures = state 96 | if diag is None: 97 | conj_diag = None 98 | else: 99 | conj_diag = diag.conj() 100 | conj_options = {} 101 | conj_state = conj_diag, packed_structures 102 | return conj_state, conj_options 103 | 104 | def allow_dependent_columns(self, operator): 105 | return not self.well_posed 106 | 107 | def allow_dependent_rows(self, operator): 108 | return not self.well_posed 109 | 110 | 111 | Diagonal.__init__.__doc__ = """**Arguments**: 112 | 113 | - `well_posed`: if `False`, then singular operators are accepted, and the pseudoinverse 114 | solution is returned. If `True` then passing a singular operator will cause an error 115 | to be raised instead. 116 | - `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine 117 | precision times `N`, where `N` is the input (or output) size of the operator. 118 | Only used if `well_posed=False` 119 | """ 120 | -------------------------------------------------------------------------------- /lineax/_solver/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import typing 17 | import warnings 18 | from typing import Any, NewType, TYPE_CHECKING 19 | 20 | import equinox.internal as eqxi 21 | import jax.numpy as jnp 22 | import jax.tree_util as jtu 23 | import numpy as np 24 | from jaxtyping import Array, PyTree, Shaped 25 | 26 | from .._misc import strip_weak_dtype, structure_equal 27 | from .._operator import ( 28 | AbstractLinearOperator, 29 | IdentityLinearOperator, 30 | ) 31 | 32 | 33 | def preconditioner_and_y0( 34 | operator: AbstractLinearOperator, vector: PyTree[Array], options: dict[str, Any] 35 | ): 36 | structure = operator.in_structure() 37 | try: 38 | preconditioner = options["preconditioner"] 39 | except KeyError: 40 | preconditioner = IdentityLinearOperator(structure) 41 | else: 42 | if not isinstance(preconditioner, AbstractLinearOperator): 43 | raise ValueError("The preconditioner must be a linear operator.") 44 | if not structure_equal(preconditioner.in_structure(), structure): 45 | raise ValueError( 46 | "The preconditioner must have `in_structure` that matches the " 47 | "operator's `in_strucure`." 48 | ) 49 | if not structure_equal(preconditioner.out_structure(), structure): 50 | raise ValueError( 51 | "The preconditioner must have `out_structure` that matches the " 52 | "operator's `in_structure`." 53 | ) 54 | try: 55 | y0 = options["y0"] 56 | except KeyError: 57 | y0 = jtu.tree_map(jnp.zeros_like, vector) 58 | else: 59 | if not structure_equal(y0, vector): 60 | raise ValueError( 61 | "`y0` must have the same structure, shape, and dtype as `vector`" 62 | ) 63 | return preconditioner, y0 64 | 65 | 66 | # This seems to introduce some spurious failure at docgen time. 67 | if hasattr(typing, "GENERATING_DOCUMENTATION") and not TYPE_CHECKING: 68 | PackedStructures = lambda x: x 69 | else: 70 | PackedStructures = NewType("PackedStructures", eqxi.Static) 71 | 72 | 73 | def pack_structures(operator: AbstractLinearOperator) -> PackedStructures: 74 | structures = ( 75 | strip_weak_dtype(operator.out_structure()), 76 | strip_weak_dtype(operator.in_structure()), 77 | ) 78 | leaves, treedef = jtu.tree_flatten(structures) # handle nonhashable pytrees 79 | return PackedStructures(eqxi.Static((leaves, treedef))) 80 | 81 | 82 | def ravel_vector( 83 | pytree: PyTree[Array], packed_structures: PackedStructures 84 | ) -> Shaped[Array, " size"]: 85 | leaves, treedef = packed_structures.value 86 | out_structure, _ = jtu.tree_unflatten(treedef, leaves) 87 | # `is` in case `tree_equal` returns a Tracer. 88 | if not structure_equal(pytree, out_structure): 89 | raise ValueError("pytree does not match out_structure") 90 | # not using `ravel_pytree` as that doesn't come with guarantees about order 91 | leaves = jtu.tree_leaves(pytree) 92 | dtype = jnp.result_type(*leaves) 93 | return jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves]) 94 | 95 | 96 | def unravel_solution( 97 | solution: Shaped[Array, " size"], packed_structures: PackedStructures 98 | ) -> PyTree[Array]: 99 | leaves, treedef = packed_structures.value 100 | _, in_structure = jtu.tree_unflatten(treedef, leaves) 101 | leaves, treedef = jtu.tree_flatten(in_structure) 102 | sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]]) 103 | split = jnp.split(solution, sizes) 104 | assert len(split) == len(leaves) 105 | with warnings.catch_warnings(): 106 | warnings.simplefilter("ignore") # ignore complex-to-real cast warning 107 | shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)] 108 | return jtu.tree_unflatten(treedef, shaped) 109 | 110 | 111 | def transpose_packed_structures( 112 | packed_structures: PackedStructures, 113 | ) -> PackedStructures: 114 | leaves, treedef = packed_structures.value 115 | out_structure, in_structure = jtu.tree_unflatten(treedef, leaves) 116 | leaves, treedef = jtu.tree_flatten((in_structure, out_structure)) 117 | return PackedStructures(eqxi.Static((leaves, treedef))) 118 | -------------------------------------------------------------------------------- /docs/api/tags.md: -------------------------------------------------------------------------------- 1 | # Tags 2 | 3 | Lineax offers a way to "tag" linear operators as exhibiting certain properties, e.g. that they are positive semidefinite. 4 | 5 | If a linear operator is known to have a particular property, then this can be used to dispatch to a more efficient implementation, e.g. when solving a linear system. 6 | 7 | Generally speaking, tags are an *optional* tool that can be used to improve your run time and/or compile time, by statically telling the linear solvers what properties they may assume about your system. However, if misused then you may find that the wrong result is silently returned. 8 | 9 | In this way they are analogous to flags like `scipy.linalg.solve(..., assume_a="pos")`. 10 | 11 | !!! Example 12 | 13 | ```python 14 | # Some rank-2 JAX array. 15 | matrix = ... 16 | # Some rank-1 JAX array. 17 | vector = ... 18 | 19 | # Declare that this matrix is positive semidefinite. 20 | operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag) 21 | 22 | # This tag is used to dispatch to a maximally-efficient linear solver. 23 | # In this case, a Cholesky solver is used: 24 | solution = lx.linear_solve(operator, vector) 25 | 26 | # Whether operators are tagged can be checked: 27 | assert lx.is_positive_semidefinite(operator) 28 | ``` 29 | 30 | !!! Warning 31 | 32 | Be careful, only the tag is actually checked, not the actual value of the matrix: 33 | ```python 34 | # Not a positive semidefinite matrix 35 | matrix = jax.numpy.array([[1, 2], [3, 4]]) 36 | 37 | operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag) 38 | lx.is_positive_semidefinite(operator) # True 39 | lx.linear_solve(operator, vector) # Returns the wrong solution! 40 | ``` 41 | 42 | Of the built-in operators: [`lineax.MatrixLinearOperator`][], [`lineax.PyTreeLinearOperator`][], [`lineax.JacobianLinearOperator`][], [`lineax.FunctionLinearOperator`][], [`lineax.TaggedLinearOperator`][] directly support a `tags` argument that mark them as having certain characteristics: 43 | ```python 44 | operator = lx.MatrixLinearOperator(matrix, lx.symmetric_tag) 45 | ``` 46 | 47 | You can pass multiple tags at once: 48 | ```python 49 | operator = lx.MatrixLinearOperator(matrix, (lx.symmetric_tag, lx.unit_diagonal_tag)) 50 | ``` 51 | 52 | Other linear operators can be wrapped into a [`lineax.TaggedLinearOperator`][] if necessary: 53 | ```python 54 | operator = lx.MatrixLinearOperator(...) 55 | symmetric_operator = operator + operator.T 56 | lx.is_symmetric(symmetric_operator) # False 57 | symmetric_operator = lx.TaggedLinearOperator(symmetric_operator, lx.symmetric_tag) 58 | lx.is_symmetric(symmetric_operator) # True 59 | ``` 60 | 61 | Some linear operators are known to exhibit certain properties by construction, and need no additional tags: 62 | ```python 63 | lx.is_symmetric(lx.DiagonalLinearOperator(...)) # True 64 | lx.is_positive_semidefinite(lx.IdentityLinearOperator(...)) # True 65 | ``` 66 | 67 | ## List of available tags 68 | 69 | ::: lineax.symmetric_tag 70 | 71 | Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.) 72 | 73 | --- 74 | 75 | ::: lineax.diagonal_tag 76 | 77 | Marks than an operator is diagonal. (As a matrix, it must have zeros in the off-diagonal entries.) 78 | 79 | For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Diagonal`][] as the solver. 80 | 81 | --- 82 | 83 | ::: lineax.unit_diagonal_tag 84 | 85 | Marks than an operator has $1$ for every diagonal element. (As a matrix $A$, then it must have $A_{ii} = 1$ for all $i$.) Note that the whole matrix need not be diagonal. 86 | 87 | For example, [`lineax.Triangular`][] uses this to cheapen its solve. 88 | 89 | --- 90 | 91 | ::: lineax.lower_triangular_tag 92 | 93 | Marks that an operator is lower triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i < j$.) Note that the diagonal may still have nonzero entries. 94 | 95 | For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver. 96 | 97 | --- 98 | 99 | ::: lineax.upper_triangular_tag 100 | 101 | Marks that an operator is upper triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i > j$.) Note that the diagonal may still have nonzero entries. 102 | 103 | For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver. 104 | 105 | --- 106 | 107 | ::: lineax.positive_semidefinite_tag 108 | 109 | Marks than operator is positive **semidefinite**. 110 | 111 | For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver. 112 | 113 | --- 114 | 115 | ::: lineax.negative_semidefinite_tag 116 | 117 | Marks than operator is negative **semidefinite**. 118 | 119 | For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver. 120 | -------------------------------------------------------------------------------- /tests/test_jvp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools as ft 16 | 17 | import equinox as eqx 18 | import jax.numpy as jnp 19 | import jax.random as jr 20 | import lineax as lx 21 | import pytest 22 | 23 | from .helpers import ( 24 | construct_matrix, 25 | construct_singular_matrix, 26 | finite_difference_jvp, 27 | has_tag, 28 | make_jac_operator, 29 | make_matrix_operator, 30 | solvers_tags_pseudoinverse, 31 | tree_allclose, 32 | ) 33 | 34 | 35 | @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) 36 | @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) 37 | @pytest.mark.parametrize("use_state", (True, False)) 38 | @pytest.mark.parametrize( 39 | "make_matrix", 40 | ( 41 | construct_matrix, 42 | construct_singular_matrix, 43 | ), 44 | ) 45 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 46 | def test_jvp( 47 | getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype 48 | ): 49 | t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None 50 | 51 | if (make_matrix is construct_matrix) or pseudoinverse: 52 | matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype) 53 | 54 | out_size, _ = matrix.shape 55 | vec = jr.normal(getkey(), (out_size,), dtype=dtype) 56 | t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) 57 | 58 | if has_tag(tags, lx.unit_diagonal_tag): 59 | # For all the other tags, A + εB with A, B \in {matrices satisfying the tag} 60 | # still satisfies the tag itself. 61 | # This is the exception. 62 | t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0) 63 | 64 | make_op = ft.partial(make_operator, getkey) 65 | operator, t_operator = eqx.filter_jvp( 66 | make_op, (matrix, tags), (t_matrix, t_tags) 67 | ) 68 | 69 | if use_state: 70 | state = solver.init(operator, options={}) 71 | linear_solve = ft.partial(lx.linear_solve, state=state) 72 | else: 73 | linear_solve = lx.linear_solve 74 | 75 | solve_vec_only = lambda v: linear_solve(operator, v, solver).value 76 | solve_op_only = lambda op: linear_solve(op, vec, solver).value 77 | solve_op_vec = lambda op, v: linear_solve(op, v, solver).value 78 | 79 | vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,)) 80 | op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,)) 81 | op_vec_out, t_op_vec_out = eqx.filter_jvp( 82 | solve_op_vec, 83 | (operator, vec), 84 | (t_operator, t_vec), 85 | ) 86 | (expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp( 87 | lambda op: jnp.linalg.lstsq(op, vec), # pyright: ignore 88 | (matrix,), 89 | (t_matrix,), 90 | ) 91 | (expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp( 92 | jnp.linalg.lstsq, 93 | (matrix, vec), 94 | (t_matrix, t_vec), # pyright: ignore 95 | ) 96 | 97 | # Work around JAX issue #14868. 98 | if jnp.any(jnp.isnan(t_expected_op_out)): 99 | _, (t_expected_op_out, *_) = finite_difference_jvp( 100 | lambda op: jnp.linalg.lstsq(op, vec), # pyright: ignore 101 | (matrix,), 102 | (t_matrix,), 103 | ) 104 | if jnp.any(jnp.isnan(t_expected_op_vec_out)): 105 | _, (t_expected_op_vec_out, *_) = finite_difference_jvp( 106 | jnp.linalg.lstsq, 107 | (matrix, vec), 108 | (t_matrix, t_vec), # pyright: ignore 109 | ) 110 | 111 | pinv_matrix = jnp.linalg.pinv(matrix) # pyright: ignore 112 | expected_vec_out = pinv_matrix @ vec 113 | assert tree_allclose(vec_out, expected_vec_out) 114 | assert tree_allclose(op_out, expected_op_out) 115 | assert tree_allclose(op_vec_out, expected_op_vec_out) 116 | 117 | t_expected_vec_out = pinv_matrix @ t_vec 118 | assert tree_allclose(matrix @ t_vec_out, matrix @ t_expected_vec_out, rtol=1e-3) 119 | assert tree_allclose(t_op_out, t_expected_op_out, rtol=1e-3) 120 | assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3) 121 | -------------------------------------------------------------------------------- /lineax/_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools as ft 16 | import math 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import jax.tree_util as jtu 21 | from equinox.internal import ω 22 | from jaxtyping import Array, ArrayLike, Inexact, PyTree, Scalar 23 | 24 | from ._misc import complex_to_real_dtype, default_floating_dtype 25 | 26 | 27 | def tree_dot(tree1: PyTree[ArrayLike], tree2: PyTree[ArrayLike]) -> Inexact[Array, ""]: 28 | """Compute the dot product of two pytrees of arrays with the same pytree 29 | structure.""" 30 | leaves1, treedef1 = jtu.tree_flatten(tree1) 31 | leaves2, treedef2 = jtu.tree_flatten(tree2) 32 | if treedef1 != treedef2: 33 | raise ValueError("trees must have the same structure") 34 | assert len(leaves1) == len(leaves2) 35 | dots = [] 36 | for leaf1, leaf2 in zip(leaves1, leaves2): 37 | dots.append( 38 | jnp.dot( 39 | jnp.conj(leaf1).reshape(-1), 40 | jnp.reshape(leaf2, -1), 41 | precision=jax.lax.Precision.HIGHEST, # pyright: ignore 42 | ) 43 | ) 44 | if len(dots) == 0: 45 | return jnp.array(0, default_floating_dtype()) 46 | else: 47 | return ft.reduce(jnp.add, dots) 48 | 49 | 50 | def sum_squares(x: PyTree[ArrayLike]) -> Scalar: 51 | """Computes the square of the L2 norm of a PyTree of arrays. 52 | 53 | Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes 54 | `Σ_i x_i^2` 55 | """ 56 | return tree_dot(x, x).real 57 | 58 | 59 | def two_norm(x: PyTree[ArrayLike]) -> Scalar: 60 | """Computes the L2 norm of a PyTree of arrays. 61 | 62 | Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes 63 | `sqrt(Σ_i x_i^2)` 64 | """ 65 | # Wrap the `custom_jvp` into a function so that our autogenerated documentation 66 | # displays the docstring correctly. 67 | return _two_norm(x) 68 | 69 | 70 | @jax.custom_jvp 71 | def _two_norm(x: PyTree[ArrayLike]) -> Scalar: 72 | leaves = jtu.tree_leaves(x) 73 | size = sum([jnp.size(xi) for xi in leaves]) 74 | if size == 1: 75 | # Avoid needless squaring-and-then-rooting. 76 | for leaf in leaves: 77 | if jnp.size(leaf) == 1: 78 | return jnp.abs(jnp.reshape(leaf, ())) 79 | else: 80 | assert False 81 | else: 82 | return jnp.sqrt(sum_squares(x)) 83 | 84 | 85 | @_two_norm.defjvp 86 | def _two_norm_jvp(x, tx): 87 | (x,) = x 88 | (tx,) = tx 89 | out = two_norm(x) 90 | # Get zero gradient, rather than NaN gradient, in these cases. 91 | pred = (out == 0) | jnp.isinf(out) 92 | denominator = jnp.where(pred, 1, out) 93 | # We could also switch the dot and the division. 94 | # This approach is a bit more expensive (more divisions), but should be more 95 | # numerically stable (`x` and `denominator` should be of the same scale; `tx` is of 96 | # unknown scale). 97 | with jax.numpy_dtype_promotion("standard"): 98 | div = (x**ω / denominator).ω 99 | t_out = tree_dot(div, tx).real 100 | t_out = jnp.where(pred, 0, t_out) 101 | return out, t_out 102 | 103 | 104 | def rms_norm(x: PyTree[ArrayLike]) -> Scalar: 105 | """Compute the RMS (root-mean-squared) norm of a PyTree of arrays. 106 | 107 | This is the same as the L2 norm, averaged by the size of the input `x`. Considering 108 | the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes 109 | `sqrt((Σ_i x_i^2)/n)` 110 | """ 111 | leaves = jtu.tree_leaves(x) 112 | size = sum([jnp.size(xi) for xi in leaves]) 113 | if size == 0: 114 | if len(leaves) == 0: 115 | dtype = default_floating_dtype() 116 | else: 117 | dtype = complex_to_real_dtype(jnp.result_type(*leaves)) 118 | return jnp.array(0.0, dtype) 119 | else: 120 | return two_norm(x) / math.sqrt(size) 121 | 122 | 123 | def max_norm(x: PyTree[ArrayLike]) -> Scalar: 124 | """Compute the L-infinity norm of a PyTree of arrays. 125 | 126 | This is the largest absolute elementwise value. Considering the input `x` as a flat 127 | vector `(x_1, ..., x_n)`, then this computes `max_i |x_i|`. 128 | """ 129 | leaves = jtu.tree_leaves(x) 130 | leaf_maxes = [jnp.max(jnp.abs(xi)) for xi in leaves if jnp.size(xi) > 0] 131 | if len(leaf_maxes) == 0: 132 | if len(leaves) == 0: 133 | dtype = default_floating_dtype() 134 | else: 135 | dtype = complex_to_real_dtype(jnp.result_type(*leaves)) 136 | return jnp.array(0.0, dtype) 137 | else: 138 | out = ft.reduce(jnp.maximum, leaf_maxes) 139 | return _zero_grad_at_zero(out) 140 | 141 | 142 | @jax.custom_jvp 143 | def _zero_grad_at_zero(x): 144 | return x 145 | 146 | 147 | @_zero_grad_at_zero.defjvp 148 | def _zero_grad_at_zero_jvp(primals, tangents): 149 | (out,) = primals 150 | (t_out,) = tangents 151 | return out, jnp.where(out == 0, 0, t_out) 152 | -------------------------------------------------------------------------------- /tests/test_vmap_vmap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools as ft 16 | 17 | import equinox as eqx 18 | import jax.numpy as jnp 19 | import jax.random as jr 20 | import lineax as lx 21 | import pytest 22 | 23 | from .helpers import ( 24 | construct_matrix, 25 | construct_singular_matrix, 26 | make_jac_operator, 27 | make_matrix_operator, 28 | solvers_tags_pseudoinverse, 29 | tree_allclose, 30 | ) 31 | 32 | 33 | @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) 34 | @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) 35 | @pytest.mark.parametrize("use_state", (True, False)) 36 | @pytest.mark.parametrize( 37 | "make_matrix", 38 | ( 39 | construct_matrix, 40 | construct_singular_matrix, 41 | ), 42 | ) 43 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 44 | def test_vmap_vmap( 45 | getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype 46 | ): 47 | if (make_matrix is construct_matrix) or pseudoinverse: 48 | # combinations with nontrivial application across both vmaps 49 | axes = [ 50 | (eqx.if_array(0), eqx.if_array(0), None, None), 51 | (None, None, 0, 0), 52 | (eqx.if_array(0), eqx.if_array(0), None, 0), 53 | (eqx.if_array(0), eqx.if_array(0), 0, 0), 54 | (None, eqx.if_array(0), 0, 0), 55 | ] 56 | 57 | for vmap2_op, vmap1_op, vmap2_vec, vmap1_vec in axes: 58 | if vmap1_op is not None: 59 | axis_size1 = 10 60 | out_axis1 = eqx.if_array(0) 61 | else: 62 | axis_size1 = None 63 | out_axis1 = None 64 | 65 | if vmap2_op is not None: 66 | axis_size2 = 10 67 | out_axis2 = eqx.if_array(0) 68 | else: 69 | axis_size2 = None 70 | out_axis2 = None 71 | 72 | (matrix,) = eqx.filter_vmap( 73 | eqx.filter_vmap( 74 | lambda getkey, solver, tags: make_matrix( 75 | getkey, solver, tags, dtype=dtype 76 | ), 77 | axis_size=axis_size1, 78 | out_axes=out_axis1, 79 | ), 80 | axis_size=axis_size2, 81 | out_axes=out_axis2, 82 | )(getkey, solver, tags) 83 | 84 | if vmap1_op is not None: 85 | if vmap2_op is not None: 86 | _, _, out_size, _ = matrix.shape 87 | else: 88 | _, out_size, _ = matrix.shape 89 | else: 90 | out_size, _ = matrix.shape 91 | 92 | if vmap1_vec is None: 93 | vec = jr.normal(getkey(), (out_size,), dtype=dtype) 94 | elif (vmap1_vec is not None) and (vmap2_vec is None): 95 | vec = jr.normal(getkey(), (10, out_size), dtype=dtype) 96 | else: 97 | vec = jr.normal(getkey(), (10, 10, out_size), dtype=dtype) 98 | 99 | make_op = ft.partial(make_operator, getkey) 100 | operator = eqx.filter_vmap( 101 | eqx.filter_vmap( 102 | make_op, 103 | in_axes=vmap1_op, 104 | out_axes=out_axis1, 105 | ), 106 | in_axes=vmap2_op, 107 | out_axes=out_axis2, 108 | )(matrix, tags) 109 | 110 | if use_state: 111 | 112 | def linear_solve(operator, vector): 113 | state = solver.init(operator, options={}) 114 | return lx.linear_solve(operator, vector, state=state, solver=solver) 115 | 116 | else: 117 | 118 | def linear_solve(operator, vector): 119 | return lx.linear_solve(operator, vector, solver) 120 | 121 | as_matrix_vmapped = eqx.filter_vmap( 122 | eqx.filter_vmap( 123 | lambda x: x.as_matrix(), 124 | in_axes=vmap1_op, 125 | out_axes=None if vmap1_op is None else 0, 126 | ), 127 | in_axes=vmap2_op, 128 | out_axes=None if vmap2_op is None else 0, 129 | )(operator) 130 | 131 | vmap1_axes = (vmap1_op, vmap1_vec) 132 | vmap2_axes = (vmap2_op, vmap2_vec) 133 | 134 | result = eqx.filter_vmap( 135 | eqx.filter_vmap(linear_solve, in_axes=vmap1_axes), in_axes=vmap2_axes 136 | )(operator, vec).value 137 | 138 | solve_with = lambda x: eqx.filter_vmap( 139 | eqx.filter_vmap(x, in_axes=vmap1_axes), in_axes=vmap2_axes 140 | )(as_matrix_vmapped, vec) 141 | 142 | if make_matrix is construct_singular_matrix: 143 | true_result, _, _, _ = solve_with(jnp.linalg.lstsq) # pyright: ignore 144 | else: 145 | true_result = solve_with(jnp.linalg.solve) # pyright: ignore 146 | assert tree_allclose(result, true_result, rtol=1e-3) 147 | -------------------------------------------------------------------------------- /tests/test_vmap_jvp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools as ft 16 | 17 | import equinox as eqx 18 | import jax.lax as lax 19 | import jax.numpy as jnp 20 | import jax.random as jr 21 | import lineax as lx 22 | import pytest 23 | 24 | from .helpers import ( 25 | construct_matrix, 26 | construct_singular_matrix, 27 | make_jac_operator, 28 | make_matrix_operator, 29 | solvers_tags_pseudoinverse, 30 | tree_allclose, 31 | ) 32 | 33 | 34 | @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) 35 | @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) 36 | @pytest.mark.parametrize("use_state", (True, False)) 37 | @pytest.mark.parametrize( 38 | "make_matrix", 39 | ( 40 | construct_matrix, 41 | construct_singular_matrix, 42 | ), 43 | ) 44 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 45 | def test_vmap_jvp( 46 | getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix, dtype 47 | ): 48 | if (make_matrix is construct_matrix) or pseudoinverse: 49 | t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None 50 | if pseudoinverse: 51 | jnp_solve1 = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore 52 | else: 53 | jnp_solve1 = jnp.linalg.solve # pyright: ignore 54 | if use_state: 55 | 56 | def linear_solve1(operator, vector): 57 | state = solver.init(operator, options={}) 58 | state_dynamic, state_static = eqx.partition(state, eqx.is_inexact_array) 59 | state_dynamic = lax.stop_gradient(state_dynamic) 60 | state = eqx.combine(state_dynamic, state_static) 61 | 62 | return lx.linear_solve(operator, vector, state=state, solver=solver) 63 | 64 | else: 65 | linear_solve1 = ft.partial(lx.linear_solve, solver=solver) 66 | 67 | for mode in ("vec", "op", "op_vec"): 68 | if "op" in mode: 69 | axis_size = 10 70 | out_axes = eqx.if_array(0) 71 | else: 72 | axis_size = None 73 | out_axes = None 74 | 75 | def _make(): 76 | matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype) 77 | make_op = ft.partial(make_operator, getkey) 78 | operator, t_operator = eqx.filter_jvp( 79 | make_op, (matrix, tags), (t_matrix, t_tags) 80 | ) 81 | return matrix, t_matrix, operator, t_operator 82 | 83 | matrix, t_matrix, operator, t_operator = eqx.filter_vmap( 84 | _make, axis_size=axis_size, out_axes=out_axes 85 | )() 86 | 87 | if "op" in mode: 88 | _, out_size, _ = matrix.shape 89 | else: 90 | out_size, _ = matrix.shape 91 | 92 | if "vec" in mode: 93 | vec = jr.normal(getkey(), (10, out_size), dtype=dtype) 94 | t_vec = jr.normal(getkey(), (10, out_size), dtype=dtype) 95 | else: 96 | vec = jr.normal(getkey(), (out_size,), dtype=dtype) 97 | t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) 98 | 99 | if mode == "op": 100 | linear_solve2 = lambda op: linear_solve1(op, vector=vec) 101 | jnp_solve2 = lambda mat: jnp_solve1(mat, vec) 102 | elif mode == "vec": 103 | linear_solve2 = lambda vector: linear_solve1(operator, vector) 104 | jnp_solve2 = lambda vector: jnp_solve1(matrix, vector) 105 | elif mode == "op_vec": 106 | linear_solve2 = linear_solve1 107 | jnp_solve2 = jnp_solve1 108 | else: 109 | assert False 110 | for jvp_first in (True, False): 111 | if jvp_first: 112 | linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve2) 113 | else: 114 | linear_solve3 = linear_solve2 115 | linear_solve3 = eqx.filter_vmap(linear_solve3) 116 | if not jvp_first: 117 | linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve3) 118 | linear_solve3 = eqx.filter_jit(linear_solve3) 119 | jnp_solve3 = ft.partial(eqx.filter_jvp, jnp_solve2) 120 | jnp_solve3 = eqx.filter_vmap(jnp_solve3) 121 | jnp_solve3 = eqx.filter_jit(jnp_solve3) 122 | if mode == "op": 123 | out, t_out = linear_solve3((operator,), (t_operator,)) 124 | true_out, true_t_out = jnp_solve3((matrix,), (t_matrix,)) 125 | elif mode == "vec": 126 | out, t_out = linear_solve3((vec,), (t_vec,)) 127 | true_out, true_t_out = jnp_solve3((vec,), (t_vec,)) 128 | elif mode == "op_vec": 129 | out, t_out = linear_solve3((operator, vec), (t_operator, t_vec)) 130 | true_out, true_t_out = jnp_solve3((matrix, vec), (t_matrix, t_vec)) 131 | else: 132 | assert False 133 | assert tree_allclose(out.value, true_out, atol=1e-4) 134 | assert tree_allclose(t_out.value, true_t_out, atol=1e-4) 135 | -------------------------------------------------------------------------------- /benchmarks/solver_speeds.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools as ft 16 | import sys 17 | import timeit 18 | 19 | import equinox as eqx 20 | import equinox.internal as eqxi 21 | import jax 22 | import jax.numpy as jnp 23 | import jax.random as jr 24 | import jax.scipy as jsp 25 | import lineax as lx 26 | 27 | 28 | sys.path.append("../tests") 29 | from helpers import construct_matrix, has_tag # pyright: ignore[reportMissingImports] 30 | 31 | 32 | getkey = eqxi.GetKey() 33 | 34 | 35 | def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8): 36 | return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol) 37 | 38 | 39 | jax.config.update("jax_enable_x64", True) 40 | 41 | if jax.config.jax_enable_x64: # pyright: ignore 42 | tol = 1e-12 43 | else: 44 | tol = 1e-6 45 | 46 | 47 | def base_wrapper(a, b, solver): 48 | op = lx.MatrixLinearOperator( 49 | a, 50 | ( 51 | lx.positive_semidefinite_tag, 52 | lx.symmetric_tag, 53 | lx.diagonal_tag, 54 | lx.tridiagonal_tag, 55 | ), 56 | ) 57 | out = lx.linear_solve(op, b, solver, throw=False) 58 | return out.value 59 | 60 | 61 | def jax_svd(a, b): 62 | out, _, _, _ = jnp.linalg.lstsq(a, b) # pyright: ignore 63 | return out 64 | 65 | 66 | def jax_gmres(a, b): 67 | out, _ = jsp.sparse.linalg.gmres(a, b, tol=tol) 68 | return out 69 | 70 | 71 | def jax_bicgstab(a, b): 72 | out, _ = jsp.sparse.linalg.bicgstab(a, b, tol=tol) 73 | return out 74 | 75 | 76 | def jax_cg(a, b): 77 | out, _ = jsp.sparse.linalg.cg(a, b, tol=tol) 78 | return out 79 | 80 | 81 | def jax_lu(matrix, vector): 82 | return jsp.linalg.lu_solve(jsp.linalg.lu_factor(matrix), vector) 83 | 84 | 85 | def jax_cholesky(matrix, vector): 86 | return jsp.linalg.cho_solve(jsp.linalg.cho_factor(matrix), vector) 87 | 88 | 89 | def jax_tridiagonal(matrix, vector): 90 | dl = jnp.append(0.0, matrix.diagonal(-1)) 91 | d = matrix.diagonal(0) 92 | du = jnp.append(matrix.diagonal(1), 0.0) 93 | return jax.lax.linalg.tridiagonal_solve(dl, d, du, vector[:, None])[:, 0] 94 | 95 | 96 | named_solvers = [ 97 | ("LU", "LU", lx.LU(), jax_lu, ()), 98 | ("QR", "SVD", lx.QR(), jax_svd, ()), 99 | ("SVD", "SVD", lx.SVD(), jax_svd, ()), 100 | ( 101 | "Cholesky", 102 | "Cholesky", 103 | lx.Cholesky(), 104 | jax_cholesky, 105 | lx.positive_semidefinite_tag, 106 | ), 107 | ("Diagonal", "None", lx.Diagonal(), None, lx.diagonal_tag), 108 | ( 109 | "Tridiagonal", 110 | "Tridiagonal", 111 | lx.Tridiagonal(), 112 | jax_tridiagonal, 113 | lx.tridiagonal_tag, 114 | ), 115 | ( 116 | "CG", 117 | "CG", 118 | lx.CG(atol=tol, rtol=tol, stabilise_every=None), 119 | jax_cg, 120 | lx.positive_semidefinite_tag, 121 | ), 122 | ( 123 | "GMRES", 124 | "GMRES", 125 | lx.GMRES(atol=1, rtol=1), 126 | jax_gmres, 127 | (), 128 | ), 129 | ( 130 | "BiCGStab", 131 | "BiCGStab", 132 | lx.BiCGStab(atol=tol, rtol=tol), 133 | jax_bicgstab, 134 | (), 135 | ), 136 | ] 137 | 138 | 139 | def create_problem(solver, tags, size=3): 140 | (matrix,) = construct_matrix(getkey, solver, tags, size=size) 141 | true_x = jr.normal(getkey(), (size,)) 142 | b = matrix @ true_x 143 | return matrix, true_x, b 144 | 145 | 146 | def create_easy_iterative_problem(size, tags): 147 | matrix = jr.normal(getkey(), (size, size)) / size + 2 * jnp.eye(size) 148 | true_x = jr.normal(getkey(), (size,)) 149 | if has_tag(tags, lx.positive_semidefinite_tag): 150 | matrix = matrix.T @ matrix 151 | b = matrix @ true_x 152 | return matrix, true_x, b 153 | 154 | 155 | def test_solvers(vmap_size, mat_size): 156 | for lx_name, jax_name, _lx_solver, jax_solver, tags in named_solvers: 157 | lx_solver = ft.partial(base_wrapper, solver=_lx_solver) 158 | if vmap_size == 1: 159 | if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)): 160 | matrix, true_x, b = create_easy_iterative_problem(mat_size, tags) 161 | else: 162 | matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size) 163 | else: 164 | if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)): 165 | matrix, true_x, b = eqx.filter_vmap( 166 | create_easy_iterative_problem, 167 | axis_size=vmap_size, 168 | out_axes=eqx.if_array(0), 169 | )(mat_size, tags) 170 | else: 171 | matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size) 172 | _create_problem = ft.partial(create_problem, size=mat_size) 173 | matrix, true_x, b = eqx.filter_vmap( 174 | _create_problem, axis_size=vmap_size, out_axes=eqx.if_array(0) 175 | )(lx_solver, tags) 176 | 177 | lx_solver = jax.vmap(lx_solver) 178 | if jax_solver is not None: 179 | jax_solver = jax.vmap(jax_solver) 180 | 181 | lx_solver = jax.jit(lx_solver) 182 | bench_lx = ft.partial(lx_solver, matrix, b) 183 | 184 | if vmap_size == 1: 185 | batch_msg = "problem" 186 | else: 187 | batch_msg = f"batch of {vmap_size} problems" 188 | 189 | lx_soln = bench_lx() 190 | if tree_allclose(lx_soln, true_x, atol=1e-4, rtol=1e-4): 191 | lx_solve_time = timeit.timeit(bench_lx, number=1) 192 | 193 | print( 194 | f"Lineax's {lx_name} solved {batch_msg} of " 195 | f"size {mat_size} in {lx_solve_time} seconds." 196 | ) 197 | else: 198 | fail_time = timeit.timeit(bench_lx, number=1) 199 | err = jnp.abs(lx_soln - true_x).max() 200 | print( 201 | f"Lineax's {lx_name} failed to solve {batch_msg} of " 202 | f"size {mat_size} with error {err} in {fail_time} seconds" 203 | ) 204 | if jax_solver is None: 205 | print("JAX has no equivalent solver. \n") 206 | 207 | else: 208 | jax_solver = jax.jit(jax_solver) 209 | bench_jax = ft.partial(jax_solver, matrix, b) 210 | jax_soln = bench_jax() 211 | if tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4): 212 | jax_solve_time = timeit.timeit(bench_jax, number=1) 213 | print( 214 | f"JAX's {jax_name} solved {batch_msg} of " 215 | f"size {mat_size} in {jax_solve_time} seconds. \n" 216 | ) 217 | else: 218 | fail_time = timeit.timeit(bench_jax, number=1) 219 | err = jnp.abs(jax_soln - true_x).max() 220 | print( 221 | f"JAX's {jax_name} failed to solve {batch_msg} of " 222 | f"size {mat_size} with error {err} in {fail_time} seconds. \n" 223 | ) 224 | 225 | 226 | for vmap_size, mat_size in [(1, 50), (1000, 50)]: 227 | test_solvers(vmap_size, mat_size) 228 | -------------------------------------------------------------------------------- /tests/test_solve.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import jax.random as jr 18 | import lineax as lx 19 | import pytest 20 | 21 | from .helpers import construct_poisson_matrix, tree_allclose 22 | 23 | 24 | def test_gmres_large_dense(getkey): 25 | if jax.config.jax_enable_x64: # pyright: ignore 26 | tol = 1e-10 27 | else: 28 | tol = 1e-4 29 | solver = lx.GMRES(atol=tol, rtol=tol, restart=100) 30 | 31 | matrix = jr.normal(getkey(), (100, 100)) 32 | operator = lx.MatrixLinearOperator(matrix) 33 | true_x = jr.normal(getkey(), (100,)) 34 | b = matrix @ true_x 35 | 36 | lx_soln = lx.linear_solve(operator, b, solver).value 37 | 38 | assert tree_allclose(lx_soln, true_x, atol=tol, rtol=tol) 39 | 40 | 41 | def test_nontrivial_pytree_operator(): 42 | x = [[1, 5.0], [jnp.array(-2), jnp.array(-2.0)]] 43 | y = [3, 4] 44 | struct = jax.eval_shape(lambda: y) 45 | operator = lx.PyTreeLinearOperator(x, struct) 46 | out = lx.linear_solve(operator, y).value 47 | true_out = [jnp.array(-3.25), jnp.array(1.25)] 48 | assert tree_allclose(out, true_out) 49 | 50 | 51 | def test_nontrivial_diagonal_operator(): 52 | x = (8.0, jnp.array([1, 2, 3]), {"a": jnp.array([4, 5]), "b": 6}) 53 | y = (4.0, jnp.array([7, 8, 9]), {"a": jnp.array([2, 10]), "b": 12}) 54 | operator = lx.DiagonalLinearOperator(x) 55 | out = lx.linear_solve(operator, y).value 56 | true_out = ( 57 | jnp.array(0.5), 58 | jnp.array([7.0, 4.0, 3.0]), 59 | {"a": jnp.array([0.5, 2.0]), "b": jnp.array(2.0)}, 60 | ) 61 | assert tree_allclose(out, true_out) 62 | 63 | 64 | @pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD())) 65 | def test_mixed_dtypes(solver): 66 | f32 = lambda x: jnp.array(x, dtype=jnp.float32) 67 | f64 = lambda x: jnp.array(x, dtype=jnp.float64) 68 | x = [[f32(1), f64(5)], [f32(-2), f64(-2)]] 69 | y = [f64(3), f64(4)] 70 | struct = jax.eval_shape(lambda: y) 71 | operator = lx.PyTreeLinearOperator(x, struct) 72 | out = lx.linear_solve(operator, y, solver=solver).value 73 | true_out = [f32(-3.25), f64(1.25)] 74 | assert tree_allclose(out, true_out) 75 | 76 | 77 | @pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD())) 78 | def test_mixed_dtypes_complex(solver): 79 | c64 = lambda x: jnp.array(x, dtype=jnp.complex64) 80 | c128 = lambda x: jnp.array(x, dtype=jnp.complex128) 81 | x = [[c64(1), c128(5.0j)], [c64(2.0j), c128(-2)]] 82 | y = [c128(3), c128(4)] 83 | struct = jax.eval_shape(lambda: y) 84 | operator = lx.PyTreeLinearOperator(x, struct) 85 | out = lx.linear_solve(operator, y, solver=solver).value 86 | true_out = [c64(-0.75 - 2.5j), c128(0.5 - 0.75j)] 87 | assert tree_allclose(out, true_out) 88 | 89 | 90 | @pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD())) 91 | def test_mixed_dtypes_complex_real(solver): 92 | f64 = lambda x: jnp.array(x, dtype=jnp.float64) 93 | c128 = lambda x: jnp.array(x, dtype=jnp.complex128) 94 | x = [[f64(1), c128(-5.0j)], [f64(2.0), c128(-2j)]] 95 | y = [c128(3), c128(4)] 96 | struct = jax.eval_shape(lambda: y) 97 | operator = lx.PyTreeLinearOperator(x, struct) 98 | out = lx.linear_solve(operator, y, solver=solver).value 99 | true_out = [f64(1.75), c128(0.25j)] 100 | assert tree_allclose(out, true_out) 101 | 102 | 103 | def test_mixed_dtypes_triangular(): 104 | f32 = lambda x: jnp.array(x, dtype=jnp.float32) 105 | f64 = lambda x: jnp.array(x, dtype=jnp.float64) 106 | x = [[f32(1), f64(0)], [f32(-2), f64(-2)]] 107 | y = [f64(3), f64(4)] 108 | struct = jax.eval_shape(lambda: y) 109 | operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag) 110 | out = lx.linear_solve(operator, y, solver=lx.Triangular()).value 111 | true_out = [f32(3), f64(-5)] 112 | assert tree_allclose(out, true_out) 113 | 114 | 115 | def test_mixed_dtypes_complex_triangular(): 116 | c64 = lambda x: jnp.array(x, dtype=jnp.complex64) 117 | c128 = lambda x: jnp.array(x, dtype=jnp.complex128) 118 | x = [[c64(1), c128(0)], [c64(2.0j), c128(-2)]] 119 | y = [c128(3), c128(4)] 120 | struct = jax.eval_shape(lambda: y) 121 | operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag) 122 | out = lx.linear_solve(operator, y, solver=lx.Triangular()).value 123 | true_out = [c64(3), c128(-2 + 3.0j)] 124 | assert tree_allclose(out, true_out) 125 | 126 | 127 | def test_mixed_dtypes_complex_real_triangular(): 128 | f64 = lambda x: jnp.array(x, dtype=jnp.float64) 129 | c128 = lambda x: jnp.array(x, dtype=jnp.complex128) 130 | x = [[f64(1), c128(0)], [f64(2.0), c128(2j)]] 131 | y = [c128(3), c128(4)] 132 | struct = jax.eval_shape(lambda: y) 133 | operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag) 134 | out = lx.linear_solve(operator, y, solver=lx.Triangular()).value 135 | true_out = [f64(3), c128(1j)] 136 | assert tree_allclose(out, true_out) 137 | 138 | 139 | def test_ad_closure_function_linear_operator(getkey): 140 | def f(x, z): 141 | def fn(y): 142 | return x * y 143 | 144 | op = lx.FunctionLinearOperator(fn, jax.eval_shape(lambda: z)) 145 | sol = lx.linear_solve(op, z).value 146 | return jnp.sum(sol), sol 147 | 148 | x = jr.normal(getkey(), (3,)) 149 | x = jnp.where(jnp.abs(x) < 1e-6, 0.7, x) 150 | z = jr.normal(getkey(), (3,)) 151 | grad, sol = jax.grad(f, has_aux=True)(x, z) 152 | assert tree_allclose(grad, -z / (x**2)) 153 | assert tree_allclose(sol, z / x) 154 | 155 | 156 | def test_grad_vmap_symbolic_cotangent(): 157 | def f(x): 158 | return x[0], x[1] 159 | 160 | @jax.vmap 161 | def to_vmap(x): 162 | op = lx.FunctionLinearOperator(f, jax.eval_shape(lambda: x)) 163 | sol = lx.linear_solve(op, x) 164 | return sol.value[0] 165 | 166 | @jax.grad 167 | def to_grad(x): 168 | return jnp.sum(to_vmap(x)) 169 | 170 | x = (jnp.arange(3.0), jnp.arange(3.0)) 171 | to_grad(x) 172 | 173 | 174 | @pytest.mark.parametrize( 175 | "solver", 176 | ( 177 | lx.CG(0.0, 0.0, max_steps=2), 178 | lx.NormalCG(0.0, 0.0, max_steps=2), 179 | lx.BiCGStab(0.0, 0.0, max_steps=2), 180 | lx.GMRES(0.0, 0.0, max_steps=2), 181 | lx.LSMR(0.0, 0.0, max_steps=2), 182 | ), 183 | ) 184 | def test_iterative_solver_max_steps_only(solver): 185 | """Iterative solvers should work with max_steps only (no Equinox errors).""" 186 | SIZE = 100 187 | 188 | poisson_matrix = construct_poisson_matrix(SIZE) 189 | poisson_operator = lx.MatrixLinearOperator( 190 | poisson_matrix, tags=(lx.negative_semidefinite_tag, lx.symmetric_tag) 191 | ) 192 | rhs = jax.random.normal(jax.random.key(0), (SIZE,)) 193 | 194 | lx.linear_solve(poisson_operator, rhs, solver) 195 | 196 | 197 | def test_nonfinite_input(): 198 | operator = lx.DiagonalLinearOperator((1.0, 1.0)) 199 | vector = (1.0, jnp.inf) 200 | sol = lx.linear_solve(operator, vector, throw=False) 201 | assert sol.result == lx.RESULTS.nonfinite_input 202 | 203 | vector = (1.0, jnp.nan) 204 | sol = lx.linear_solve(operator, vector, throw=False) 205 | assert sol.result == lx.RESULTS.nonfinite_input 206 | 207 | vector = (jnp.nan, jnp.inf) 208 | sol = lx.linear_solve(operator, vector, throw=False) 209 | assert sol.result == lx.RESULTS.nonfinite_input 210 | -------------------------------------------------------------------------------- /docs/examples/structured_matrices.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e2573d62-a505-4998-8796-b0f1bc889433", 6 | "metadata": {}, 7 | "source": [ 8 | "# Structured matrices\n", 9 | "\n", 10 | "Lineax can also be used with matrices known to exhibit special structure, e.g. tridiagonal matrices or positive definite matrices.\n", 11 | "\n", 12 | "Typically, that means using a particular operator type:" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "8e275652-dd80-4a9a-b3ac-b96dc16d3334", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "[[ 4. 2. 0. 0. ]\n", 28 | " [ 1. -0.5 -1. 0. ]\n", 29 | " [ 0. 3. 7. -5. ]\n", 30 | " [ 0. 0. -0.7 1. ]]\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "import jax.numpy as jnp\n", 36 | "import jax.random as jr\n", 37 | "import lineax as lx\n", 38 | "\n", 39 | "\n", 40 | "diag = jnp.array([4.0, -0.5, 7.0, 1.0])\n", 41 | "lower_diag = jnp.array([1.0, 3.0, -0.7])\n", 42 | "upper_diag = jnp.array([2.0, -1.0, -5.0])\n", 43 | "\n", 44 | "operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\n", 45 | "print(operator.as_matrix())" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "id": "ba23ecc4-bdea-4293-a138-ce77bc83082c", 52 | "metadata": { 53 | "tags": [] 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "vector = jnp.array([1.0, -0.5, 2.0, 0.8])\n", 58 | "# Will automatically dispatch to a tridiagonal solver.\n", 59 | "solution = lx.linear_solve(operator, vector)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "cd58979d-b619-4ddf-9a17-12e8babae3e8", 65 | "metadata": {}, 66 | "source": [ 67 | "If you're uncertain which solver is being dispatched to, then you can check:" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "id": "6984f62f-75fc-4d6e-ab42-fdade471be5b", 74 | "metadata": { 75 | "tags": [] 76 | }, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Tridiagonal()\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "default_solver = lx.AutoLinearSolver(well_posed=True)\n", 88 | "print(default_solver.select_solver(operator))" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "164a5bd5-5d48-4b28-bcc5-d276ab49c780", 94 | "metadata": {}, 95 | "source": [ 96 | "If you want to enforce that a particular solver is used, then it can be passed manually:" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "id": "102ada9a-0533-40cf-9bad-02918fffb6b1", 103 | "metadata": { 104 | "tags": [] 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "solution = lx.linear_solve(operator, vector, solver=lx.Tridiagonal())" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "1b4ebf09-e138-43f6-973c-c9f005ffb55e", 114 | "metadata": {}, 115 | "source": [ 116 | "Trying to use a solver with an unsupported operator will raise an error:" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "id": "d8f5bf66-53cd-4e81-a8d7-a19e86307ad3", 123 | "metadata": { 124 | "tags": [] 125 | }, 126 | "outputs": [ 127 | { 128 | "ename": "ValueError", 129 | "evalue": "`Tridiagonal` may only be used for linear solves with tridiagonal matrices", 130 | "output_type": "error", 131 | "traceback": [ 132 | "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m `Tridiagonal` may only be used for linear solves with tridiagonal matrices\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "not_tridiagonal_matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n", 138 | "not_tridiagonal_operator = lx.MatrixLinearOperator(not_tridiagonal_matrix)\n", 139 | "solution = lx.linear_solve(not_tridiagonal_operator, vector, solver=lx.Tridiagonal())" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "03c4c531-58fa-4b56-8b0a-6e611c8c5912", 145 | "metadata": {}, 146 | "source": [ 147 | "---\n", 148 | "\n", 149 | "Besides using a particular operator type, the structure of the matrix can also be expressed by [adding particular tags](../api/tags.md). These tags act as a manual override mechanism, and the values of the matrix are not checked.\n", 150 | "\n", 151 | "For example, let's construct a positive definite matrix:" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "id": "b5add874-7a2c-4000-84c3-8c94a121a831", 158 | "metadata": { 159 | "tags": [] 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n", 164 | "operator = lx.MatrixLinearOperator(matrix.T @ matrix)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "5459b2d6-ddb9-4a37-bb51-3f5c204bab0d", 170 | "metadata": {}, 171 | "source": [ 172 | "Unfortunately, Lineax has no way of knowing that this matrix is positive definite. It can solve the system, but it will not use a solver that is adapted to exploit the extra structure:" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 8, 178 | "id": "78400416-e774-4f74-a530-e368db84af0e", 179 | "metadata": { 180 | "tags": [] 181 | }, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "LU()\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "solution = lx.linear_solve(operator, vector)\n", 193 | "print(default_solver.select_solver(operator))" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "id": "e108bdff-1cf1-4751-8c9d-3baae82ca9a7", 199 | "metadata": {}, 200 | "source": [ 201 | "But if we add a tag:" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 9, 207 | "id": "f6dc2966-1dfa-4a3c-be6a-974926695547", 208 | "metadata": { 209 | "tags": [] 210 | }, 211 | "outputs": [ 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "Cholesky()\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "operator = lx.MatrixLinearOperator(matrix.T @ matrix, lx.positive_semidefinite_tag)\n", 222 | "solution2 = lx.linear_solve(operator, vector)\n", 223 | "print(default_solver.select_solver(operator))" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "id": "7274d17b-a7d3-45bf-9042-785ac25e2d74", 229 | "metadata": {}, 230 | "source": [ 231 | "Then a more efficient solver can be selected. We can check that the solutions returned from these two approaches are equal:" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 10, 237 | "id": "fdcde152-9ac1-4532-a174-3fc39d83d289", 238 | "metadata": { 239 | "tags": [] 240 | }, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "[ 1.400575 -0.41042092 0.5313305 0.28422552]\n", 247 | "[ 1.4005749 -0.41042086 0.53133047 0.2842255 ]\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "print(solution.value)\n", 253 | "print(solution2.value)" 254 | ] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "py39", 260 | "language": "python", 261 | "name": "py39" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.9.16" 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 5 278 | } 279 | -------------------------------------------------------------------------------- /lineax/_solver/bicgstab.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Callable 16 | from typing import Any, TypeAlias 17 | 18 | import jax 19 | import jax.lax as lax 20 | import jax.numpy as jnp 21 | import jax.tree_util as jtu 22 | from equinox.internal import ω 23 | from jaxtyping import Array, PyTree 24 | 25 | from .._norm import max_norm, tree_dot 26 | from .._operator import AbstractLinearOperator, conj 27 | from .._solution import RESULTS 28 | from .._solve import AbstractLinearSolver 29 | from .misc import preconditioner_and_y0 30 | 31 | 32 | _BiCGStabState: TypeAlias = AbstractLinearOperator 33 | 34 | 35 | class BiCGStab(AbstractLinearSolver[_BiCGStabState]): 36 | """Biconjugate gradient stabilised method for linear systems. 37 | 38 | The operator should be square. 39 | 40 | Equivalent to `jax.scipy.sparse.linalg.bicgstab`. 41 | 42 | This supports the following `options` (as passed to 43 | `lx.linear_solve(..., options=...)`). 44 | 45 | - `preconditioner`: A [`lineax.AbstractLinearOperator`][] 46 | to be used as a preconditioner. Defaults to 47 | [`lineax.IdentityLinearOperator`][]. This method uses right preconditioning. 48 | - `y0`: The initial estimate of the solution to the linear system. Defaults to all 49 | zeros. 50 | """ 51 | 52 | rtol: float 53 | atol: float 54 | norm: Callable = max_norm 55 | max_steps: int | None = None 56 | 57 | def __check_init__(self): 58 | if isinstance(self.rtol, (int, float)) and self.rtol < 0: 59 | raise ValueError("Tolerances must be non-negative.") 60 | if isinstance(self.atol, (int, float)) and self.atol < 0: 61 | raise ValueError("Tolerances must be non-negative.") 62 | 63 | if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)): 64 | if self.atol == 0 and self.rtol == 0 and self.max_steps is None: 65 | raise ValueError( 66 | "Must specify `rtol`, `atol`, or `max_steps` (or some combination " 67 | "of all three)." 68 | ) 69 | 70 | def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): 71 | if operator.in_structure() != operator.out_structure(): 72 | raise ValueError( 73 | "`BiCGstab(..., normal=False)` may only be used for linear solves with " 74 | "square matrices." 75 | ) 76 | return operator 77 | 78 | def compute( 79 | self, state: _BiCGStabState, vector: PyTree[Array], options: dict[str, Any] 80 | ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: 81 | operator = state 82 | preconditioner, y0 = preconditioner_and_y0(operator, vector, options) 83 | leaves, _ = jtu.tree_flatten(vector) 84 | if self.max_steps is None: 85 | size = sum(leaf.size for leaf in leaves) 86 | max_steps = 10 * size 87 | else: 88 | max_steps = self.max_steps 89 | has_scale = not ( 90 | isinstance(self.atol, (int, float)) 91 | and isinstance(self.rtol, (int, float)) 92 | and self.atol == 0 93 | and self.rtol == 0 94 | ) 95 | if has_scale: 96 | b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω 97 | 98 | # This implementation is the same a jax.scipy.sparse.linalg.bicgstab 99 | # but with AbstractLinearOperator. 100 | # We use the notation found on the wikipedia except with y instead of x: 101 | # https://en.wikipedia.org/wiki/ 102 | # Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB 103 | # preconditioner in this case is K2^(-1) (i.e., right preconditioning) 104 | 105 | r0 = (vector**ω - operator.mv(y0) ** ω).ω 106 | 107 | def breakdown_occurred(omega, alpha, rho): 108 | # Empirically, the tolerance checks for breakdown are very tight. 109 | # These specific tolerances are heuristic. 110 | if jax.config.jax_enable_x64: # pyright: ignore 111 | return (omega == 0.0) | (alpha == 0.0) | (rho == 0.0) 112 | else: 113 | return (omega < 1e-16) | (alpha < 1e-16) | (rho < 1e-16) 114 | 115 | def not_converged(r, diff, y): 116 | # The primary tolerance check. 117 | # Given Ay=b, then we have to be doing better than `scale` in both 118 | # the `y` and the `b` spaces. 119 | if has_scale: 120 | with jax.numpy_dtype_promotion("standard"): 121 | y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω 122 | norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore 123 | norm2 = self.norm((diff**ω / y_scale**ω).ω) 124 | return (norm1 > 1) | (norm2 > 1) 125 | else: 126 | return True 127 | 128 | def cond_fun(carry): 129 | y, r, alpha, omega, rho, _, _, diff, step = carry 130 | out = jnp.invert(breakdown_occurred(omega, alpha, rho)) 131 | out = out & not_converged(r, diff, y) 132 | out = out & (step < max_steps) 133 | return out 134 | 135 | def body_fun(carry): 136 | y, r, alpha, omega, rho, p, v, diff, step = carry 137 | 138 | rho_new = tree_dot(r0, r) 139 | beta = (rho_new / rho) * (alpha / omega) 140 | p_new = (r**ω + beta * (p**ω - omega * v**ω)).ω 141 | 142 | # TODO(raderj): reduce this to a single operator.mv call 143 | # by using the scan trick. 144 | x = preconditioner.mv(p_new) 145 | v_new = operator.mv(x) 146 | 147 | alpha_new = rho_new / tree_dot(r0, v_new) 148 | s = (r**ω - alpha_new * v_new**ω).ω 149 | 150 | z = preconditioner.mv(s) 151 | t = operator.mv(z) 152 | 153 | omega_new = tree_dot(s, t) / tree_dot(t, t) 154 | 155 | diff = (alpha_new * x**ω + omega_new * z**ω).ω 156 | y_new = (y**ω + diff**ω).ω 157 | r_new = (s**ω - omega_new * t**ω).ω 158 | return ( 159 | y_new, 160 | r_new, 161 | alpha_new, 162 | omega_new, 163 | rho_new, 164 | p_new, 165 | v_new, 166 | diff, 167 | step + 1, 168 | ) 169 | 170 | p0 = v0 = jtu.tree_map(jnp.zeros_like, vector) 171 | alpha = omega = rho = jnp.array(1.0) 172 | 173 | init_carry = ( 174 | y0, 175 | r0, 176 | alpha, 177 | omega, 178 | rho, 179 | p0, 180 | v0, 181 | ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω, 182 | 0, 183 | ) 184 | solution, residual, alpha, omega, rho, _, _, diff, num_steps = lax.while_loop( 185 | cond_fun, body_fun, init_carry 186 | ) 187 | 188 | if self.max_steps is None: 189 | result = RESULTS.where( 190 | (num_steps == max_steps), RESULTS.singular, RESULTS.successful 191 | ) 192 | else: 193 | result = RESULTS.where( 194 | (num_steps == self.max_steps), 195 | RESULTS.max_steps_reached if has_scale else RESULTS.successful, 196 | RESULTS.successful, 197 | ) 198 | # breakdown is only an issue if we did not converge 199 | breakdown = breakdown_occurred(omega, alpha, rho) & not_converged( 200 | residual, diff, solution 201 | ) 202 | result = RESULTS.where(breakdown, RESULTS.breakdown, result) 203 | 204 | stats = {"num_steps": num_steps, "max_steps": self.max_steps} 205 | return solution, result, stats 206 | 207 | def transpose(self, state: _BiCGStabState, options: dict[str, Any]): 208 | transpose_options = {} 209 | if "preconditioner" in options: 210 | transpose_options["preconditioner"] = options["preconditioner"].transpose() 211 | operator = state 212 | return operator.transpose(), transpose_options 213 | 214 | def conj(self, state: _BiCGStabState, options: dict[str, Any]): 215 | conj_options = {} 216 | if "preconditioner" in options: 217 | conj_options["preconditioner"] = conj(options["preconditioner"]) 218 | operator = state 219 | return conj(operator), conj_options 220 | 221 | def allow_dependent_columns(self, operator): 222 | return False 223 | 224 | def allow_dependent_rows(self, operator): 225 | return False 226 | 227 | 228 | BiCGStab.__init__.__doc__ = r"""**Arguments:** 229 | 230 | - `rtol`: Relative tolerance for terminating solve. 231 | - `atol`: Absolute tolerance for terminating solve. 232 | - `norm`: The norm to use when computing whether the error falls within the tolerance. 233 | Defaults to the max norm. 234 | - `max_steps`: The maximum number of iterations to run the solver for. If more steps 235 | than this are required, then the solve is halted with a failure. 236 | """ 237 | -------------------------------------------------------------------------------- /tests/test_singular.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import contextlib 16 | import functools as ft 17 | 18 | import equinox as eqx 19 | import jax 20 | import jax.numpy as jnp 21 | import jax.random as jr 22 | import lineax as lx 23 | import pytest 24 | 25 | from .helpers import ( 26 | construct_singular_matrix, 27 | finite_difference_jvp, 28 | make_jac_operator, 29 | make_matrix_operator, 30 | ops, 31 | params, 32 | tol, 33 | tree_allclose, 34 | ) 35 | 36 | 37 | @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=True)) 38 | @pytest.mark.parametrize("ops", ops) 39 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 40 | def test_small_singular(make_operator, solver, tags, ops, getkey, dtype): 41 | if jax.config.jax_enable_x64: # pyright: ignore 42 | tol = 1e-10 43 | else: 44 | tol = 1e-4 45 | (matrix,) = construct_singular_matrix(getkey, solver, tags, dtype=dtype) 46 | operator = make_operator(getkey, matrix, tags) 47 | operator, matrix = ops(operator, matrix) 48 | assert tree_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol) 49 | out_size, in_size = matrix.shape 50 | true_x = jr.normal(getkey(), (in_size,), dtype=dtype) 51 | b = matrix @ true_x 52 | x = lx.linear_solve(operator, b, solver=solver, throw=False).value 53 | jax_x, *_ = jnp.linalg.lstsq(matrix, b) # pyright: ignore 54 | assert tree_allclose(x, jax_x, atol=tol, rtol=tol) 55 | 56 | 57 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 58 | def test_bicgstab_breakdown(getkey, dtype): 59 | if jax.config.jax_enable_x64: # pyright: ignore 60 | tol = 1e-10 61 | else: 62 | tol = 1e-4 63 | solver = lx.GMRES(atol=tol, rtol=tol, restart=2) 64 | 65 | matrix = jr.normal(jr.PRNGKey(0), (100, 100), dtype=dtype) 66 | true_x = jr.normal(jr.PRNGKey(0), (100,), dtype=dtype) 67 | b = matrix @ true_x 68 | operator = lx.MatrixLinearOperator(matrix) 69 | 70 | # result != 0 implies lineax reported failure 71 | lx_soln = lx.linear_solve(operator, b, solver, throw=False) 72 | 73 | assert jnp.all(lx_soln.result != lx.RESULTS.successful) 74 | 75 | 76 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 77 | def test_gmres_stagnation_or_breakdown(getkey, dtype): 78 | if jax.config.jax_enable_x64: # pyright: ignore 79 | tol = 1e-10 80 | else: 81 | tol = 1e-4 82 | solver = lx.GMRES(atol=tol, rtol=tol, restart=2) 83 | 84 | matrix = jnp.array( 85 | [ 86 | [0.15892892, 0.05884365, -0.60427412, 0.1891916], 87 | [-1.5484863, 0.93608822, 1.94888868, 1.37069667], 88 | [0.62687318, -0.13996738, -0.6824359, 0.30975754], 89 | [-0.67428635, 1.52372255, -0.88277754, 0.69633816], 90 | ], 91 | dtype=dtype, 92 | ) 93 | true_x = jnp.array([0.51383273, 1.72983427, -0.43251078, -1.11764668], dtype=dtype) 94 | b = matrix @ true_x 95 | operator = lx.MatrixLinearOperator(matrix) 96 | 97 | # result != 0 implies lineax reported failure 98 | lx_soln = lx.linear_solve(operator, b, solver, throw=False) 99 | 100 | assert jnp.all(lx_soln.result != lx.RESULTS.successful) 101 | 102 | 103 | @pytest.mark.parametrize( 104 | "solver", 105 | ( 106 | lx.AutoLinearSolver(well_posed=None), 107 | lx.QR(), 108 | lx.SVD(), 109 | lx.LSMR(atol=tol, rtol=tol), 110 | ), 111 | ) 112 | def test_nonsquare_pytree_operator1(solver): 113 | x = [[1, 5.0, jnp.array(-1.0)], [jnp.array(-2), jnp.array(-2.0), 3.0]] 114 | y = [3.0, 4] 115 | struct = jax.eval_shape(lambda: y) 116 | operator = lx.PyTreeLinearOperator(x, struct) 117 | out = lx.linear_solve(operator, y, solver=solver).value 118 | matrix = jnp.array([[1.0, 5.0, -1.0], [-2.0, -2.0, 3.0]]) 119 | true_out, _, _, _ = jnp.linalg.lstsq(matrix, jnp.array(y)) # pyright: ignore 120 | true_out = [true_out[0], true_out[1], true_out[2]] 121 | assert tree_allclose(out, true_out) 122 | 123 | 124 | @pytest.mark.parametrize( 125 | "solver", 126 | ( 127 | lx.AutoLinearSolver(well_posed=None), 128 | lx.QR(), 129 | lx.SVD(), 130 | lx.LSMR(atol=tol, rtol=tol), 131 | ), 132 | ) 133 | def test_nonsquare_pytree_operator2(solver): 134 | x = [[1, jnp.array(-2)], [5.0, jnp.array(-2.0)], [jnp.array(-1.0), 3.0]] 135 | y = [3.0, 4, 5.0] 136 | struct = jax.eval_shape(lambda: y) 137 | operator = lx.PyTreeLinearOperator(x, struct) 138 | out = lx.linear_solve(operator, y, solver=solver).value 139 | matrix = jnp.array([[1.0, -2.0], [5.0, -2.0], [-1.0, 3.0]]) 140 | true_out, _, _, _ = jnp.linalg.lstsq(matrix, jnp.array(y)) # pyright: ignore 141 | true_out = [true_out[0], true_out[1]] 142 | assert tree_allclose(out, true_out) 143 | 144 | 145 | @pytest.mark.parametrize("full_rank", (True, False)) 146 | @pytest.mark.parametrize("jvp", (False, True)) 147 | @pytest.mark.parametrize("wide", (False, True)) 148 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 149 | def test_qr_nonsquare_mat_vec(full_rank, jvp, wide, dtype, getkey): 150 | if wide: 151 | out_size = 3 152 | in_size = 6 153 | else: 154 | out_size = 6 155 | in_size = 3 156 | matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) 157 | if full_rank: 158 | context = contextlib.nullcontext() 159 | else: 160 | context = pytest.raises(Exception) 161 | if wide: 162 | matrix = matrix.at[:, 2:].set(0) 163 | else: 164 | matrix = matrix.at[2:, :].set(0) 165 | vector = jr.normal(getkey(), (out_size,), dtype=dtype) 166 | lx_solve = lambda mat, vec: lx.linear_solve( 167 | lx.MatrixLinearOperator(mat), vec, lx.QR() 168 | ).value 169 | jnp_solve = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore 170 | if jvp: 171 | lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve)) 172 | jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve)) 173 | t_matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) 174 | t_vector = jr.normal(getkey(), (out_size,), dtype=dtype) 175 | args = ((matrix, vector), (t_matrix, t_vector)) 176 | else: 177 | args = (matrix, vector) 178 | with context: 179 | x = lx_solve(*args) # pyright: ignore 180 | if full_rank: 181 | true_x = jnp_solve(*args) 182 | assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4) 183 | 184 | 185 | @pytest.mark.parametrize("full_rank", (True, False)) 186 | @pytest.mark.parametrize("jvp", (False, True)) 187 | @pytest.mark.parametrize("wide", (False, True)) 188 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 189 | def test_qr_nonsquare_vec(full_rank, jvp, wide, dtype, getkey): 190 | if wide: 191 | out_size = 3 192 | in_size = 6 193 | else: 194 | out_size = 6 195 | in_size = 3 196 | matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) 197 | if full_rank: 198 | context = contextlib.nullcontext() 199 | else: 200 | context = pytest.raises(Exception) 201 | if wide: 202 | matrix = matrix.at[:, 2:].set(0) 203 | else: 204 | matrix = matrix.at[2:, :].set(0) 205 | vector = jr.normal(getkey(), (out_size,), dtype=dtype) 206 | lx_solve = lambda vec: lx.linear_solve( 207 | lx.MatrixLinearOperator(matrix), vec, lx.QR() 208 | ).value 209 | jnp_solve = lambda vec: jnp.linalg.lstsq(matrix, vec)[0] # pyright: ignore 210 | if jvp: 211 | lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve)) 212 | jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve)) 213 | t_vector = jr.normal(getkey(), (out_size,), dtype=dtype) 214 | args = ((vector,), (t_vector,)) 215 | else: 216 | args = (vector,) 217 | with context: 218 | x = lx_solve(*args) # pyright: ignore 219 | if full_rank: 220 | true_x = jnp_solve(*args) 221 | assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4) 222 | 223 | 224 | _iterative_solvers = ( 225 | (lx.CG(rtol=tol, atol=tol), lx.positive_semidefinite_tag), 226 | (lx.CG(rtol=tol, atol=tol, max_steps=512), lx.negative_semidefinite_tag), 227 | (lx.GMRES(rtol=tol, atol=tol), ()), 228 | (lx.BiCGStab(rtol=tol, atol=tol), ()), 229 | ) 230 | 231 | 232 | @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) 233 | @pytest.mark.parametrize("solver, tags", _iterative_solvers) 234 | @pytest.mark.parametrize("use_state", (False, True)) 235 | @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) 236 | def test_iterative_singular(getkey, solver, tags, use_state, make_operator, dtype): 237 | (matrix,) = construct_singular_matrix(getkey, solver, tags) 238 | operator = make_operator(getkey, matrix, tags) 239 | 240 | out_size, _ = matrix.shape 241 | vec = jr.normal(getkey(), (out_size,), dtype=dtype) 242 | 243 | if use_state: 244 | state = solver.init(operator, options={}) 245 | linear_solve = ft.partial(lx.linear_solve, state=state) 246 | else: 247 | linear_solve = lx.linear_solve 248 | 249 | with pytest.raises(Exception): 250 | linear_solve(operator, vec, solver) 251 | -------------------------------------------------------------------------------- /docs/examples/operators.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2fe0b1e4-35cb-4c39-b324-65253aab005a", 6 | "metadata": {}, 7 | "source": [ 8 | "# Manipulating linear operators\n", 9 | "\n", 10 | "Lineax offers a sophisticated system of linear operators, supporting many operations.\n", 11 | "\n", 12 | "## Arithmetic\n", 13 | "\n", 14 | "To begin with, they support arithmetic, like addition and multiplication:" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "552021d3-dadf-49f3-bd17-84a18513bfcc", 21 | "metadata": { 22 | "tags": [] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import jax\n", 27 | "import jax.numpy as jnp\n", 28 | "import jax.random as jr\n", 29 | "import lineax as lx\n", 30 | "import numpy as np\n", 31 | "\n", 32 | "\n", 33 | "np.set_printoptions(precision=3)\n", 34 | "\n", 35 | "matrix = jnp.zeros((5, 5))\n", 36 | "matrix = matrix.at[0, 4].set(3) # top left corner\n", 37 | "sparse_operator = lx.MatrixLinearOperator(matrix)\n", 38 | "\n", 39 | "key0, key1, key = jr.split(jr.PRNGKey(0), 3)\n", 40 | "diag = jr.normal(key0, (5,))\n", 41 | "lower_diag = jr.normal(key0, (4,))\n", 42 | "upper_diag = jr.normal(key0, (4,))\n", 43 | "tridiag_operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\n", 44 | "\n", 45 | "identity_operator = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((5,), jnp.float32))" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "id": "a4bb9825-73cc-447e-bc4c-c3e1a121a0a3", 52 | "metadata": { 53 | "tags": [] 54 | }, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "[[-1.149 0.963 0. 0. 3. ]\n", 61 | " [ 0.963 -2.007 0.155 0. 0. ]\n", 62 | " [ 0. 0.155 0.988 -0.261 0. ]\n", 63 | " [ 0. 0. -0.261 0.931 0.899]\n", 64 | " [ 0. 0. 0. 0.899 -0.288]]\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "print((sparse_operator + tridiag_operator).as_matrix())" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "id": "759c78a1-eee7-40e9-be6c-ea8c97c29e95", 76 | "metadata": { 77 | "tags": [] 78 | }, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "[[-101.149 0.963 0. 0. 0. ]\n", 85 | " [ 0.963 -102.007 0.155 0. 0. ]\n", 86 | " [ 0. 0.155 -99.012 -0.261 0. ]\n", 87 | " [ 0. 0. -0.261 -99.069 0.899]\n", 88 | " [ 0. 0. 0. 0.899 -100.288]]\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "print((tridiag_operator - 100 * identity_operator).as_matrix())" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "id": "84412bfa-00ec-41d4-87d7-def781145a90", 99 | "metadata": {}, 100 | "source": [ 101 | "Or they can be composed together. (I.e. matrix multiplication.)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "id": "8081d97f-5579-464f-8780-ffaa1d9c5f95", 108 | "metadata": { 109 | "tags": [] 110 | }, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "[[ 0. 0. 0. 0. -3.447]\n", 117 | " [ 0. 0. 0. 0. 2.888]\n", 118 | " [ 0. 0. 0. 0. 0. ]\n", 119 | " [ 0. 0. 0. 0. 0. ]\n", 120 | " [ 0. 0. 0. 0. 0. ]]\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "print((tridiag_operator @ sparse_operator).as_matrix())" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "id": "d2c2b580-616f-4abd-a732-7f4a9b13335f", 131 | "metadata": {}, 132 | "source": [ 133 | "Or they can be transposed:" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 5, 139 | "id": "ae0393eb-3f43-490b-9842-bb374633633a", 140 | "metadata": { 141 | "tags": [] 142 | }, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "[[0. 0. 0. 0. 0.]\n", 149 | " [0. 0. 0. 0. 0.]\n", 150 | " [0. 0. 0. 0. 0.]\n", 151 | " [0. 0. 0. 0. 0.]\n", 152 | " [3. 0. 0. 0. 0.]]\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "print(sparse_operator.transpose().as_matrix()) # or sparse_operator.T will work" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "ddbbbb0f-7983-4e35-b92d-2512c9612d19", 163 | "metadata": {}, 164 | "source": [ 165 | "## Different operator types\n", 166 | "\n", 167 | "Lineax has many different operator types:\n", 168 | "\n", 169 | "- We've already seen some general examples above, like [`lineax.MatrixLinearOperator`][].\n", 170 | "- We've already seen some structured examples above, like [`lineax.TridiagonalLinearOperator`][].\n", 171 | "- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\times m}$.\n", 172 | "- Given a linear function $g \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$, then [`lineax.FunctionLinearOperator`][] represents the matrix corresponding to this linear function, i.e. the unique matrix $A$ for which $g(x) = Ax$.\n", 173 | "- etc!\n", 174 | "\n", 175 | "See the [operators](../api/operators.md) page for details on all supported operators.\n", 176 | "\n", 177 | "As above these can be freely combined:" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 6, 183 | "id": "75ad4480-8ce0-4a88-9c76-bc054b1a0eaf", 184 | "metadata": { 185 | "tags": [] 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n", 190 | "\n", 191 | "\n", 192 | "def f(y: Float[Array, \"3\"], args) -> Float[Array, \"3\"]:\n", 193 | " y0, y1, y2 = y\n", 194 | " f0 = 5 * y0 + y1**2\n", 195 | " f1 = y1 - y2 + 5\n", 196 | " f2 = y0 / (1 + 5 * y2**2)\n", 197 | " return jnp.stack([f0, f1, f2])\n", 198 | "\n", 199 | "\n", 200 | "def g(y: Float[Array, \"3\"]) -> Float[Array, \"3\"]:\n", 201 | " # Must be linear!\n", 202 | " y0, y1, y2 = y\n", 203 | " f0 = y0 - y2\n", 204 | " f1 = 0.0\n", 205 | " f2 = 5 * y1\n", 206 | " return jnp.stack([f0, f1, f2])\n", 207 | "\n", 208 | "\n", 209 | "y = jnp.array([1.0, 2.0, 3.0])\n", 210 | "in_structure = jax.eval_shape(lambda: y)\n", 211 | "jac_operator = lx.JacobianLinearOperator(f, y, args=None)\n", 212 | "fn_operator = lx.FunctionLinearOperator(g, in_structure)\n", 213 | "identity_operator = lx.IdentityLinearOperator(in_structure)\n", 214 | "\n", 215 | "operator = jac_operator @ fn_operator + 0.9 * identity_operator" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "id": "5e528057-29ff-468d-aa3d-7155dd57082d", 221 | "metadata": {}, 222 | "source": [ 223 | "This composition does not instantiate a matrix for them by default. (This is sometimes important for efficiency when working with many operators.) Instead, the composition is stored as another linear operator:" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 7, 229 | "id": "5d15150d-955f-4006-bd36-58e2e6663307", 230 | "metadata": { 231 | "tags": [] 232 | }, 233 | "outputs": [ 234 | { 235 | "name": "stdout", 236 | "output_type": "stream", 237 | "text": [ 238 | "AddLinearOperator(\n", 239 | " operator1=ComposedLinearOperator(\n", 240 | " operator1=JacobianLinearOperator(...),\n", 241 | " operator2=FunctionLinearOperator(...)\n", 242 | " ),\n", 243 | " operator2=MulLinearOperator(\n", 244 | " operator=IdentityLinearOperator(...),\n", 245 | " scalar=f32[]\n", 246 | " )\n", 247 | ")\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "import equinox as eqx # https://github.com/patrick-kidger/equinox\n", 253 | "\n", 254 | "\n", 255 | "truncate_leaf = lambda x: x in (jac_operator, fn_operator, identity_operator)\n", 256 | "eqx.tree_pprint(operator, truncate_leaf=truncate_leaf)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "id": "ff7b0591-1203-4f5e-886e-399822c68a15", 262 | "metadata": { 263 | "tags": [] 264 | }, 265 | "source": [ 266 | "If you want to materialise them into a matrix, then this can be done:" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 8, 272 | "id": "3713589f-1ac4-4e08-946b-ecc3fcf6a4c3", 273 | "metadata": { 274 | "tags": [] 275 | }, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "text/plain": [ 280 | "Array([[ 5.9 , 0. , -5. ],\n", 281 | " [ 0. , -4.1 , 0. ],\n", 282 | " [ 0.022, -0.071, 0.878]], dtype=float32)" 283 | ] 284 | }, 285 | "execution_count": 8, 286 | "metadata": {}, 287 | "output_type": "execute_result" 288 | } 289 | ], 290 | "source": [ 291 | "operator.as_matrix()" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "id": "a483517e-89d7-4e9e-ad89-1915d886c14c", 297 | "metadata": {}, 298 | "source": [ 299 | "Which can in turn be treated as another linear operator, if desired:" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 9, 305 | "id": "fccddc81-d50e-4abe-a354-38402e462b1f", 306 | "metadata": { 307 | "tags": [] 308 | }, 309 | "outputs": [ 310 | { 311 | "name": "stdout", 312 | "output_type": "stream", 313 | "text": [ 314 | "MatrixLinearOperator(\n", 315 | " matrix=Array([[ 5.9 , 0. , -5. ],\n", 316 | " [ 0. , -4.1 , 0. ],\n", 317 | " [ 0.022, -0.071, 0.878]], dtype=float32),\n", 318 | " tags=frozenset()\n", 319 | ")\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "operator_fully_materialised = lx.MatrixLinearOperator(operator.as_matrix())\n", 325 | "eqx.tree_pprint(operator_fully_materialised, short_arrays=False)" 326 | ] 327 | } 328 | ], 329 | "metadata": { 330 | "kernelspec": { 331 | "display_name": "py39", 332 | "language": "python", 333 | "name": "py39" 334 | }, 335 | "language_info": { 336 | "codemirror_mode": { 337 | "name": "ipython", 338 | "version": 3 339 | }, 340 | "file_extension": ".py", 341 | "mimetype": "text/x-python", 342 | "name": "python", 343 | "nbconvert_exporter": "python", 344 | "pygments_lexer": "ipython3", 345 | "version": "3.9.16" 346 | } 347 | }, 348 | "nbformat": 4, 349 | "nbformat_minor": 5 350 | } 351 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------