├── .github └── workflows │ └── makefile.yml ├── .gitignore ├── LICENSE ├── README.md ├── jaxeigs.py └── test.py /.github/workflows/makefile.yml: -------------------------------------------------------------------------------- 1 | name: Makefile CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - name: configure 18 | run: ./configure 19 | 20 | - name: Install dependencies 21 | run: make 22 | 23 | - name: Run check 24 | run: make check 25 | 26 | - name: Run distcheck 27 | run: make distcheck 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jaxeigs 2 | 3 | As we all known, until now, GPU version `eigs` is not supported in `JAX`. In this repo is a simple implementation of `eigs` in `JAX` by rescuing useful code in TensorNetwork by google. 4 | 5 | ## Installation 6 | 7 | Copy `jaxeigs.py` to your project. 8 | 9 | ## Usage 10 | 11 | See `test.py` for usage. -------------------------------------------------------------------------------- /jaxeigs.py: -------------------------------------------------------------------------------- 1 | __all__ = ['eigs','eigsh'] 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import warnings 6 | from typing import Optional,Type, Any, List, Tuple, Callable, Sequence, Text 7 | import functools 8 | import types 9 | import numpy as np 10 | Tensor = Any 11 | 12 | 13 | Tensor = Any 14 | _CACHED_MATVECS = {} 15 | _CACHED_FUNCTIONS = {} 16 | 17 | def randn(shape: Tuple[int, ...], 18 | dtype: Optional[np.dtype] = None, 19 | seed: Optional[int] = None) -> Tensor: 20 | if not seed: 21 | seed = np.random.randint(0, 2**63) 22 | key = jax.random.PRNGKey(seed) 23 | 24 | dtype = dtype if dtype is not None else np.dtype(np.float64) 25 | 26 | def cmplx_randn(complex_dtype, real_dtype): 27 | real_dtype = np.dtype(real_dtype) 28 | complex_dtype = np.dtype(complex_dtype) 29 | 30 | key_2 = jax.random.PRNGKey(seed + 1) 31 | 32 | real_part = jax.random.normal(key, shape, dtype=real_dtype) 33 | complex_part = jax.random.normal(key_2, shape, dtype=real_dtype) 34 | unit = ( 35 | np.complex64(1j) 36 | if complex_dtype == np.dtype(np.complex64) else np.complex128(1j)) 37 | return real_part + unit * complex_part 38 | 39 | if np.dtype(dtype) is np.dtype(jnp.complex128): 40 | return cmplx_randn(dtype, jnp.float64) 41 | if np.dtype(dtype) is np.dtype(jnp.complex64): 42 | return cmplx_randn(dtype, jnp.float32) 43 | 44 | return jax.random.normal(key, shape).astype(dtype) 45 | 46 | def random_uniform(shape: Tuple[int, ...], 47 | boundaries: Optional[Tuple[float, float]] = (0.0, 1.0), 48 | dtype: Optional[np.dtype] = None, 49 | seed: Optional[int] = None) -> Tensor: 50 | if not seed: 51 | seed = np.random.randint(0, 2**63) 52 | key = jax.random.PRNGKey(seed) 53 | 54 | dtype = dtype if dtype is not None else np.dtype(np.float64) 55 | 56 | def cmplx_random_uniform(complex_dtype, real_dtype): 57 | real_dtype = np.dtype(real_dtype) 58 | complex_dtype = np.dtype(complex_dtype) 59 | 60 | key_2 = jax.random.PRNGKey(seed + 1) 61 | 62 | real_part = jax.random.uniform( 63 | key, 64 | shape, 65 | dtype=real_dtype, 66 | minval=boundaries[0], 67 | maxval=boundaries[1]) 68 | complex_part = jax.random.uniform( 69 | key_2, 70 | shape, 71 | dtype=real_dtype, 72 | minval=boundaries[0], 73 | maxval=boundaries[1]) 74 | unit = ( 75 | np.complex64(1j) 76 | if complex_dtype == np.dtype(np.complex64) else np.complex128(1j)) 77 | return real_part + unit * complex_part 78 | 79 | if np.dtype(dtype) is np.dtype(jnp.complex128): 80 | return cmplx_random_uniform(dtype, jnp.float64) 81 | if np.dtype(dtype) is np.dtype(jnp.complex64): 82 | return cmplx_random_uniform(dtype, jnp.float32) 83 | 84 | return jax.random.uniform( 85 | key, shape, minval=boundaries[0], maxval=boundaries[1]).astype(dtype) 86 | 87 | """ 88 | Implicitly restarted Arnoldi method for finding the lowest 89 | eigenvector-eigenvalue pairs of a linear operator `A`. 90 | `A` is a function implementing the matrix-vector 91 | product. 92 | 93 | WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered 94 | at the first invocation of `eigs`, and on any subsequent calls 95 | if the python `id` of `A` changes, even if the formal definition of `A` 96 | stays the same. 97 | Example: the following will jit once at the beginning, and then never again: 98 | 99 | ```python 100 | import jax 101 | import numpy as np 102 | def A(H,x): 103 | return jax.np.dot(H,x) 104 | for n in range(100): 105 | H = jax.np.array(np.random.rand(10,10)) 106 | x = jax.np.array(np.random.rand(10,10)) 107 | res = eigs(A, [H],x) #jitting is triggerd only at `n=0` 108 | ``` 109 | 110 | The following code triggers jitting at every iteration, which 111 | results in considerably reduced performance 112 | 113 | ```python 114 | import jax 115 | import numpy as np 116 | for n in range(100): 117 | def A(H,x): 118 | return jax.np.dot(H,x) 119 | H = jax.np.array(np.random.rand(10,10)) 120 | x = jax.np.array(np.random.rand(10,10)) 121 | res = eigs(A, [H],x) #jitting is triggerd at every step `n` 122 | ``` 123 | 124 | Args: 125 | A: A (sparse) implementation of a linear operator. 126 | Call signature of `A` is `res = A(vector, *args)`, where `vector` 127 | can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`. 128 | args: A list of arguments to `A`. `A` will be called as 129 | `res = A(initial_state, *args)`. 130 | initial_state: An initial vector for the algorithm. If `None`, 131 | a random initial `Tensor` is created using the `backend.randn` method 132 | shape: The shape of the input-dimension of `A`. 133 | dtype: The dtype of the input `A`. If no `initial_state` is provided, 134 | a random initial state with shape `shape` and dtype `dtype` is created. 135 | num_krylov_vecs: The number of iterations (number of krylov vectors). 136 | numeig: The number of eigenvector-eigenvalue pairs to be computed. 137 | tol: The desired precision of the eigenvalues. For the jax backend 138 | this has currently no effect, and precision of eigenvalues is not 139 | guaranteed. This feature may be added at a later point. To increase 140 | precision the caller can either increase `maxiter` or `num_krylov_vecs`. 141 | which: Flag for targetting different types of eigenvalues. Currently 142 | supported are `which = 'LR'` (larges real part) and `which = 'LM'` 143 | (larges magnitude). 144 | maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes 145 | equivalent to a simple Arnoldi method. 146 | Returns: 147 | (eigvals, eigvecs) 148 | eigvals: A list of `numeig` eigenvalues 149 | eigvecs: A list of `numeig` eigenvectors 150 | """ 151 | def eigs(A: Callable, 152 | args: Optional[List] = None, 153 | initial_state: Optional[Tensor] = None, 154 | shape: Optional[Tuple[int, ...]] = None, 155 | dtype: Optional[Type[np.number]] = None, 156 | num_krylov_vecs: int = 50, 157 | numeig: int = 6, 158 | tol: float = 1E-8, 159 | which: Text = 'LM', 160 | maxiter: int = 20) -> Tuple[Tensor, List]: 161 | if args is None: 162 | args = [] 163 | if which not in ('LR', 'LM'): 164 | raise ValueError(f'which = {which} is currently not supported.') 165 | 166 | if numeig > num_krylov_vecs: 167 | raise ValueError('`num_krylov_vecs` >= `numeig` required!') 168 | 169 | if initial_state is None: 170 | if (shape is None) or (dtype is None): 171 | raise ValueError("if no `initial_state` is passed, then `shape` and" 172 | "`dtype` have to be provided") 173 | initial_state = randn(shape, dtype) 174 | 175 | if not isinstance(initial_state, (jnp.ndarray, np.ndarray)): 176 | raise TypeError("Expected a `jax.array`. Got {}".format( 177 | type(initial_state))) 178 | 179 | if A not in _CACHED_MATVECS: 180 | _CACHED_MATVECS[A] = jax.tree_util.Partial(jax.jit(A)) 181 | 182 | if "imp_arnoldi" not in _CACHED_FUNCTIONS: 183 | imp_arnoldi = _implicitly_restarted_arnoldi(jax) 184 | _CACHED_FUNCTIONS["imp_arnoldi"] = imp_arnoldi 185 | 186 | eta, U, numits = _CACHED_FUNCTIONS["imp_arnoldi"](_CACHED_MATVECS[A], args, 187 | initial_state, 188 | num_krylov_vecs, numeig, 189 | which, tol, maxiter, 190 | jax.lax.Precision.DEFAULT) 191 | if numeig > numits: 192 | warnings.warn( 193 | f"Arnoldi terminated early after numits = {numits}" 194 | f" < numeig = {numeig} steps. For this value of `numeig `" 195 | f"the routine will return spurious eigenvalues of value 0.0." 196 | f"Use a smaller value of numeig, or a smaller value for `tol`") 197 | return eta, U 198 | 199 | def eigsh(A: Callable, 200 | args: Optional[List] = None, 201 | initial_state: Optional[Tensor] = None, 202 | shape: Optional[Tuple[int, ...]] = None, 203 | dtype: Optional[Type[np.number]] = None, 204 | num_krylov_vecs: int = 50, 205 | numeig: int = 6, 206 | tol: float = 1E-8, 207 | which: Text = 'SA', 208 | maxiter: int = 20) -> Tuple[Tensor, List]: 209 | """ 210 | Implicitly restarted Lanczos method for finding the lowest 211 | eigenvector-eigenvalue pairs of a symmetric (hermitian) linear operator `A`. 212 | `A` is a function implementing the matrix-vector 213 | product. 214 | 215 | WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered 216 | at the first invocation of `eigsh`, and on any subsequent calls 217 | if the python `id` of `A` changes, even if the formal definition of `A` 218 | stays the same. 219 | Example: the following will jit once at the beginning, and then never again: 220 | 221 | ```python 222 | import jax 223 | import numpy as np 224 | def A(H,x): 225 | return jax.np.dot(H,x) 226 | for n in range(100): 227 | H = jax.np.array(np.random.rand(10,10)) 228 | x = jax.np.array(np.random.rand(10,10)) 229 | res = eigsh(A, [H],x) #jitting is triggerd only at `n=0` 230 | ``` 231 | 232 | The following code triggers jitting at every iteration, which 233 | results in considerably reduced performance 234 | 235 | ```python 236 | import jax 237 | import numpy as np 238 | for n in range(100): 239 | def A(H,x): 240 | return jax.np.dot(H,x) 241 | H = jax.np.array(np.random.rand(10,10)) 242 | x = jax.np.array(np.random.rand(10,10)) 243 | res = eigsh(A, [H],x) #jitting is triggerd at every step `n` 244 | ``` 245 | 246 | Args: 247 | A: A (sparse) implementation of a linear operator. 248 | Call signature of `A` is `res = A(vector, *args)`, where `vector` 249 | can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`. 250 | args: A list of arguments to `A`. `A` will be called as 251 | `res = A(initial_state, *args)`. 252 | initial_state: An initial vector for the algorithm. If `None`, 253 | a random initial `Tensor` is created using the `backend.randn` method 254 | shape: The shape of the input-dimension of `A`. 255 | dtype: The dtype of the input `A`. If no `initial_state` is provided, 256 | a random initial state with shape `shape` and dtype `dtype` is created. 257 | num_krylov_vecs: The number of iterations (number of krylov vectors). 258 | numeig: The number of eigenvector-eigenvalue pairs to be computed. 259 | tol: The desired precision of the eigenvalues. For the jax backend 260 | this has currently no effect, and precision of eigenvalues is not 261 | guaranteed. This feature may be added at a later point. To increase 262 | precision the caller can either increase `maxiter` or `num_krylov_vecs`. 263 | which: Flag for targetting different types of eigenvalues. Currently 264 | supported are `which = 'LR'` (larges real part) and `which = 'LM'` 265 | (larges magnitude). 266 | maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes 267 | equivalent to a simple Arnoldi method. 268 | Returns: 269 | (eigvals, eigvecs) 270 | eigvals: A list of `numeig` eigenvalues 271 | eigvecs: A list of `numeig` eigenvectors 272 | """ 273 | 274 | if args is None: 275 | args = [] 276 | if which not in ('SA', 'LA', 'LM'): 277 | raise ValueError(f'which = {which} is currently not supported.') 278 | 279 | if numeig > num_krylov_vecs: 280 | raise ValueError('`num_krylov_vecs` >= `numeig` required!') 281 | 282 | if initial_state is None: 283 | if (shape is None) or (dtype is None): 284 | raise ValueError("if no `initial_state` is passed, then `shape` and" 285 | "`dtype` have to be provided") 286 | initial_state = randn(shape, dtype) 287 | 288 | if not isinstance(initial_state, (jnp.ndarray, np.ndarray)): 289 | raise TypeError("Expected a `jax.array`. Got {}".format( 290 | type(initial_state))) 291 | 292 | if A not in _CACHED_MATVECS: 293 | _CACHED_MATVECS[A] = jax.tree_util.Partial(jax.jit(A)) 294 | 295 | if "imp_lanczos" not in _CACHED_FUNCTIONS: 296 | imp_lanczos = _implicitly_restarted_lanczos(jax) 297 | _CACHED_FUNCTIONS["imp_lanczos"] = imp_lanczos 298 | 299 | eta, U, numits = _CACHED_FUNCTIONS["imp_lanczos"](_CACHED_MATVECS[A], args, 300 | initial_state, 301 | num_krylov_vecs, numeig, 302 | which, tol, maxiter, 303 | jax.lax.Precision.DEFAULT) 304 | if numeig > numits: 305 | warnings.warn( 306 | f"Arnoldi terminated early after numits = {numits}" 307 | f" < numeig = {numeig} steps. For this value of `numeig `" 308 | f"the routine will return spurious eigenvalues of value 0.0." 309 | f"Use a smaller value of numeig, or a smaller value for `tol`") 310 | return eta, U 311 | 312 | def cpu_eig_host(H): 313 | res = np.linalg.eig(H) 314 | print(res) 315 | return res 316 | 317 | def cpu_eig(H): 318 | result_shape = (jax.ShapeDtypeStruct(H.shape[0:1], H.dtype), 319 | jax.ShapeDtypeStruct(H.shape, H.dtype)) 320 | return jax.pure_callback(cpu_eig_host, result_shape, H) 321 | 322 | def _iterative_classical_gram_schmidt(jax: types.ModuleType) -> Callable: 323 | 324 | JaxPrecisionType = type(jax.lax.Precision.DEFAULT) 325 | def iterative_classical_gram_schmidt( 326 | vector: jax.Array, 327 | krylov_vectors: jax.Array, 328 | precision: JaxPrecisionType, 329 | iterations: int = 2, 330 | ) -> jax.Array: 331 | """ 332 | Orthogonalize `vector` to all rows of `krylov_vectors`. 333 | 334 | Args: 335 | vector: Initial vector. 336 | krylov_vectors: Matrix of krylov vectors, each row is treated as a 337 | vector. 338 | iterations: Number of iterations. 339 | 340 | Returns: 341 | jax.Array: The orthogonalized vector. 342 | """ 343 | i1 = list(range(1, len(krylov_vectors.shape))) 344 | i2 = list(range(len(vector.shape))) 345 | 346 | vec = vector 347 | overlaps = 0 348 | for _ in range(iterations): 349 | ov = jax.numpy.tensordot( 350 | krylov_vectors.conj(), vec, (i1, i2), precision=precision) 351 | vec = vec - jax.numpy.tensordot( 352 | ov, krylov_vectors, ([0], [0]), precision=precision) 353 | overlaps = overlaps + ov 354 | return vec, overlaps 355 | 356 | return iterative_classical_gram_schmidt 357 | 358 | 359 | def _generate_jitted_eigsh_lanczos(jax: types.ModuleType) -> Callable: 360 | """ 361 | Helper function to generate jitted lanczos function used 362 | in JaxBackend.eigsh_lanczos. The function `jax_lanczos` 363 | returned by this higher-order function has the following 364 | call signature: 365 | ``` 366 | eigenvalues, eigenvectors = jax_lanczos(matvec:Callable, 367 | arguments: List[Tensor], 368 | init: Tensor, 369 | ncv: int, 370 | neig: int, 371 | landelta: float, 372 | reortho: bool) 373 | ``` 374 | `matvec`: A callable implementing the matrix-vector product of a 375 | linear operator. `arguments`: Arguments to `matvec` additional to 376 | an input vector. `matvec` will be called as `matvec(init, *args)`. 377 | `init`: An initial input vector to `matvec`. 378 | `ncv`: Number of krylov iterations (i.e. dimension of the Krylov space). 379 | `neig`: Number of eigenvalue-eigenvector pairs to be computed. 380 | `landelta`: Convergence parameter: if the norm of the current Lanczos vector 381 | 382 | `reortho`: If `True`, reorthogonalize all krylov vectors at each step. 383 | This should be used if `neig>1`. 384 | 385 | Args: 386 | jax: The `jax` module. 387 | Returns: 388 | Callable: A jitted function that does a lanczos iteration. 389 | 390 | """ 391 | JaxPrecisionType = type(jax.lax.Precision.DEFAULT) 392 | 393 | @functools.partial(jax.jit, static_argnums=(3, 4, 5, 6, 7)) 394 | def jax_lanczos(matvec: Callable, arguments: List, init: jax.Array, 395 | ncv: int, neig: int, landelta: float, reortho: bool, 396 | precision: JaxPrecisionType) -> Tuple[jax.Array, List]: 397 | """ 398 | Lanczos iteration for symmeric eigenvalue problems. If reortho = False, 399 | the Krylov basis is constructed without explicit re-orthogonalization. 400 | In infinite precision, all Krylov vectors would be orthogonal. Due to 401 | finite precision arithmetic, orthogonality is usually quickly lost. 402 | For reortho=True, the Krylov basis is explicitly reorthogonalized. 403 | 404 | Args: 405 | matvec: A callable implementing the matrix-vector product of a 406 | linear operator. 407 | arguments: Arguments to `matvec` additional to an input vector. 408 | `matvec` will be called as `matvec(init, *args)`. 409 | init: An initial input vector to `matvec`. 410 | ncv: Number of krylov iterations (i.e. dimension of the Krylov space). 411 | neig: Number of eigenvalue-eigenvector pairs to be computed. 412 | landelta: Convergence parameter: if the norm of the current Lanczos vector 413 | falls below `landelta`, iteration is stopped. 414 | reortho: If `True`, reorthogonalize all krylov vectors at each step. 415 | This should be used if `neig>1`. 416 | precision: jax.lax.Precision type used in jax.numpy.vdot 417 | 418 | Returns: 419 | jax.Array: Eigenvalues 420 | List: Eigenvectors 421 | int: Number of iterations 422 | """ 423 | shape = init.shape 424 | dtype = init.dtype 425 | iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax) 426 | mask_slice = (slice(ncv + 2), ) + (None,) * len(shape) 427 | def scalar_product(a, b): 428 | i1 = list(range(len(a.shape))) 429 | i2 = list(range(len(b.shape))) 430 | return jax.numpy.tensordot(a.conj(), b, (i1, i2), precision=precision) 431 | 432 | def norm(a): 433 | return jax.numpy.sqrt(scalar_product(a, a)) 434 | 435 | def body_lanczos(vals): 436 | krylov_vectors, alphas, betas, i = vals 437 | previous_vector = krylov_vectors[i] 438 | def body_while(vals): 439 | pv, kv, _ = vals 440 | pv = iterative_classical_gram_schmidt( 441 | pv, (i > jax.numpy.arange(ncv + 2))[mask_slice] * kv, precision)[0] 442 | return [pv, kv, False] 443 | 444 | def cond_while(vals): 445 | return vals[2] 446 | 447 | previous_vector, krylov_vectors, _ = jax.lax.while_loop( 448 | cond_while, body_while, 449 | [previous_vector, krylov_vectors, reortho]) 450 | 451 | beta = norm(previous_vector) 452 | normalized_vector = previous_vector / beta 453 | Av = matvec(normalized_vector, *arguments) 454 | alpha = scalar_product(normalized_vector, Av) 455 | alphas = alphas.at[i - 1].set(alpha) 456 | betas = betas.at[i].set(beta) 457 | 458 | def while_next(vals): 459 | Av, _ = vals 460 | res = Av - normalized_vector * alpha - krylov_vectors[i - 1] * beta 461 | return [res, False] 462 | 463 | def cond_next(vals): 464 | return vals[1] 465 | 466 | next_vector, _ = jax.lax.while_loop( 467 | cond_next, while_next, 468 | [Av, jax.numpy.logical_not(reortho)]) 469 | next_vector = jax.numpy.reshape(next_vector, shape) 470 | 471 | krylov_vectors = krylov_vectors.at[i].set(normalized_vector) 472 | krylov_vectors = krylov_vectors.at[i + 1].set(next_vector) 473 | 474 | return [krylov_vectors, alphas, betas, i + 1] 475 | 476 | def cond_fun(vals): 477 | betas, i = vals[-2], vals[-1] 478 | norm = betas[i - 1] 479 | return jax.lax.cond(i <= ncv, lambda x: x[0] > x[1], lambda x: False, 480 | [norm, landelta]) 481 | 482 | # note: ncv + 2 because the first vector is all zeros, and the 483 | # last is the unnormalized residual. 484 | krylov_vecs = jax.numpy.zeros((ncv + 2,) + shape, dtype=dtype) 485 | # NOTE (mganahl): initial vector is normalized inside the loop 486 | krylov_vecs = krylov_vecs.at[1].set(init) 487 | 488 | # betas are the upper and lower diagonal elements 489 | # of the projected linear operator 490 | # the first two beta-values can be discarded 491 | # set betas[0] to 1.0 for initialization of loop 492 | # betas[2] is set to the norm of the initial vector. 493 | betas = jax.numpy.zeros(ncv + 1, dtype=dtype) 494 | betas = betas.at[0].set(1.0) 495 | # diagonal elements of the projected linear operator 496 | alphas = jax.numpy.zeros(ncv, dtype=dtype) 497 | initvals = [krylov_vecs, alphas, betas, 1] 498 | krylov_vecs, alphas, betas, numits = jax.lax.while_loop( 499 | cond_fun, body_lanczos, initvals) 500 | # FIXME (mganahl): if the while_loop stopps early at iteration i, alphas 501 | # and betas are 0.0 at positions n >= i - 1. eigh will then wrongly give 502 | # degenerate eigenvalues 0.0. JAX does currently not support 503 | # dynamic slicing with variable slice sizes, so these beta values 504 | # can't be truncated. Thus, if numeig >= i - 1, jitted_lanczos returns 505 | # a set of spurious eigen vectors and eigen values. 506 | # If algebraically small EVs are desired, one can initialize `alphas` with 507 | # large positive values, thus pushing the spurious eigenvalues further 508 | # away from the desired ones (similar for algebraically large EVs) 509 | 510 | #FIXME: replace with eigh_banded once JAX supports it 511 | A_tridiag = jax.numpy.diag(alphas) + jax.numpy.diag( 512 | betas[2:], 1) + jax.numpy.diag(jax.numpy.conj(betas[2:]), -1) 513 | eigvals, U = jax.numpy.linalg.eigh(A_tridiag) 514 | eigvals = eigvals.astype(dtype) 515 | 516 | # expand eigenvectors in krylov basis 517 | def body_vector(i, vals): 518 | krv, unitary, vectors = vals 519 | dim = unitary.shape[1] 520 | n, m = jax.numpy.divmod(i, dim) 521 | vectors = vectors.at[n, :].set(vectors[n, :] + krv[m + 1] * unitary[m, n]) 522 | return [krv, unitary, vectors] 523 | 524 | _vectors = jax.numpy.zeros((neig,) + shape, dtype=dtype) 525 | _, _, vectors = jax.lax.fori_loop(0, neig * (krylov_vecs.shape[0] - 1), 526 | body_vector, 527 | [krylov_vecs, U, _vectors]) 528 | 529 | return jax.numpy.array(eigvals[0:neig]), [ 530 | vectors[n] / norm(vectors[n]) for n in range(neig) 531 | ], numits 532 | 533 | return jax_lanczos 534 | 535 | 536 | def _generate_lanczos_factorization(jax: types.ModuleType) -> Callable: 537 | """ 538 | Helper function to generate a jitteed function that 539 | computes a lanczos factoriazation of a linear operator. 540 | Returns: 541 | Callable: A jitted function that does a lanczos factorization. 542 | 543 | """ 544 | JaxPrecisionType = type(jax.lax.Precision.DEFAULT) 545 | 546 | @functools.partial(jax.jit, static_argnums=(6, 7, 8, 9)) 547 | def _lanczos_fact( 548 | matvec: Callable, args: List, v0: jax.Array, 549 | Vm: jax.Array, alphas: jax.Array, betas: jax.Array, 550 | start: int, num_krylov_vecs: int, tol: float, precision: JaxPrecisionType 551 | ): 552 | """ 553 | Compute an m-step lanczos factorization of `matvec`, with 554 | m <=`num_krylov_vecs`. The factorization will 555 | do at most `num_krylov_vecs` steps, and terminate early 556 | if an invariat subspace is encountered. The returned arrays 557 | `alphas`, `betas` and `Vm` will satisfy the Lanczos recurrence relation 558 | ``` 559 | matrix @ Vm - Vm @ Hm - fm * em = 0 560 | ``` 561 | with `matrix` the matrix representation of `matvec`, 562 | `Hm = jnp.diag(alphas) + jnp.diag(betas, -1) + jnp.diag(betas.conj(), 1)` 563 | `fm=residual * norm`, and `em` a cartesian basis vector of shape 564 | `(1, kv.shape[1])` with `em[0, -1] == 1` and 0 elsewhere. 565 | 566 | Note that the caller is responsible for dtype consistency between 567 | the inputs, i.e. dtypes between all input arrays have to match. 568 | 569 | Args: 570 | matvec: The matrix vector product. 571 | args: List of arguments to `matvec`. 572 | v0: Initial state to `matvec`. 573 | Vm: An array for storing the krylov vectors. The individual 574 | vectors are stored as columns. 575 | The shape of `krylov_vecs` has to be 576 | (num_krylov_vecs + 1, np.ravel(v0).shape[0]). 577 | alphas: An array for storing the diagonal elements of the reduced 578 | operator. 579 | betas: An array for storing the lower diagonal elements of the 580 | reduced operator. 581 | start: Integer denoting the start position where the first 582 | produced krylov_vector should be inserted into `Vm` 583 | num_krylov_vecs: Number of krylov iterations, should be identical to 584 | `Vm.shape[0] + 1` 585 | tol: Convergence parameter. Iteration is terminated if the norm of a 586 | krylov-vector falls below `tol`. 587 | 588 | Returns: 589 | jax.Array: An array of shape 590 | `(num_krylov_vecs, np.prod(initial_state.shape))` of krylov vectors. 591 | jax.Array: The diagonal elements of the tridiagonal reduced 592 | operator ("alphas") 593 | jax.Array: The lower-diagonal elements of the tridiagonal reduced 594 | operator ("betas") 595 | jax.Array: The unnormalized residual of the Lanczos process. 596 | float: The norm of the residual. 597 | int: The number of performed iterations. 598 | bool: if `True`: iteration hit an invariant subspace. 599 | if `False`: iteration terminated without encountering 600 | an invariant subspace. 601 | """ 602 | 603 | shape = v0.shape 604 | iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax) 605 | Z = jax.numpy.linalg.norm(v0) 606 | #only normalize if norm > tol, else return zero vector 607 | v = jax.lax.cond(Z > tol, lambda x: v0 / Z, lambda x: v0 * 0.0, None) 608 | Vm = Vm.at[start, :].set(jax.numpy.ravel(v)) 609 | betas = jax.lax.cond( 610 | start > 0, 611 | lambda x: betas.at[start - 1].set(Z), 612 | lambda x: betas, start) 613 | # body of the arnoldi iteration 614 | def body(vals): 615 | Vm, alphas, betas, previous_vector, _, i = vals 616 | Av = matvec(previous_vector, *args) 617 | Av, overlaps = iterative_classical_gram_schmidt( 618 | Av.ravel(), 619 | (i >= jax.numpy.arange(Vm.shape[0]))[:, None] * Vm, precision) 620 | alphas = alphas.at[i].set(overlaps[i]) 621 | norm = jax.numpy.linalg.norm(Av) 622 | Av = jax.numpy.reshape(Av, shape) 623 | # only normalize if norm is larger than threshold, 624 | # otherwise return zero vector 625 | Av = jax.lax.cond(norm > tol, lambda x: Av/norm, lambda x: Av * 0.0, None) 626 | Vm, betas = jax.lax.cond( 627 | i < num_krylov_vecs - 1, 628 | lambda x: (Vm.at[i + 1, :].set(Av.ravel()), betas.at[i].set(norm)), 629 | lambda x: (Vm, betas), 630 | None) 631 | 632 | return [Vm, alphas, betas, Av, norm, i + 1] 633 | 634 | def cond_fun(vals): 635 | # Continue loop while iteration < num_krylov_vecs and norm > tol 636 | norm, iteration = vals[4], vals[5] 637 | counter_done = (iteration >= num_krylov_vecs) 638 | norm_not_too_small = norm > tol 639 | continue_iteration = jax.lax.cond(counter_done, lambda x: False, 640 | lambda x: norm_not_too_small, None) 641 | return continue_iteration 642 | initial_values = [Vm, alphas, betas, v, Z, start] 643 | final_values = jax.lax.while_loop(cond_fun, body, initial_values) 644 | Vm, alphas, betas, residual, norm, it = final_values 645 | return Vm, alphas, betas, residual, norm, it, norm < tol 646 | 647 | return _lanczos_fact 648 | 649 | 650 | def _generate_arnoldi_factorization(jax: types.ModuleType) -> Callable: 651 | """ 652 | Helper function to create a jitted arnoldi factorization. 653 | The function returns a function `_arnoldi_fact` which 654 | performs an m-step arnoldi factorization. 655 | 656 | `_arnoldi_fact` computes an m-step arnoldi factorization 657 | of an input callable `matvec`, with m = min(`it`,`num_krylov_vecs`). 658 | `_arnoldi_fact` will do at most `num_krylov_vecs` steps. 659 | `_arnoldi_fact` returns arrays `kv` and `H` which satisfy 660 | the Arnoldi recurrence relation 661 | ``` 662 | matrix @ Vm - Vm @ Hm - fm * em = 0 663 | ``` 664 | with `matrix` the matrix representation of `matvec` and 665 | `Vm = jax.numpy.transpose(kv[:it, :])`, 666 | `Hm = H[:it, :it]`, `fm = np.expand_dims(kv[it, :] * H[it, it - 1]`,1) 667 | and `em` a kartesian basis vector of shape `(1, kv.shape[1])` 668 | with `em[0, -1] == 1` and 0 elsewhere. 669 | 670 | Note that the caller is responsible for dtype consistency between 671 | the inputs, i.e. dtypes between all input arrays have to match. 672 | 673 | Args: 674 | matvec: The matrix vector product. This function has to be wrapped into 675 | `jax.tree_util.Partial`. `matvec` will be called as `matvec(x, *args)` 676 | for an input vector `x`. 677 | args: List of arguments to `matvec`. 678 | v0: Initial state to `matvec`. 679 | Vm: An array for storing the krylov vectors. The individual 680 | vectors are stored as columns. The shape of `krylov_vecs` has to be 681 | (num_krylov_vecs + 1, np.ravel(v0).shape[0]). 682 | H: Matrix of overlaps. The shape has to be 683 | (num_krylov_vecs + 1,num_krylov_vecs + 1). 684 | start: Integer denoting the start position where the first 685 | produced krylov_vector should be inserted into `Vm` 686 | num_krylov_vecs: Number of krylov iterations, should be identical to 687 | `Vm.shape[0] + 1` 688 | tol: Convergence parameter. Iteration is terminated if the norm of a 689 | krylov-vector falls below `tol`. 690 | 691 | Returns: 692 | kv: An array of krylov vectors 693 | H: A matrix of overlaps 694 | it: The number of performed iterations. 695 | converged: Whether convergence was achieved. 696 | 697 | """ 698 | JaxPrecisionType = type(jax.lax.Precision.DEFAULT) 699 | iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax) 700 | 701 | @functools.partial(jax.jit, static_argnums=(5, 6, 7, 8)) 702 | def _arnoldi_fact( 703 | matvec: Callable, args: List, v0: jax.Array, 704 | Vm: jax.Array, H: jax.Array, start: int, 705 | num_krylov_vecs: int, tol: float, precision: JaxPrecisionType 706 | ) -> Tuple[jax.Array, jax.Array, jax.Array, float, int, 707 | bool]: 708 | """ 709 | Compute an m-step arnoldi factorization of `matvec`, with 710 | m = min(`it`,`num_krylov_vecs`). The factorization will 711 | do at most `num_krylov_vecs` steps. The returned arrays 712 | `kv` and `H` will satisfy the Arnoldi recurrence relation 713 | ``` 714 | matrix @ Vm - Vm @ Hm - fm * em = 0 715 | ``` 716 | with `matrix` the matrix representation of `matvec` and 717 | `Vm = jax.numpy.transpose(kv[:it, :])`, 718 | `Hm = H[:it, :it]`, `fm = np.expand_dims(kv[it, :] * H[it, it - 1]`,1) 719 | and `em` a cartesian basis vector of shape `(1, kv.shape[1])` 720 | with `em[0, -1] == 1` and 0 elsewhere. 721 | 722 | Note that the caller is responsible for dtype consistency between 723 | the inputs, i.e. dtypes between all input arrays have to match. 724 | 725 | Args: 726 | matvec: The matrix vector product. 727 | args: List of arguments to `matvec`. 728 | v0: Initial state to `matvec`. 729 | Vm: An array for storing the krylov vectors. The individual 730 | vectors are stored as columns. 731 | The shape of `krylov_vecs` has to be 732 | (num_krylov_vecs + 1, np.ravel(v0).shape[0]). 733 | H: Matrix of overlaps. The shape has to be 734 | (num_krylov_vecs + 1,num_krylov_vecs + 1). 735 | start: Integer denoting the start position where the first 736 | produced krylov_vector should be inserted into `Vm` 737 | num_krylov_vecs: Number of krylov iterations, should be identical to 738 | `Vm.shape[0] + 1` 739 | tol: Convergence parameter. Iteration is terminated if the norm of a 740 | krylov-vector falls below `tol`. 741 | Returns: 742 | jax.Array: An array of shape 743 | `(num_krylov_vecs, np.prod(initial_state.shape))` of krylov vectors. 744 | jax.Array: Upper Hessenberg matrix of shape 745 | `(num_krylov_vecs, num_krylov_vecs`) of the Arnoldi processs. 746 | jax.Array: The unnormalized residual of the Arnoldi process. 747 | int: The norm of the residual. 748 | int: The number of performed iterations. 749 | bool: if `True`: iteration hit an invariant subspace. 750 | if `False`: iteration terminated without encountering 751 | an invariant subspace. 752 | """ 753 | 754 | # Note (mganahl): currently unused, but is very convenient to have 755 | # for further development and tests (it's usually more accurate than 756 | # classical gs) 757 | # Call signature: 758 | #```python 759 | # initial_vals = [Av.ravel(), Vm, i, H] 760 | # Av, Vm, _, H = jax.lax.fori_loop( 761 | # 0, i + 1, modified_gram_schmidt_step_arnoldi, initial_vals) 762 | #``` 763 | def modified_gram_schmidt_step_arnoldi(j, vals): #pylint: disable=unused-variable 764 | """ 765 | Single step of a modified gram-schmidt orthogonalization. 766 | Substantially more accurate than classical gram schmidt 767 | Args: 768 | j: Integer value denoting the vector to be orthogonalized. 769 | vals: A list of variables: 770 | `vector`: The current vector to be orthogonalized 771 | to all previous ones 772 | `Vm`: jax.array of collected krylov vectors 773 | `n`: integer denoting the column-position of the overlap 774 | <`krylov_vector`|`vector`> within `H`. 775 | Returns: 776 | updated vals. 777 | 778 | """ 779 | vector, krylov_vectors, n, H = vals 780 | v = krylov_vectors[j, :] 781 | h = jax.numpy.vdot(v, vector, precision=precision) 782 | H = H.at[j, n].set(h) 783 | vector = vector - h * v 784 | return [vector, krylov_vectors, n, H] 785 | 786 | shape = v0.shape 787 | Z = jax.numpy.linalg.norm(v0) 788 | #only normalize if norm > tol, else return zero vector 789 | v = jax.lax.cond(Z > tol, lambda x: v0 / Z, lambda x: v0 * 0.0, None) 790 | Vm = Vm.at[start, :].set(jax.numpy.ravel(v)) 791 | H = jax.lax.cond( 792 | start > 0, 793 | lambda x: H.at[x, x - 1].set(Z), 794 | lambda x: H, start) 795 | # body of the arnoldi iteration 796 | def body(vals): 797 | Vm, H, previous_vector, _, i = vals 798 | Av = matvec(previous_vector, *args) 799 | 800 | Av, overlaps = iterative_classical_gram_schmidt( 801 | Av.ravel(), 802 | (i >= jax.numpy.arange(Vm.shape[0]))[:, None] * 803 | Vm, precision) 804 | H = H.at[:, i].set(overlaps) 805 | norm = jax.numpy.linalg.norm(Av) 806 | Av = jax.numpy.reshape(Av, shape) 807 | 808 | # only normalize if norm is larger than threshold, 809 | # otherwise return zero vector 810 | Av = jax.lax.cond(norm > tol, lambda x: Av/norm, lambda x: Av * 0.0, None) 811 | Vm, H = jax.lax.cond( 812 | i < num_krylov_vecs - 1, 813 | lambda x: (Vm.at[i + 1, :].set(Av.ravel()), H.at[i + 1, i].set(norm)), #pylint: disable=line-too-long 814 | lambda x: (x[0], x[1]), 815 | (Vm, H, Av, i, norm)) 816 | 817 | return [Vm, H, Av, norm, i + 1] 818 | 819 | def cond_fun(vals): 820 | # Continue loop while iteration < num_krylov_vecs and norm > tol 821 | norm, iteration = vals[3], vals[4] 822 | counter_done = (iteration >= num_krylov_vecs) 823 | norm_not_too_small = norm > tol 824 | continue_iteration = jax.lax.cond(counter_done, lambda x: False, 825 | lambda x: norm_not_too_small, None) 826 | return continue_iteration 827 | 828 | initial_values = [Vm, H, v, Z, start] 829 | final_values = jax.lax.while_loop(cond_fun, body, initial_values) 830 | Vm, H, residual, norm, it = final_values 831 | return Vm, H, residual, norm, it, norm < tol 832 | 833 | return _arnoldi_fact 834 | 835 | # ###################################################### 836 | # ####### NEW SORTING FUCTIONS INSERTED HERE ######### 837 | # ###################################################### 838 | def _LR_sort(jax): 839 | @functools.partial(jax.jit, static_argnums=(0,)) 840 | def sorter( 841 | p: int, 842 | evals: jax.Array) -> Tuple[jax.Array, jax.Array]: 843 | inds = jax.numpy.argsort(jax.numpy.real(evals), stable=True)[::-1] 844 | shifts = evals[inds][-p:] 845 | return shifts, inds 846 | return sorter 847 | 848 | def _SA_sort(jax): 849 | @functools.partial(jax.jit, static_argnums=(0,)) 850 | def sorter( 851 | p: int, 852 | evals: jax.Array) -> Tuple[jax.Array, jax.Array]: 853 | inds = jax.numpy.argsort(evals, stable=True) 854 | shifts = evals[inds][-p:] 855 | return shifts, inds 856 | return sorter 857 | 858 | def _LA_sort(jax): 859 | @functools.partial(jax.jit, static_argnums=(0,)) 860 | def sorter( 861 | p: int, 862 | evals: jax.Array) -> Tuple[jax.Array, jax.Array]: 863 | inds = jax.numpy.argsort(evals, kind='stable')[::-1] 864 | shifts = evals[inds][-p:] 865 | return shifts, inds 866 | return sorter 867 | 868 | def _LM_sort(jax): 869 | @functools.partial(jax.jit, static_argnums=(0,)) 870 | def sorter( 871 | p: int, 872 | evals: jax.Array) -> Tuple[jax.Array, jax.Array]: 873 | inds = jax.numpy.argsort(jax.numpy.abs(evals), stable=True)[::-1] 874 | shifts = evals[inds][-p:] 875 | return shifts, inds 876 | return sorter 877 | 878 | # #################################################### 879 | # #################################################### 880 | 881 | def _shifted_QR(jax): 882 | @functools.partial(jax.jit, static_argnums=(4,)) 883 | def shifted_QR( 884 | Vm: jax.Array, Hm: jax.Array, fm: jax.Array, 885 | shifts: jax.Array, 886 | numeig: int) -> Tuple[jax.Array, jax.Array, jax.Array]: 887 | # compress arnoldi factorization 888 | q = jax.numpy.zeros(Hm.shape[0], dtype=Hm.dtype) 889 | q = q.at[-1].set(1.0) 890 | 891 | def body(i, vals): 892 | Vm, Hm, q = vals 893 | shift = shifts[i] * jax.numpy.eye(Hm.shape[0], dtype=Hm.dtype) 894 | Qj, R = jax.numpy.linalg.qr(Hm - shift) 895 | Hm = R @ Qj + shift 896 | Vm = Qj.T @ Vm 897 | q = q @ Qj 898 | return Vm, Hm, q 899 | 900 | Vm, Hm, q = jax.lax.fori_loop(0, shifts.shape[0], body, 901 | (Vm, Hm, q)) 902 | fk = Vm[numeig, :] * Hm[numeig, numeig - 1] + fm * q[numeig - 1] 903 | return Vm, Hm, fk 904 | return shifted_QR 905 | 906 | def _get_vectors(jax): 907 | @functools.partial(jax.jit, static_argnums=(3,)) 908 | def get_vectors(Vm: jax.Array, unitary: jax.Array, 909 | inds: jax.Array, numeig: int) -> jax.Array: 910 | 911 | def body_vector(i, states): 912 | dim = unitary.shape[1] 913 | n, m = jax.numpy.divmod(i, dim) 914 | states = states.at[n, :].set(states[n,:] + Vm[m, :] * unitary[m, inds[n]]) 915 | return states 916 | 917 | state_vectors = jax.numpy.zeros([numeig, Vm.shape[1]], dtype=Vm.dtype) 918 | state_vectors = jax.lax.fori_loop(0, numeig * Vm.shape[0], body_vector, 919 | state_vectors) 920 | state_norms = jax.numpy.linalg.norm(state_vectors, axis=1) 921 | state_vectors = state_vectors / state_norms[:, None] 922 | return state_vectors 923 | 924 | return get_vectors 925 | 926 | def _check_eigvals_convergence_eigh(jax): 927 | @functools.partial(jax.jit, static_argnums=(3,)) 928 | def check_eigvals_convergence(beta_m: float, Hm: jax.Array, 929 | Hm_norm: float, 930 | tol: float) -> bool: 931 | eigvals, eigvecs = jax.numpy.linalg.eigh(Hm) 932 | # TODO (mganahl) confirm that this is a valid matrix norm) 933 | thresh = jax.numpy.maximum( 934 | jax.numpy.finfo(eigvals.dtype).eps * Hm_norm, 935 | jax.numpy.abs(eigvals) * tol) 936 | vals = jax.numpy.abs(eigvecs[-1, :]) 937 | return jax.numpy.all(beta_m * vals < thresh) 938 | 939 | return check_eigvals_convergence 940 | 941 | def _check_eigvals_convergence_eig(jax): 942 | @functools.partial(jax.jit, static_argnums=(2, 3)) 943 | def check_eigvals_convergence(beta_m: float, Hm: jax.Array, 944 | tol: float, numeig: int) -> bool: 945 | eigvals, eigvecs = cpu_eig(Hm) 946 | # TODO (mganahl) confirm that this is a valid matrix norm) 947 | Hm_norm = jax.numpy.linalg.norm(Hm) 948 | thresh = jax.numpy.maximum( 949 | jax.numpy.finfo(eigvals.dtype).eps * Hm_norm, 950 | jax.numpy.abs(eigvals[:numeig]) * tol) 951 | vals = jax.numpy.abs(eigvecs[numeig - 1, :numeig]) 952 | return jax.numpy.all(beta_m * vals < thresh) 953 | 954 | return check_eigvals_convergence 955 | 956 | def _implicitly_restarted_arnoldi(jax: types.ModuleType) -> Callable: 957 | """ 958 | Helper function to generate a jitted function to do an 959 | implicitly restarted arnoldi factorization of `matvec`. The 960 | returned routine finds the lowest `numeig` 961 | eigenvector-eigenvalue pairs of `matvec` 962 | by alternating between compression and re-expansion of an initial 963 | `num_krylov_vecs`-step Arnoldi factorization. 964 | 965 | Note: The caller has to ensure that the dtype of the return value 966 | of `matvec` matches the dtype of the initial state. Otherwise jax 967 | will raise a TypeError. 968 | 969 | The function signature of the returned function is 970 | Args: 971 | matvec: A callable representing the linear operator. 972 | args: Arguments to `matvec`. `matvec` is called with 973 | `matvec(x, *args)` with `x` the input array on which 974 | `matvec` should act. 975 | initial_state: An starting vector for the iteration. 976 | num_krylov_vecs: Number of krylov vectors of the arnoldi factorization. 977 | numeig: The number of desired eigenvector-eigenvalue pairs. 978 | which: Which eigenvalues to target. Currently supported: `which = 'LR'`. 979 | tol: Convergence flag. If the norm of a krylov vector drops below `tol` 980 | the iteration is terminated. 981 | maxiter: Maximum number of (outer) iteration steps. 982 | Returns: 983 | eta, U: Two lists containing eigenvalues and eigenvectors. 984 | 985 | Args: 986 | jax: The jax module. 987 | Returns: 988 | Callable: A function performing an implicitly restarted 989 | Arnoldi factorization 990 | """ 991 | JaxPrecisionType = type(jax.lax.Precision.DEFAULT) 992 | 993 | arnoldi_fact = _generate_arnoldi_factorization(jax) 994 | 995 | 996 | @functools.partial(jax.jit, static_argnums=(3, 4, 5, 6, 7, 8)) 997 | def implicitly_restarted_arnoldi_method( 998 | matvec: Callable, args: List, initial_state: jax.Array, 999 | num_krylov_vecs: int, numeig: int, which: Text, tol: float, maxiter: int, 1000 | precision: JaxPrecisionType 1001 | ) -> Tuple[jax.Array, List[jax.Array], int]: 1002 | """ 1003 | Implicitly restarted arnoldi factorization of `matvec`. The routine 1004 | finds the lowest `numeig` eigenvector-eigenvalue pairs of `matvec` 1005 | by alternating between compression and re-expansion of an initial 1006 | `num_krylov_vecs`-step Arnoldi factorization. 1007 | 1008 | Note: The caller has to ensure that the dtype of the return value 1009 | of `matvec` matches the dtype of the initial state. Otherwise jax 1010 | will raise a TypeError. 1011 | 1012 | NOTE: Under certain circumstances, the routine can return spurious 1013 | eigenvalues 0.0: if the Arnoldi iteration terminated early 1014 | (after numits < num_krylov_vecs iterations) 1015 | and numeig > numits, then spurious 0.0 eigenvalues will be returned. 1016 | 1017 | Args: 1018 | matvec: A callable representing the linear operator. 1019 | args: Arguments to `matvec`. `matvec` is called with 1020 | `matvec(x, *args)` with `x` the input array on which 1021 | `matvec` should act. 1022 | initial_state: An starting vector for the iteration. 1023 | num_krylov_vecs: Number of krylov vectors of the arnoldi factorization. 1024 | numeig: The number of desired eigenvector-eigenvalue pairs. 1025 | which: Which eigenvalues to target. 1026 | Currently supported: `which = 'LR'` (largest real part). 1027 | tol: Convergence flag. If the norm of a krylov vector drops below `tol` 1028 | the iteration is terminated. 1029 | maxiter: Maximum number of (outer) iteration steps. 1030 | precision: jax.lax.Precision used within lax operations. 1031 | 1032 | Returns: 1033 | jax.Array: Eigenvalues 1034 | List: Eigenvectors 1035 | int: Number of inner krylov iterations of the last arnoldi 1036 | factorization. 1037 | """ 1038 | shape = initial_state.shape 1039 | dtype = initial_state.dtype 1040 | 1041 | dim = np.prod(shape).astype(np.int32) 1042 | num_expand = num_krylov_vecs - numeig 1043 | if not numeig <= num_krylov_vecs <= dim: 1044 | raise ValueError(f"num_krylov_vecs must be between numeig <=" 1045 | f" num_krylov_vecs <= dim, got " 1046 | f" numeig = {numeig}, num_krylov_vecs = " 1047 | f"{num_krylov_vecs}, dim = {dim}.") 1048 | if numeig > dim: 1049 | raise ValueError(f"number of requested eigenvalues numeig = {numeig} " 1050 | f"is larger than the dimension of the operator " 1051 | f"dim = {dim}") 1052 | 1053 | # initialize arrays 1054 | Vm = jax.numpy.zeros( 1055 | (num_krylov_vecs, jax.numpy.ravel(initial_state).shape[0]), dtype=dtype) 1056 | Hm = jax.numpy.zeros((num_krylov_vecs, num_krylov_vecs), dtype=dtype) 1057 | # perform initial arnoldi factorization 1058 | Vm, Hm, residual, norm, numits, ar_converged = arnoldi_fact( 1059 | matvec, args, initial_state, Vm, Hm, 0, num_krylov_vecs, tol, precision) 1060 | fm = residual.ravel() * norm 1061 | 1062 | # generate needed functions 1063 | shifted_QR = _shifted_QR(jax) 1064 | check_eigvals_convergence = _check_eigvals_convergence_eig(jax) 1065 | get_vectors = _get_vectors(jax) 1066 | 1067 | # sort_fun returns `num_expand` least relevant eigenvalues 1068 | # (those to be projected out) 1069 | if which == 'LR': 1070 | sort_fun = jax.tree_util.Partial(_LR_sort(jax), num_expand) 1071 | elif which == 'LM': 1072 | sort_fun = jax.tree_util.Partial(_LM_sort(jax), num_expand) 1073 | else: 1074 | raise ValueError(f"which = {which} not implemented") 1075 | 1076 | it = 1 # we already did one arnoldi factorization 1077 | if maxiter > 1: 1078 | # cast arrays to correct complex dtype 1079 | if Vm.dtype == np.float64: 1080 | dtype = np.complex128 1081 | elif Vm.dtype == np.float32: 1082 | dtype = np.complex64 1083 | elif Vm.dtype == np.complex128: 1084 | dtype = Vm.dtype 1085 | elif Vm.dtype == np.complex64: 1086 | dtype = Vm.dtype 1087 | else: 1088 | raise TypeError(f'dtype {Vm.dtype} not supported') 1089 | 1090 | Vm = Vm.astype(dtype) 1091 | Hm = Hm.astype(dtype) 1092 | fm = fm.astype(dtype) 1093 | 1094 | def outer_loop(carry): 1095 | Hm, Vm, fm, it, numits, ar_converged, _, _, = carry 1096 | evals, _ = cpu_eig(Hm) 1097 | shifts, _ = sort_fun(evals) 1098 | # perform shifted QR iterations to compress arnoldi factorization 1099 | # Note that ||fk|| typically decreases as one iterates the outer loop 1100 | # indicating that iram converges. 1101 | # ||fk|| = \beta_m in reference above 1102 | Vk, Hk, fk = shifted_QR(Vm, Hm, fm, shifts, numeig) 1103 | # reset matrices 1104 | beta_k = jax.numpy.linalg.norm(fk) 1105 | converged = check_eigvals_convergence(beta_k, Hk, tol, numeig) 1106 | Vk = Vk.at[numeig:, :].set(0.0) 1107 | Hk = Hk.at[numeig:, :].set(0.0) 1108 | Hk = Hk.at[:, numeig:].set(0.0) 1109 | def do_arnoldi(vals): 1110 | Vk, Hk, fk, _, _, _, _ = vals 1111 | # restart 1112 | Vm, Hm, residual, norm, numits, ar_converged = arnoldi_fact( 1113 | matvec, args, jax.numpy.reshape(fk, shape), Vk, Hk, numeig, 1114 | num_krylov_vecs, tol, precision) 1115 | fm = residual.ravel() * norm 1116 | return [Vm, Hm, fm, norm, numits, ar_converged, False] 1117 | 1118 | def cond_arnoldi(vals): 1119 | return vals[6] 1120 | 1121 | res = jax.lax.while_loop(cond_arnoldi, do_arnoldi, [ 1122 | Vk, Hk, fk, 1123 | jax.numpy.linalg.norm(fk), numeig, False, 1124 | jax.numpy.logical_not(converged) 1125 | ]) 1126 | 1127 | Vm, Hm, fm, norm, numits, ar_converged = res[0:6] 1128 | out_vars = [ 1129 | Hm, Vm, fm, it + 1, numits, ar_converged, converged, norm 1130 | ] 1131 | return out_vars 1132 | 1133 | def cond_fun(carry): 1134 | it, ar_converged, converged = carry[3], carry[5], carry[ 1135 | 6] 1136 | return jax.lax.cond( 1137 | it < maxiter, lambda x: x, lambda x: False, 1138 | jax.numpy.logical_not(jax.numpy.logical_or(converged, ar_converged))) 1139 | 1140 | converged = False 1141 | carry = [Hm, Vm, fm, it, numits, ar_converged, converged, norm] 1142 | res = jax.lax.while_loop(cond_fun, outer_loop, carry) 1143 | Hm, Vm = res[0], res[1] 1144 | numits, converged = res[4], res[6] 1145 | # if `ar_converged` then `norm`is below convergence threshold 1146 | # set it to 0.0 in this case to prevent `jnp.linalg.eig` from finding a 1147 | # spurious eigenvalue of order `norm`. 1148 | Hm = Hm.at[numits, numits - 1].set( 1149 | jax.lax.cond(converged, lambda x: Hm.dtype.type(0.0), lambda x: x, 1150 | Hm[numits, numits - 1])) 1151 | 1152 | # if the Arnoldi-factorization stopped early (after `numit` iterations) 1153 | # before exhausting the allowed size of the Krylov subspace, 1154 | # (i.e. `numit` < 'num_krylov_vecs'), set elements 1155 | # at positions m, n with m, n >= `numit` to 0.0. 1156 | 1157 | # FIXME (mganahl): under certain circumstances, the routine can still 1158 | # return spurious 0 eigenvalues: if arnoldi terminated early 1159 | # (after numits < num_krylov_vecs iterations) 1160 | # and numeig > numits, then spurious 0.0 eigenvalues will be returned 1161 | 1162 | Hm = (numits > jax.numpy.arange(num_krylov_vecs))[:, None] * Hm * ( 1163 | numits > jax.numpy.arange(num_krylov_vecs))[None, :] 1164 | eigvals, U = cpu_eig(Hm) 1165 | inds = sort_fun(eigvals)[1][:numeig] 1166 | vectors = get_vectors(Vm, U, inds, numeig) 1167 | return eigvals[inds], [ 1168 | jax.numpy.reshape(vectors[n, :], shape) 1169 | for n in range(numeig) 1170 | ], numits 1171 | 1172 | return implicitly_restarted_arnoldi_method 1173 | 1174 | 1175 | def _implicitly_restarted_lanczos(jax: types.ModuleType) -> Callable: 1176 | """ 1177 | Helper function to generate a jitted function to do an 1178 | implicitly restarted lanczos factorization of `matvec`. The 1179 | returned routine finds the lowest `numeig` 1180 | eigenvector-eigenvalue pairs of `matvec` 1181 | by alternating between compression and re-expansion of an initial 1182 | `num_krylov_vecs`-step Lanczos factorization. 1183 | 1184 | Note: The caller has to ensure that the dtype of the return value 1185 | of `matvec` matches the dtype of the initial state. Otherwise jax 1186 | will raise a TypeError. 1187 | 1188 | The function signature of the returned function is 1189 | Args: 1190 | matvec: A callable representing the linear operator. 1191 | args: Arguments to `matvec`. `matvec` is called with 1192 | `matvec(x, *args)` with `x` the input array on which 1193 | `matvec` should act. 1194 | initial_state: An starting vector for the iteration. 1195 | num_krylov_vecs: Number of krylov vectors of the lanczos factorization. 1196 | numeig: The number of desired eigenvector-eigenvalue pairs. 1197 | which: Which eigenvalues to target. Currently supported: `which = 'LR'` 1198 | or `which = 'SR'`. 1199 | tol: Convergence flag. If the norm of a krylov vector drops below `tol` 1200 | the iteration is terminated. 1201 | maxiter: Maximum number of (outer) iteration steps. 1202 | Returns: 1203 | eta, U: Two lists containing eigenvalues and eigenvectors. 1204 | 1205 | Args: 1206 | jax: The jax module. 1207 | Returns: 1208 | Callable: A function performing an implicitly restarted 1209 | Lanczos factorization 1210 | """ 1211 | JaxPrecisionType = type(jax.lax.Precision.DEFAULT) 1212 | lanczos_fact = _generate_lanczos_factorization(jax) 1213 | 1214 | @functools.partial(jax.jit, static_argnums=(3, 4, 5, 6, 7, 8)) 1215 | def implicitly_restarted_lanczos_method( 1216 | matvec: Callable, args: List, initial_state: jax.Array, 1217 | num_krylov_vecs: int, numeig: int, which: Text, tol: float, maxiter: int, 1218 | precision: JaxPrecisionType 1219 | ) -> Tuple[jax.Array, List[jax.Array], int]: 1220 | """ 1221 | Implicitly restarted lanczos factorization of `matvec`. The routine 1222 | finds the lowest `numeig` eigenvector-eigenvalue pairs of `matvec` 1223 | by alternating between compression and re-expansion of an initial 1224 | `num_krylov_vecs`-step Lanczos factorization. 1225 | 1226 | Note: The caller has to ensure that the dtype of the return value 1227 | of `matvec` matches the dtype of the initial state. Otherwise jax 1228 | will raise a TypeError. 1229 | 1230 | NOTE: Under certain circumstances, the routine can return spurious 1231 | eigenvalues 0.0: if the Lanczos iteration terminated early 1232 | (after numits < num_krylov_vecs iterations) 1233 | and numeig > numits, then spurious 0.0 eigenvalues will be returned. 1234 | 1235 | References: 1236 | http://emis.impa.br/EMIS/journals/ETNA/vol.2.1994/pp1-21.dir/pp1-21.pdf 1237 | http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter11.pdf 1238 | 1239 | Args: 1240 | matvec: A callable representing the linear operator. 1241 | args: Arguments to `matvec`. `matvec` is called with 1242 | `matvec(x, *args)` with `x` the input array on which 1243 | `matvec` should act. 1244 | initial_state: An starting vector for the iteration. 1245 | num_krylov_vecs: Number of krylov vectors of the lanczos factorization. 1246 | numeig: The number of desired eigenvector-eigenvalue pairs. 1247 | which: Which eigenvalues to target. 1248 | Currently supported: `which = 'LR'` (largest real part). 1249 | tol: Convergence flag. If the norm of a krylov vector drops below `tol` 1250 | the iteration is terminated. 1251 | maxiter: Maximum number of (outer) iteration steps. 1252 | precision: jax.lax.Precision used within lax operations. 1253 | 1254 | Returns: 1255 | jax.Array: Eigenvalues 1256 | List: Eigenvectors 1257 | int: Number of inner krylov iterations of the last lanczos 1258 | factorization. 1259 | """ 1260 | shape = initial_state.shape 1261 | dtype = initial_state.dtype 1262 | 1263 | dim = np.prod(shape).astype(np.int32) 1264 | num_expand = num_krylov_vecs - numeig 1265 | #note: the second part of the cond is for testing purposes 1266 | if num_krylov_vecs <= numeig < dim: 1267 | raise ValueError(f"num_krylov_vecs must be between numeig <" 1268 | f" num_krylov_vecs <= dim = {dim}," 1269 | f" num_krylov_vecs = {num_krylov_vecs}") 1270 | if numeig > dim: 1271 | raise ValueError(f"number of requested eigenvalues numeig = {numeig} " 1272 | f"is larger than the dimension of the operator " 1273 | f"dim = {dim}") 1274 | 1275 | # initialize arrays 1276 | Vm = jax.numpy.zeros( 1277 | (num_krylov_vecs, jax.numpy.ravel(initial_state).shape[0]), dtype=dtype) 1278 | alphas = jax.numpy.zeros(num_krylov_vecs, dtype=dtype) 1279 | betas = jax.numpy.zeros(num_krylov_vecs - 1, dtype=dtype) 1280 | 1281 | # perform initial lanczos factorization 1282 | Vm, alphas, betas, residual, norm, numits, ar_converged = lanczos_fact( 1283 | matvec, args, initial_state, Vm, alphas, betas, 0, num_krylov_vecs, tol, 1284 | precision) 1285 | fm = residual.ravel() * norm 1286 | # generate needed functions 1287 | shifted_QR = _shifted_QR(jax) 1288 | check_eigvals_convergence = _check_eigvals_convergence_eigh(jax) 1289 | get_vectors = _get_vectors(jax) 1290 | 1291 | # sort_fun returns `num_expand` least relevant eigenvalues 1292 | # (those to be projected out) 1293 | if which == 'LA': 1294 | sort_fun = jax.tree_util.Partial(_LA_sort(jax), num_expand) 1295 | elif which == 'SA': 1296 | sort_fun = jax.tree_util.Partial(_SA_sort(jax), num_expand) 1297 | elif which == 'LM': 1298 | sort_fun = jax.tree_util.Partial(_LM_sort(jax), num_expand) 1299 | else: 1300 | raise ValueError(f"which = {which} not implemented") 1301 | 1302 | it = 1 # we already did one lanczos factorization 1303 | def outer_loop(carry): 1304 | alphas, betas, Vm, fm, it, numits, ar_converged, _, _, = carry 1305 | # pack into alphas and betas into tridiagonal matrix 1306 | Hm = jax.numpy.diag(alphas) + jax.numpy.diag(betas, -1) + jax.numpy.diag( 1307 | betas.conj(), 1) 1308 | evals, _ = jax.numpy.linalg.eigh(Hm) 1309 | shifts, _ = sort_fun(evals) 1310 | # perform shifted QR iterations to compress lanczos factorization 1311 | # Note that ||fk|| typically decreases as one iterates the outer loop 1312 | # indicating that iram converges. 1313 | # ||fk|| = \beta_m in reference above 1314 | Vk, Hk, fk = shifted_QR(Vm, Hm, fm, shifts, numeig) 1315 | # extract new alphas and betas 1316 | alphas = jax.numpy.diag(Hk) 1317 | betas = jax.numpy.diag(Hk, -1) 1318 | alphas = alphas.at[numeig:].set(0.0) 1319 | betas = betas.at[numeig-1:].set(0.0) 1320 | 1321 | beta_k = jax.numpy.linalg.norm(fk) 1322 | Hktest = Hk[:numeig, :numeig] 1323 | matnorm = jax.numpy.linalg.norm(Hktest) 1324 | converged = check_eigvals_convergence(beta_k, Hktest, matnorm, tol) 1325 | 1326 | 1327 | def do_lanczos(vals): 1328 | Vk, alphas, betas, fk, _, _, _, _ = vals 1329 | # restart 1330 | Vm, alphas, betas, residual, norm, numits, ar_converged = lanczos_fact( 1331 | matvec, args, jax.numpy.reshape(fk, shape), Vk, alphas, betas, 1332 | numeig, num_krylov_vecs, tol, precision) 1333 | fm = residual.ravel() * norm 1334 | return [Vm, alphas, betas, fm, norm, numits, ar_converged, False] 1335 | 1336 | def cond_lanczos(vals): 1337 | return vals[7] 1338 | 1339 | res = jax.lax.while_loop(cond_lanczos, do_lanczos, [ 1340 | Vk, alphas, betas, fk, 1341 | jax.numpy.linalg.norm(fk), numeig, False, 1342 | jax.numpy.logical_not(converged) 1343 | ]) 1344 | 1345 | Vm, alphas, betas, fm, norm, numits, ar_converged = res[0:7] 1346 | 1347 | out_vars = [ 1348 | alphas, betas, Vm, fm, it + 1, numits, ar_converged, converged, norm 1349 | ] 1350 | return out_vars 1351 | 1352 | def cond_fun(carry): 1353 | it, ar_converged, converged = carry[4], carry[6], carry[7] 1354 | return jax.lax.cond( 1355 | it < maxiter, lambda x: x, lambda x: False, 1356 | jax.numpy.logical_not(jax.numpy.logical_or(converged, ar_converged))) 1357 | 1358 | converged = False 1359 | carry = [alphas, betas, Vm, fm, it, numits, ar_converged, converged, norm] 1360 | res = jax.lax.while_loop(cond_fun, outer_loop, carry) 1361 | alphas, betas, Vm = res[0], res[1], res[2] 1362 | numits, ar_converged, converged = res[5], res[6], res[7] 1363 | Hm = jax.numpy.diag(alphas) + jax.numpy.diag(betas, -1) + jax.numpy.diag( 1364 | betas.conj(), 1) 1365 | # FIXME (mganahl): under certain circumstances, the routine can still 1366 | # return spurious 0 eigenvalues: if lanczos terminated early 1367 | # (after numits < num_krylov_vecs iterations) 1368 | # and numeig > numits, then spurious 0.0 eigenvalues will be returned 1369 | Hm = (numits > jax.numpy.arange(num_krylov_vecs))[:, None] * Hm * ( 1370 | numits > jax.numpy.arange(num_krylov_vecs))[None, :] 1371 | 1372 | eigvals, U = jax.numpy.linalg.eigh(Hm) 1373 | inds = sort_fun(eigvals)[1][:numeig] 1374 | vectors = get_vectors(Vm, U, inds, numeig) 1375 | return eigvals[inds], [ 1376 | jax.numpy.reshape(vectors[n, :], shape) for n in range(numeig) 1377 | ], numits 1378 | 1379 | return implicitly_restarted_lanczos_method 1380 | 1381 | 1382 | def gmres_wrapper(jax: types.ModuleType): 1383 | """ 1384 | Allows Jax (the module) to be passed in as an argument rather than imported, 1385 | since doing the latter breaks the build. In addition, instantiates certain 1386 | of the enclosed functions as concrete objects within a Dict, allowing them to 1387 | be cached. This avoids spurious recompilations that would otherwise be 1388 | triggered by attempts to pass callables into Jitted functions. 1389 | 1390 | The important function here is functions["gmres_m"], which implements 1391 | GMRES. The other functions are exposed only for testing. 1392 | 1393 | Args: 1394 | ---- 1395 | jax: The imported Jax module. 1396 | 1397 | Returns: 1398 | ------- 1399 | functions: A namedtuple of functions: 1400 | functions.gmres_m = gmres_m 1401 | functions.gmres_residual = gmres_residual 1402 | functions.gmres_krylov = gmres_krylov 1403 | functions.gs_step = _gs_step 1404 | functions.kth_arnoldi_step = kth_arnoldi_step 1405 | functions.givens_rotation = givens_rotation 1406 | """ 1407 | jnp = jax.numpy 1408 | JaxPrecisionType = type(jax.lax.Precision.DEFAULT) 1409 | def gmres_m( 1410 | A_mv: Callable, A_args: Sequence, b: jax.Array, x0: jax.Array, 1411 | tol: float, atol: float, num_krylov_vectors: int, maxiter: int, 1412 | precision: JaxPrecisionType) -> Tuple[jax.Array, float, int, bool]: 1413 | """ 1414 | Solve A x = b for x using the m-restarted GMRES method. This is 1415 | intended to be called via jax_backend.gmres. 1416 | 1417 | Given a linear mapping with (n x n) matrix representation 1418 | A = A_mv(*A_args) gmres_m solves 1419 | Ax = b (1) 1420 | where x and b are length-n vectors, using the method of 1421 | Generalized Minimum RESiduals with M iterations per restart (GMRES_M). 1422 | 1423 | Args: 1424 | A_mv: A function v0 = A_mv(v, *A_args) where v0 and v have the same shape. 1425 | A_args: A list of positional arguments to A_mv. 1426 | b: The b in A @ x = b. 1427 | x0: Initial guess solution. 1428 | tol, atol: Solution tolerance to achieve, 1429 | norm(residual) <= max(tol * norm(b), atol). 1430 | tol is also used to set the threshold at which the Arnoldi factorization 1431 | terminates. 1432 | num_krylov_vectors: Size of the Krylov space to build at each restart. 1433 | maxiter: The Krylov space will be repeatedly rebuilt up to this many 1434 | times. 1435 | Returns: 1436 | x: The approximate solution. 1437 | beta: Norm of the residual at termination. 1438 | n_iter: Number of iterations at termination. 1439 | converged: Whether the desired tolerance was achieved. 1440 | """ 1441 | num_krylov_vectors = min(num_krylov_vectors, b.size) 1442 | x = x0 1443 | b_norm = jnp.linalg.norm(b) 1444 | tol = max(tol * b_norm, atol) 1445 | for n_iter in range(maxiter): 1446 | done, beta, x = gmres(A_mv, A_args, b, x, num_krylov_vectors, x0, tol, 1447 | b_norm, precision) 1448 | if done: 1449 | break 1450 | return x, beta, n_iter, done 1451 | 1452 | def gmres(A_mv: Callable, A_args: Sequence, b: jax.Array, 1453 | x: jax.Array, num_krylov_vectors: int, x0: jax.Array, 1454 | tol: float, b_norm: float, 1455 | precision: JaxPrecisionType) -> Tuple[bool, float, jax.Array]: 1456 | """ 1457 | A single restart of GMRES. 1458 | 1459 | Args: 1460 | A_mv: A function `v0 = A_mv(v, *A_args)` where `v0` and 1461 | `v` have the same shape. 1462 | A_args: A list of positional arguments to A_mv. 1463 | b: The `b` in `A @ x = b`. 1464 | x: Initial guess solution. 1465 | tol: Solution tolerance to achieve, 1466 | num_krylov_vectors : Size of the Krylov space to build. 1467 | Returns: 1468 | done: Whether convergence was achieved. 1469 | beta: Magnitude of residual (i.e. the error estimate). 1470 | x: The approximate solution. 1471 | """ 1472 | r, beta = gmres_residual(A_mv, A_args, b, x) 1473 | k, V, R, beta_vec = gmres_krylov(A_mv, A_args, num_krylov_vectors, 1474 | x0, r, beta, tol, b_norm, precision) 1475 | x = gmres_update(k, V, R, beta_vec, x0) 1476 | done = k < num_krylov_vectors - 1 1477 | return done, beta, x 1478 | 1479 | @jax.jit 1480 | def gmres_residual(A_mv: Callable, A_args: Sequence, b: jax.Array, 1481 | x: jax.Array) -> Tuple[jax.Array, float]: 1482 | """ 1483 | Computes the residual vector r and its norm, beta, which is minimized by 1484 | GMRES. 1485 | 1486 | Args: 1487 | A_mv: A function v0 = A_mv(v, *A_args) where v0 and 1488 | v have the same shape. 1489 | A_args: A list of positional arguments to A_mv. 1490 | b: The b in A @ x = b. 1491 | x: Initial guess solution. 1492 | Returns: 1493 | r: The residual vector. 1494 | beta: Its magnitude. 1495 | """ 1496 | r = b - A_mv(x, *A_args) 1497 | beta = jnp.linalg.norm(r) 1498 | return r, beta 1499 | 1500 | def gmres_update(k: int, V: jax.Array, R: jax.Array, 1501 | beta_vec: jax.Array, 1502 | x0: jax.Array) -> jax.Array: 1503 | """ 1504 | Updates the solution in response to the information computed by the 1505 | main GMRES loop. 1506 | 1507 | Args: 1508 | k: The final iteration which was reached by GMRES before convergence. 1509 | V: The Arnoldi matrix of Krylov vectors. 1510 | R: The R factor in H = QR where H is the Arnoldi overlap matrix. 1511 | beta_vec: Stores the Givens factors used to map H into QR. 1512 | x0: The initial guess solution. 1513 | Returns: 1514 | x: The updated solution. 1515 | """ 1516 | q = min(k, R.shape[1]) 1517 | y = jax.scipy.linalg.solve_triangular(R[:q, :q], beta_vec[:q]) 1518 | x = x0 + V[:, :q] @ y 1519 | return x 1520 | 1521 | @functools.partial(jax.jit, static_argnums=(2, 8)) 1522 | def gmres_krylov( 1523 | A_mv: Callable, A_args: Sequence, n_kry: int, x0: jax.Array, 1524 | r: jax.Array, beta: float, tol: float, b_norm: float, 1525 | precision: JaxPrecisionType 1526 | ) -> Tuple[int, jax.Array, jax.Array, jax.Array]: 1527 | """ 1528 | Builds the Arnoldi decomposition of (A, v), where v is the normalized 1529 | residual of the current solution estimate. The decomposition is 1530 | returned as V, R, where V is the usual matrix of Krylov vectors and 1531 | R is the upper triangular matrix in H = QR, with H the usual matrix 1532 | of overlaps. 1533 | 1534 | Args: 1535 | A_mv: A function `v0 = A_mv(v, *A_args)` where `v0` and 1536 | `v` have the same shape. 1537 | A_args: A list of positional arguments to A_mv. 1538 | n_kry: Size of the Krylov space to build; this is called 1539 | num_krylov_vectors in higher level code. 1540 | x0: Guess solution. 1541 | r: Residual vector. 1542 | beta: Magnitude of r. 1543 | tol: Solution tolerance to achieve. 1544 | b_norm: Magnitude of b in Ax = b. 1545 | Returns: 1546 | k: Counts the number of iterations before convergence. 1547 | V: The Arnoldi matrix of Krylov vectors. 1548 | R: From H = QR where H is the Arnoldi matrix of overlaps. 1549 | beta_vec: Stores Q implicitly as Givens factors. 1550 | """ 1551 | n = r.size 1552 | err = beta 1553 | v = r / beta 1554 | 1555 | # These will store the Givens rotations used to update the QR decompositions 1556 | # of the Arnoldi matrices. 1557 | # cos : givens[0, :] 1558 | # sine: givens[1, :] 1559 | givens = jnp.zeros((2, n_kry), dtype=x0.dtype) 1560 | beta_vec = jnp.zeros((n_kry + 1), dtype=x0.dtype) 1561 | beta_vec = beta_vec.at[0].set(beta) 1562 | V = jnp.zeros((n, n_kry + 1), dtype=x0.dtype) 1563 | V = V.at[:, 0].set(v) 1564 | R = jnp.zeros((n_kry + 1, n_kry), dtype=x0.dtype) 1565 | 1566 | # The variable data for the carry call. Each iteration modifies these 1567 | # values and feeds the results to the next iteration. 1568 | k = 0 1569 | gmres_variables = (k, V, R, beta_vec, err, # < The actual output we need. 1570 | givens) # < Modified between iterations. 1571 | gmres_constants = (tol, A_mv, A_args, b_norm, n_kry) 1572 | gmres_carry = (gmres_variables, gmres_constants) 1573 | # The 'x' input for the carry call. Each iteration will receive an ascending 1574 | # loop index (from the jnp.arange) along with the constant data 1575 | # in gmres_constants. 1576 | 1577 | def gmres_krylov_work(gmres_carry: GmresCarryType) -> GmresCarryType: 1578 | """ 1579 | Performs a single iteration of gmres_krylov. See that function for a more 1580 | detailed description. 1581 | 1582 | Args: 1583 | gmres_carry: The gmres_carry from gmres_krylov. 1584 | Returns: 1585 | gmres_carry: The updated gmres_carry. 1586 | """ 1587 | gmres_variables, gmres_constants = gmres_carry 1588 | k, V, R, beta_vec, err, givens = gmres_variables 1589 | tol, A_mv, A_args, b_norm, _ = gmres_constants 1590 | 1591 | V, H = kth_arnoldi_step(k, A_mv, A_args, V, R, tol, precision) 1592 | R_col, givens = apply_givens_rotation(H[:, k], givens, k) 1593 | R = R.at[:, k].set(R_col[:]) 1594 | 1595 | # Update the residual vector. 1596 | cs, sn = givens[:, k] * beta_vec[k] 1597 | beta_vec = beta_vec.at[k].set(cs) 1598 | beta_vec = beta_vec.at[k + 1].set(sn) 1599 | err = jnp.abs(sn) / b_norm 1600 | gmres_variables = (k + 1, V, R, beta_vec, err, givens) 1601 | return (gmres_variables, gmres_constants) 1602 | 1603 | def gmres_krylov_loop_condition(gmres_carry: GmresCarryType) -> bool: 1604 | """ 1605 | This function dictates whether the main GMRES while loop will proceed. 1606 | It is equivalent to: 1607 | if k < n_kry and err > tol: 1608 | return True 1609 | else: 1610 | return False 1611 | where k, n_kry, err, and tol are unpacked from gmres_carry. 1612 | 1613 | Args: 1614 | gmres_carry: The gmres_carry from gmres_krylov. 1615 | Returns: 1616 | (bool): Whether to continue iterating. 1617 | """ 1618 | gmres_constants, gmres_variables = gmres_carry 1619 | tol = gmres_constants[0] 1620 | k = gmres_variables[0] 1621 | err = gmres_variables[4] 1622 | n_kry = gmres_constants[4] 1623 | 1624 | def is_iterating(k, n_kry): 1625 | return k < n_kry 1626 | 1627 | def not_converged(args): 1628 | err, tol = args 1629 | return err >= tol 1630 | return jax.lax.cond(is_iterating(k, n_kry), # Predicate. 1631 | not_converged, # Called if True. 1632 | lambda x: False, # Called if False. 1633 | (err, tol)) # Arguments to calls. 1634 | 1635 | gmres_carry = jax.lax.while_loop(gmres_krylov_loop_condition, 1636 | gmres_krylov_work, 1637 | gmres_carry) 1638 | gmres_variables, gmres_constants = gmres_carry 1639 | k, V, R, beta_vec, err, givens = gmres_variables 1640 | return (k, V, R, beta_vec) 1641 | 1642 | VarType = Tuple[int, jax.Array, jax.Array, jax.Array, 1643 | float, jax.Array] 1644 | ConstType = Tuple[float, Callable, Sequence, jax.Array, int] 1645 | GmresCarryType = Tuple[VarType, ConstType] 1646 | 1647 | 1648 | @functools.partial(jax.jit, static_argnums=(6,)) 1649 | def kth_arnoldi_step( 1650 | k: int, A_mv: Callable, A_args: Sequence, V: jax.Array, 1651 | H: jax.Array, tol: float, 1652 | precision: JaxPrecisionType) -> Tuple[jax.Array, jax.Array]: 1653 | """ 1654 | Performs the kth iteration of the Arnoldi reduction procedure. 1655 | Args: 1656 | k: The current iteration. 1657 | A_mv, A_args: A function A_mv(v, *A_args) performing a linear 1658 | transformation on v. 1659 | V: A matrix of size (n, K + 1), K > k such that each column in 1660 | V[n, :k+1] stores a Krylov vector and V[:, k+1] is all zeroes. 1661 | H: A matrix of size (K, K), K > k with H[:, k] all zeroes. 1662 | Returns: 1663 | V, H: With their k'th columns respectively filled in by a new 1664 | orthogonalized Krylov vector and new overlaps. 1665 | """ 1666 | 1667 | def _gs_step( 1668 | r: jax.Array, 1669 | v_i: jax.Array) -> Tuple[jax.Array, jax.Array]: 1670 | """ 1671 | Performs one iteration of the stabilized Gram-Schmidt procedure, with 1672 | r to be orthonormalized against {v} = {v_0, v_1, ...}. 1673 | 1674 | Args: 1675 | r: The new vector which is not in the initially orthonormal set. 1676 | v_i: The i'th vector in that set. 1677 | Returns: 1678 | r_i: The updated r which is now orthonormal with v_i. 1679 | h_i: The overlap of r with v_i. 1680 | """ 1681 | h_i = jnp.vdot(v_i, r, precision=precision) 1682 | r_i = r - h_i * v_i 1683 | return r_i, h_i 1684 | 1685 | v = A_mv(V[:, k], *A_args) 1686 | v_new, H_k = jax.lax.scan(_gs_step, init=v, xs=V.T) 1687 | v_norm = jnp.linalg.norm(v_new) 1688 | r_new = v_new / v_norm 1689 | # Normalize v unless it is the zero vector. 1690 | r_new = jax.lax.cond(v_norm > tol, 1691 | lambda x: x[0] / x[1], 1692 | lambda x: 0.*x[0], 1693 | (v_new, v_norm) 1694 | ) 1695 | H = H.at[:,k].set(H_k) 1696 | H = H.at[k+1,k].set(v_norm) 1697 | V = V.at[:,k+1].set(r_new) 1698 | return V, H 1699 | 1700 | #################################################################### 1701 | # GIVENS ROTATIONS 1702 | #################################################################### 1703 | @jax.jit 1704 | def apply_rotations(H_col: jax.Array, givens: jax.Array, 1705 | k: int) -> jax.Array: 1706 | """ 1707 | Successively applies each of the rotations stored in givens to H_col. 1708 | 1709 | Args: 1710 | H_col : The vector to be rotated. 1711 | givens: 2 x K, K > k matrix of rotation factors. 1712 | k : Iteration number. 1713 | Returns: 1714 | H_col : The rotated vector. 1715 | """ 1716 | rotation_carry = (H_col, 0, k, givens) 1717 | 1718 | def loop_condition(carry): 1719 | i = carry[1] 1720 | k = carry[2] 1721 | return jax.lax.cond(i < k, lambda x: True, lambda x: False, 0) 1722 | 1723 | def apply_ith_rotation(carry): 1724 | H_col, i, k, givens = carry 1725 | cs = givens[0, i] 1726 | sn = givens[1, i] 1727 | H_i = cs * H_col[i] - sn * H_col[i + 1] 1728 | H_ip1 = sn * H_col[i] + cs * H_col[i + 1] 1729 | H_col = H_col.at[i].set(H_i) 1730 | H_col = H_col.at[i + 1].set(H_ip1) 1731 | return (H_col, i + 1, k, givens) 1732 | 1733 | rotation_carry = jax.lax.while_loop(loop_condition, 1734 | apply_ith_rotation, 1735 | rotation_carry) 1736 | H_col = rotation_carry[0] 1737 | return H_col 1738 | 1739 | @jax.jit 1740 | def apply_givens_rotation(H_col: jax.Array, givens: jax.Array, 1741 | k: int) -> Tuple[jax.Array, jax.Array]: 1742 | """ 1743 | Applies the Givens rotations stored in the vectors cs and sn to the vector 1744 | H_col. Then constructs a new Givens rotation that eliminates H_col's 1745 | k'th element, yielding the corresponding column of the R in H's QR 1746 | decomposition. Returns the new column of R along with the new Givens 1747 | factors. 1748 | 1749 | Args: 1750 | H_col : The column of H to be rotated. 1751 | givens: A matrix representing the cosine and sine factors of the 1752 | previous GMRES Givens rotations, in that order 1753 | (i.e. givens[0, :] -> the cos factor). 1754 | k : Iteration number. 1755 | Returns: 1756 | R_col : The column of R obtained by transforming H_col. 1757 | givens_k: The new elements of givens that zeroed out the k+1'th element 1758 | of H_col. 1759 | """ 1760 | # This call successively applies each of the 1761 | # Givens rotations stored in givens[:, :k] to H_col. 1762 | H_col = apply_rotations(H_col, givens, k) 1763 | 1764 | cs_k, sn_k = givens_rotation(H_col[k], H_col[k + 1]) 1765 | givens = givens.at[0, k].set(cs_k) 1766 | givens = givens.at[1, k].set(sn_k) 1767 | 1768 | r_k = cs_k * H_col[k] - sn_k * H_col[k + 1] 1769 | R_col = H_col.at[k].set(r_k) 1770 | R_col = R_col.at[k + 1].set(0.) 1771 | return R_col, givens 1772 | 1773 | @jax.jit 1774 | def givens_rotation(v1: float, v2: float) -> Tuple[float, float]: 1775 | """ 1776 | Given scalars v1 and v2, computes cs = cos(theta) and sn = sin(theta) 1777 | so that [cs -sn] @ [v1] = [r] 1778 | [sn cs] [v2] [0] 1779 | Args: 1780 | v1, v2: The scalars. 1781 | Returns: 1782 | cs, sn: The rotation factors. 1783 | """ 1784 | t = jnp.sqrt(v1**2 + v2**2) 1785 | cs = v1 / t 1786 | sn = -v2 / t 1787 | return cs, sn 1788 | 1789 | fnames = [ 1790 | "gmres_m", "gmres_residual", "gmres_krylov", 1791 | "kth_arnoldi_step", "givens_rotation" 1792 | ] 1793 | functions = [ 1794 | gmres_m, gmres_residual, gmres_krylov, kth_arnoldi_step, 1795 | givens_rotation 1796 | ] 1797 | 1798 | class Functions: 1799 | 1800 | def __init__(self, fun_dict): 1801 | self.dict = fun_dict 1802 | 1803 | def __getattr__(self, name): 1804 | return self.dict[name] 1805 | 1806 | return Functions(dict(zip(fnames, functions))) 1807 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jaxtn.solver.jaxeigs import eigs, eigsh 4 | from jax import config 5 | config.update("jax_enable_x64", True) 6 | 7 | if __name__ == "__main__": 8 | m = 10 9 | A = jax.random.uniform(jax.random.PRNGKey(42),(m,m)) 10 | b = jax.random.uniform(jax.random.PRNGKey(41),(m,)) 11 | def mapA(x): return A@x 12 | res = eigs(mapA, initial_state = b, numeig=1, num_krylov_vecs = 5) 13 | print(res[0],res[1][0]) 14 | 15 | A @ res[1][0] / res[1][0] 16 | 17 | 18 | m = 10 19 | A = jax.random.uniform(jax.random.PRNGKey(42),(m,m)) 20 | b = jax.random.uniform(jax.random.PRNGKey(41),(m,)) 21 | def mapA(x): return (A+A.T.conj())@x 22 | res = eigsh(mapA, initial_state = b, numeig=1, num_krylov_vecs = 5) 23 | print(res[0],res[1][0]) 24 | 25 | (A+A.T.conj()) @ res[1][0] / res[1][0] 26 | 27 | --------------------------------------------------------------------------------