├── .github └── workflows │ └── deploy_docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── _overrides │ └── partials │ │ └── source.html ├── _static │ ├── custom_css.css │ ├── favicon.png │ └── mathjax.js ├── api │ ├── estimators.md │ └── samplers.md ├── authors.md ├── conduct.md ├── index.md ├── license.md ├── requirements.txt ├── scripts │ └── extension.py └── templates │ └── python │ └── material │ ├── class.html.jinja │ └── signature.html.jinja ├── mkdocs.yml ├── pyproject.toml ├── src ├── py.typed └── traceax │ ├── __init__.py │ ├── _estimators.py │ └── _samplers.py └── tests ├── __init__.py ├── conftest.py ├── helpers.py └── test_trace.py /.github/workflows/deploy_docs.yml: -------------------------------------------------------------------------------- 1 | name: Build docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | # These permissions are required for it to work 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | jobs: 15 | # build docs and upload 16 | build: 17 | strategy: 18 | matrix: 19 | python-version: [ 3.11 ] 20 | os: [ ubuntu-latest ] 21 | runs-on: ${{ matrix.os }} 22 | steps: 23 | - name: Checkout code 24 | uses: actions/checkout@v4 25 | 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | python -m pip install '.[docs]' 35 | 36 | - name: Build docs 37 | run: | 38 | mkdocs build 39 | 40 | - name: Configure GitHub Pages 41 | uses: actions/configure-pages@v5 42 | 43 | - name: Upload docs 44 | uses: actions/upload-pages-artifact@v3 45 | with: 46 | path: site # where `mkdocs build` puts the built site 47 | 48 | # deployement job 49 | deploy: 50 | environment: 51 | name: github-pages 52 | url: ${{ steps.deployment.outputs.pages_url }} 53 | runs-on: ${{ matrix.os }} 54 | needs: build 55 | steps: 56 | - name: Deploy to GitHub Pages 57 | id: deployment # This is required for environment 58 | uses: actions/deploy-pages@v4 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | #versioning is handled by tag 10 | src/traceax/_version.py 11 | 12 | #pychar 13 | .idea/* 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: '^(docs/conf.py|tests/testdata/.*)' 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.6.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: check-added-large-files 9 | - id: check-ast 10 | - id: check-json 11 | - id: check-merge-conflict 12 | - id: check-xml 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: mixed-line-ending 18 | args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows 19 | 20 | - repo: https://github.com/astral-sh/ruff-pre-commit 21 | rev: v0.3.5 22 | hooks: 23 | - id: ruff # linter 24 | types_or: [ python, pyi, jupyter ] 25 | args: [--fix] 26 | - id: ruff-format # formatter 27 | types_or: [ python, pyi, jupyter ] 28 | 29 | - repo: https://github.com/RobertCraigie/pyright-python 30 | rev: v1.1.358 31 | hooks: 32 | - id: pyright 33 | additional_dependencies: [ "equinox", "jax", "lineax", "optimistix", "pytest" ] 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 MancusoLab 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Documentation-webpage](https://img.shields.io/badge/Docs-Available-brightgreen)](https://mancusolab.github.io/traceax/) 2 | [![PyPI-Server](https://img.shields.io/pypi/v/traceax.svg)](https://pypi.org/project/traceax/) 3 | [![Github](https://img.shields.io/github/stars/mancusolab/traceax?style=social)](https://github.com/mancusolab/traceax) 4 | [![License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | [![Project generated with Hatch](https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg)](https://github.com/pypa/hatch) 6 | 7 | # Traceax 8 | ``traceax`` is a Python library to perform stochastic trace estimation for linear operators. Namely, 9 | given a square linear operator $\mathbf{A}$, ``traceax`` provides flexible routines that estimate, 10 | 11 | $$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},$$ 12 | 13 | using only matrix-vector products. ``traceax`` is heavily inspired by 14 | [lineax](https://github.com/patrick-kidger/lineax) as well as 15 | [XTrace](https://github.com/eepperly/XTrace). 16 | 17 | 18 | 19 | [**Installation**](#installation) 20 | | [**Example**](#get-started-with-example) 21 | | [**Documentation**](#documentation) 22 | | [**Notes**](#notes) 23 | | [**Support**](#support) 24 | | [**Other Software**](#other-software) 25 | 26 | ------------------ 27 | 28 | ## Installation 29 | 30 | Users can download the latest repository and then use `pip`: 31 | 32 | ``` bash 33 | git clone https://github.com/mancusolab/traceax.git 34 | cd traceax 35 | pip install . 36 | ``` 37 | 38 | ## Get Started with Example 39 | 40 | ```python 41 | import jax.numpy as jnp 42 | import jax.random as rdm 43 | import lineax as lx 44 | 45 | import traceax as tx 46 | 47 | # simulate simple symmetric matrix with exponential eigenvalue decay 48 | seed = 0 49 | N = 1000 50 | key = rdm.PRNGKey(seed) 51 | key, xkey = rdm.split(key) 52 | 53 | X = rdm.normal(xkey, (N, N)) 54 | Q, R = jnp.linalg.qr(X) 55 | U = jnp.power(0.7, jnp.arange(N)) 56 | A = (Q * U) @ Q.T 57 | 58 | # should be numerically close 59 | print(jnp.trace(A)) # 3.3333323 60 | print(jnp.sum(U)) # 3.3333335 61 | 62 | # setup linear operator 63 | operator = lx.MatrixLinearOperator(A) 64 | 65 | # number of matrix vector operators 66 | k = 25 67 | 68 | # split key for estimators 69 | key, key1, key2, key3, key4 = rdm.split(key, 5) 70 | 71 | # Hutchinson estimator; default samples Rademacher {-1,+1} 72 | hutch = tx.HutchinsonEstimator() 73 | print(hutch.estimate(key1, operator, k)) # (Array(3.4099615, dtype=float32), {}) 74 | 75 | # Hutch++ estimator; default samples Rademacher {-1,+1} 76 | hpp = tx.HutchPlusPlusEstimator() 77 | print(hpp.estimate(key2, operator, k)) # (Array(3.3033807, dtype=float32), {}) 78 | 79 | # XTrace estimator; default samples uniformly on n-Sphere 80 | xt = tx.XTraceEstimator() 81 | print(xt.estimate(key3, operator, k)) # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)}) 82 | 83 | # XNysTrace estimator; Improved performance for NSD/PSD trace estimates 84 | operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag) 85 | nt = tx.XNysTraceEstimator() 86 | print(nt.estimate(key4, operator, k)) # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)}) 87 | ``` 88 | 89 | ## Documentation 90 | Documentation is available at [here](https://mancusolab.github.io/traceax/). 91 | 92 | 93 | ## Notes 94 | 95 | - `traceax` uses [JAX](https://github.com/google/jax) with [Just In 96 | Time](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) 97 | compilation to achieve high-speed computation. However, there are 98 | some [issues](https://github.com/google/jax/issues/5501) for JAX 99 | with Mac M1 chip. To solve this, users need to initiate conda using 100 | [miniforge](https://github.com/conda-forge/miniforge), and then 101 | install `traceax` using `pip` in the desired environment. 102 | 103 | 104 | ## Support 105 | 106 | Please report any bugs or feature requests in the [Issue 107 | Tracker](https://github.com/mancusolab/traceax/issues). If users have 108 | any questions or comments, please contact Linda Serafin () or 109 | Nicholas Mancuso (). 110 | 111 | ## Other Software 112 | 113 | Feel free to use other software developed by [Mancuso 114 | Lab](https://www.mancusolab.com/): 115 | 116 | - [SuShiE](https://github.com/mancusolab/sushie): a Bayesian 117 | fine-mapping framework for molecular QTL data across multiple 118 | ancestries. 119 | - [MA-FOCUS](https://github.com/mancusolab/ma-focus): a Bayesian 120 | fine-mapping framework using 121 | [TWAS](https://www.nature.com/articles/ng.3506) statistics across 122 | multiple ancestries to identify the causal genes for complex traits. 123 | - [SuSiE-PCA](https://github.com/mancusolab/susiepca): a scalable 124 | Bayesian variable selection technique for sparse principal component 125 | analysis 126 | - [twas_sim](https://github.com/mancusolab/twas_sim): a Python 127 | software to simulate [TWAS](https://www.nature.com/articles/ng.3506) 128 | statistics. 129 | - [FactorGo](https://github.com/mancusolab/factorgo): a scalable 130 | variational factor analysis model that learns pleiotropic factors 131 | from GWAS summary statistics. 132 | - [HAMSTA](https://github.com/tszfungc/hamsta): a Python software to 133 | estimate heritability explained by local ancestry data from 134 | admixture mapping summary statistics. 135 | 136 | ------------------------------------------------------------------------ 137 | 138 | ``traceax`` is distributed under the terms of the 139 | [Apache-2.0 license](https://spdx.org/licenses/Apache-2.0.html). 140 | 141 | 142 | ------------------------------------------------------------------------ 143 | 144 | This project has been set up using Hatch. For details and usage 145 | information on Hatch see . 146 | -------------------------------------------------------------------------------- /docs/_overrides/partials/source.html: -------------------------------------------------------------------------------- 1 | {% import "partials/language.html" as lang with context %} 2 | 3 |
4 | {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} 5 | {% include ".icons/" ~ icon ~ ".svg" %} 6 |
7 |
8 | {{ config.repo_name }} 9 |
10 |
11 | {% if config.theme.twitter_url %} 12 | 13 |
14 | {% include ".icons/fontawesome/brands/twitter.svg" %} 15 |
16 |
17 | {{ config.theme.twitter_name }} 18 |
19 |
20 | {% endif %} 21 | -------------------------------------------------------------------------------- /docs/_static/custom_css.css: -------------------------------------------------------------------------------- 1 | /* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ 2 | html { 3 | scroll-padding-top: 50px; 4 | } 5 | 6 | /* Fit the Twitter handle alongside the GitHub one in the top right. */ 7 | 8 | div.md-header__source { 9 | width: revert; 10 | max-width: revert; 11 | } 12 | 13 | a.md-source { 14 | display: inline-block; 15 | } 16 | 17 | .md-source__repository { 18 | max-width: 100%; 19 | } 20 | 21 | /* Emphasise sections of nav on left hand side */ 22 | 23 | nav.md-nav { 24 | padding-left: 5px; 25 | } 26 | 27 | nav.md-nav--secondary { 28 | border-left: revert !important; 29 | } 30 | 31 | .md-nav__title { 32 | font-size: 0.9rem; 33 | } 34 | 35 | .md-nav__item--section > .md-nav__link { 36 | font-size: 0.9rem; 37 | } 38 | 39 | /* Indent autogenerated documentation */ 40 | 41 | div.doc-contents { 42 | padding-left: 25px; 43 | border-left: 4px solid rgba(230, 230, 230); 44 | } 45 | 46 | /* Increase visibility of splitters "---" */ 47 | 48 | [data-md-color-scheme="default"] .md-typeset hr { 49 | border-bottom-color: rgb(0, 0, 0); 50 | border-bottom-width: 1pt; 51 | } 52 | 53 | [data-md-color-scheme="slate"] .md-typeset hr { 54 | border-bottom-color: rgb(230, 230, 230); 55 | } 56 | 57 | /* More space at the bottom of the page */ 58 | 59 | .md-main__inner { 60 | margin-bottom: 1.5rem; 61 | } 62 | 63 | /* Remove prev/next footer buttons */ 64 | 65 | .md-footer__inner { 66 | display: none; 67 | } 68 | 69 | /* Change font sizes */ 70 | 71 | html { 72 | /* Decrease font size for overall webpage */ 73 | font-size: 110%; 74 | } 75 | 76 | .md-typeset .admonition { 77 | /* Increase font size in admonitions */ 78 | font-size: 100% !important; 79 | } 80 | 81 | .md-typeset details { 82 | /* Increase font size in details */ 83 | font-size: 100% !important; 84 | } 85 | 86 | .md-typeset h1 { 87 | font-size: 1.6rem; 88 | } 89 | 90 | .md-typeset h2 { 91 | font-size: 1.5rem; 92 | } 93 | 94 | .md-typeset h3 { 95 | font-size: 1.3rem; 96 | } 97 | 98 | .md-typeset h4 { 99 | font-size: 1.1rem; 100 | font-weight: 700; 101 | } 102 | 103 | .md-typeset h5 { 104 | font-size: 0.9rem; 105 | font-weight: 700; 106 | } 107 | 108 | .md-typeset h6 { 109 | font-size: 0.8rem; 110 | font-weight: 700; 111 | } 112 | 113 | /* Change default colours for tags */ 114 | 115 | [data-md-color-scheme="default"] { 116 | --md-typeset-a-color: rgb(0, 189, 164) !important; 117 | } 118 | [data-md-color-scheme="slate"] { 119 | --md-hue: 232; 120 | --md-typeset-a-color: rgb(0, 189, 164) !important; 121 | 122 | /* lighten background */ 123 | --md-default-bg-color: hsla(var(--md-hue),15%,21%,1); 124 | --md-default-bg-color--light: hsla(var(--md-hue),15%,21%,0.54); 125 | --md-default-bg-color--lighter: hsla(var(--md-hue),15%,21%,0.26); 126 | --md-default-bg-color--lightest: hsla(var(--md-hue),15%,21%,0.07); 127 | 128 | --md-code-fg-color: hsla(var(--md-hue),18%,86%,1); 129 | --md-code-bg-color: hsla(var(--md-hue),10%,10%,1); 130 | /*--md-code-bg-color: rgb(25,25,33);*/ 131 | --md-code-hl-color: rgba(66,135,255,0.15); 132 | } 133 | 134 | /* Highlight functions, classes etc. type signatures. Really helps to make clear where 135 | one item ends and another begins. */ 136 | 137 | [data-md-color-scheme="default"] { 138 | --doc-heading-color: #DDD; 139 | --doc-heading-border-color: #CCC; 140 | --doc-heading-color-alt: #F0F0F0; 141 | } 142 | [data-md-color-scheme="slate"] { 143 | --doc-heading-color: rgb(25,25,33); 144 | --doc-heading-border-color: rgb(25,25,33); 145 | --doc-heading-color-alt: rgb(33,33,44); 146 | } 147 | 148 | h4.doc-heading { 149 | /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ 150 | background-color: var(--doc-heading-color); 151 | border: solid var(--doc-heading-border-color); 152 | border-width: 1.5pt; 153 | border-radius: 2pt; 154 | padding: 0pt 5pt 2pt 5pt; 155 | } 156 | h5.doc-heading, h6.heading { 157 | background-color: var(--doc-heading-color-alt); 158 | border-radius: 2pt; 159 | padding: 0pt 5pt 2pt 5pt; 160 | text-transform: none !important; 161 | } 162 | -------------------------------------------------------------------------------- /docs/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/traceax/8a7080f693799634ea4f9c0e620e09c872e80695/docs/_static/favicon.png -------------------------------------------------------------------------------- /docs/_static/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) 17 | -------------------------------------------------------------------------------- /docs/api/estimators.md: -------------------------------------------------------------------------------- 1 | # Stochastic Trace Estimators 2 | 3 | Given a square linear operator $\mathbf{A}$, the `trace` of $\mathbf{A}$ is defined as, 4 | 5 | $$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii}.$$ 6 | 7 | When $\mathbf{A}$ is represented in memory as an $n \times n$ matrix, computing 8 | the `trace` is straightforward, only requiring $O(n)$ time to sum along the 9 | diagonal. However, in practice, $\mathbf{A}$ can be the result of many operations 10 | and explicit calculation and represention of $\mathbf{A}$ may be prohibitive. 11 | 12 | Given this, we may represent $\mathbf{A}$ as a linear operator, which can be viewed 13 | as a lazy representation of $\mathbf{A}$ that only tracks the underlying operations 14 | to calculate its final result. As such, matrix vector products between $\mathbf{A}$ and 15 | vector $\omega$ can be obtained by lazily evaluating the chain of underlying 16 | composition operations with intermediate matrix-vector products 17 | (e.g., [lineax](https://github.com/patrick-kidger/lineax)). 18 | 19 | There is a rich history of *stochastic* trace estimation for matrices where 20 | one can estimate the `trace` of $\mathbf{A}$ using multiple matrix-vector products 21 | followed by averaging. To see this, observe that 22 | 23 | $$\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A}),$$ 24 | 25 | where $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. 26 | The above is known as the Girard-Hutchinson estimator. There have been multiple 27 | advancements in stochastic `trace` estimation. Here, `traceax` aims to provide an 28 | easy-to-use API for stochastic trace estimation that leverages the flexibility of 29 | [lineax](https://github.com/patrick-kidger/lineax) linear operators together with 30 | differentiable and performant [JAX](https://github.com/google/jax) based numerics. 31 | 32 | ??? abstract "`traceax.AbstractTraceEstimator`" 33 | 34 | ::: traceax.AbstractTraceEstimator 35 | options: 36 | show_bases: false 37 | members: 38 | - estimate 39 | - __call__ 40 | 41 | ::: traceax.HutchinsonEstimator 42 | options: 43 | members: 44 | - __init__ 45 | 46 | --- 47 | 48 | ::: traceax.HutchPlusPlusEstimator 49 | options: 50 | members: 51 | - __init__ 52 | 53 | --- 54 | 55 | ::: traceax.XTraceEstimator 56 | options: 57 | members: 58 | - __init__ 59 | --- 60 | 61 | ::: traceax.XNysTraceEstimator 62 | options: 63 | members: 64 | - __init__ 65 | -------------------------------------------------------------------------------- /docs/api/samplers.md: -------------------------------------------------------------------------------- 1 | # Stochastic Samplers 2 | 3 | `traceax` uses a flexible approach to define how random samples are generated within 4 | [`traceax.AbstractTraceEstimator`][] instances. While this typically wraps a single 5 | jax random call, the varied interfaces for each randomization procedure may differ, 6 | which makes uniformly interfacing with it a bit annoying. As such, we provide a 7 | simple abstract class definition, [`traceax.AbstractSampler`][] using that subclasses 8 | [`Equinox`](https://docs.kidger.site/equinox/) modules. 9 | 10 | ??? abstract "`traceax.AbstractSampler`" 11 | ::: traceax.AbstractSampler 12 | options: 13 | show_bases: false 14 | members: 15 | - __call__ 16 | 17 | 18 | ::: traceax.NormalSampler 19 | options: 20 | members: 21 | - __init__ 22 | 23 | --- 24 | 25 | ::: traceax.SphereSampler 26 | options: 27 | members: 28 | - __init__ 29 | 30 | --- 31 | 32 | ::: traceax.RademacherSampler 33 | options: 34 | members: 35 | - __init__ 36 | -------------------------------------------------------------------------------- /docs/authors.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | * Linda Serafin 4 | * Nicholas Mancuso 5 | -------------------------------------------------------------------------------- /docs/conduct.md: -------------------------------------------------------------------------------- 1 | # MancusoLab Open Source Community Guidelines 2 | 3 | At MancusoLab, we recognize and celebrate the creativity and collaboration of open 4 | source contributors and the diversity of skills, experiences, cultures, and 5 | opinions they bring to the projects and communities they participate in. 6 | 7 | Every one of our open source projects and communities are inclusive 8 | environments, based on treating all individuals respectfully, regardless of 9 | gender identity and expression, sexual orientation, disabilities, 10 | neurodiversity, physical appearance, body size, ethnicity, nationality, race, 11 | age, religion, or similar personal characteristic. 12 | 13 | We value diverse opinions, but we value respectful behavior more. 14 | 15 | Respectful behavior includes: 16 | 17 | * Being considerate, kind, constructive, and helpful. 18 | * Not engaging in demeaning, discriminatory, harassing, hateful, sexualized, or 19 | physically threatening behavior, speech, and imagery. 20 | * Not engaging in unwanted physical contact. 21 | 22 | 23 | ## Resolve peacefully 24 | We do not believe that all conflict is necessarily bad; healthy debate and 25 | disagreement often yields positive results. However, it is never okay to be 26 | disrespectful. 27 | 28 | If you see someone behaving disrespectfully, you are encouraged to address the 29 | behavior directly with those involved. Many issues can be resolved quickly and 30 | easily, and this gives people more control over the outcome of their dispute. 31 | If you are unable to resolve the matter for any reason, or if the behavior is 32 | threatening or harassing, report it. We are dedicated to providing an 33 | environment where participants feel welcome and safe. 34 | 35 | ## Reporting problems 36 | Some MancusoLab open source projects may adopt a project-specific code of conduct. 37 | In those cases, a MancusoLab trainee (or Nick) will be identified as the Project Steward, 38 | who will receive and handle reports of code of conduct violations. In the event 39 | that a project hasn’t identified a Project Steward, you can report problems by 40 | emailing Nicholas.Mancuso@med.usc.edu. 41 | 42 | We will investigate every complaint, but you may not receive a direct response. 43 | We will use our discretion in determining when and how to follow up on reported 44 | incidents, which may range from not taking action to permanent expulsion from 45 | the project and project-sponsored spaces. We will notify the accused of the 46 | report and provide them an opportunity to discuss it before any action is 47 | taken. The identity of the reporter will be omitted from the details of the 48 | report supplied to the accused. In potentially harmful situations, such as 49 | ongoing harassment or threats to anyone's safety, we may take action without 50 | notice. 51 | 52 | This document was adapted from the 53 | [IndieWeb Code of Conduct](https://indieweb.org/code-of-conduct) 54 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | [![Documentation-webpage](https://img.shields.io/badge/Docs-Available-brightgreen)](https://mancusolab.github.io/traceax/) 2 | [![PyPI-Server](https://img.shields.io/pypi/v/traceax.svg)](https://pypi.org/project/traceax/) 3 | [![Github](https://img.shields.io/github/stars/mancusolab/traceax?style=social)](https://github.com/mancusolab/traceax) 4 | [![License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | [![Project generated with Hatch](https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg)](https://github.com/pypa/hatch) 6 | 7 | # Traceax 8 | ``traceax`` is a Python library to perform stochastic trace estimation for linear operators. Namely, 9 | given a square linear operator $\mathbf{A}$, ``traceax`` provides flexible routines that estimate, 10 | 11 | $$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},$$ 12 | 13 | using only matrix-vector products. ``traceax`` is heavily inspired by 14 | [lineax](https://github.com/patrick-kidger/lineax) as well as 15 | [XTrace](https://github.com/eepperly/XTrace). 16 | 17 | 18 | 19 | [**Installation**](#installation) 20 | | [**Example**](#get-started-with-example) 21 | | [**Notes**](#notes) 22 | | [**Support**](#support) 23 | | [**Other Software**](#other-software) 24 | 25 | ------------------ 26 | 27 | ## Installation 28 | 29 | Users can download the latest repository and then use `pip`: 30 | 31 | ``` bash 32 | git clone https://github.com/mancusolab/traceax.git 33 | cd traceax 34 | pip install . 35 | ``` 36 | 37 | ## Get Started with Example 38 | 39 | ```python 40 | import jax.numpy as jnp 41 | import jax.random as rdm 42 | import lineax as lx 43 | 44 | import traceax as tx 45 | 46 | # simulate simple symmetric matrix with exponential eigenvalue decay 47 | seed = 0 48 | N = 1000 49 | key = rdm.PRNGKey(seed) 50 | key, xkey = rdm.split(key) 51 | 52 | X = rdm.normal(xkey, (N, N)) 53 | Q, R = jnp.linalg.qr(X) 54 | U = jnp.power(0.7, jnp.arange(N)) 55 | A = (Q * U) @ Q.T 56 | 57 | # should be numerically close 58 | print(jnp.trace(A)) # 3.3333323 59 | print(jnp.sum(U)) # 3.3333335 60 | 61 | # setup linear operator 62 | operator = lx.MatrixLinearOperator(A) 63 | 64 | # number of matrix vector operators 65 | k = 25 66 | 67 | # split key for estimators 68 | key, key1, key2, key3, key4 = rdm.split(key, 5) 69 | 70 | # Hutchinson estimator; default samples Rademacher {-1,+1} 71 | hutch = tx.HutchinsonEstimator() 72 | print(hutch.estimate(key1, operator, k)) # (Array(3.4099615, dtype=float32), {}) 73 | 74 | # Hutch++ estimator; default samples Rademacher {-1,+1} 75 | hpp = tx.HutchPlusPlusEstimator() 76 | print(hpp.estimate(key2, operator, k)) # (Array(3.3033807, dtype=float32), {}) 77 | 78 | # XTrace estimator; default samples uniformly on n-Sphere 79 | xt = tx.XTraceEstimator() 80 | print(xt.estimate(key3, operator, k)) # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)}) 81 | 82 | # XNysTrace estimator; Improved performance for NSD/PSD trace estimates 83 | operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag) 84 | nt = tx.XNysTraceEstimator() 85 | print(nt.estimate(key4, operator, k)) # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)}) 86 | ``` 87 | 88 | ## Notes 89 | 90 | - `traceax` uses [JAX](https://github.com/google/jax) with [Just In 91 | Time](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) 92 | compilation to achieve high-speed computation. However, there are 93 | some [issues](https://github.com/google/jax/issues/5501) for JAX 94 | with Mac M1 chip. To solve this, users need to initiate conda using 95 | [miniforge](https://github.com/conda-forge/miniforge), and then 96 | install `traceax` using `pip` in the desired environment. 97 | 98 | 99 | ## Support 100 | 101 | Please report any bugs or feature requests in the [Issue 102 | Tracker](https://github.com/mancusolab/traceax/issues). If users have 103 | any questions or comments, please contact Linda Serafin () or 104 | Nicholas Mancuso (). 105 | 106 | ## Other Software 107 | 108 | Feel free to use other software developed by [Mancuso 109 | Lab](https://www.mancusolab.com/): 110 | 111 | - [SuShiE](https://github.com/mancusolab/sushie): a Bayesian 112 | fine-mapping framework for molecular QTL data across multiple 113 | ancestries. 114 | - [MA-FOCUS](https://github.com/mancusolab/ma-focus): a Bayesian 115 | fine-mapping framework using 116 | [TWAS](https://www.nature.com/articles/ng.3506) statistics across 117 | multiple ancestries to identify the causal genes for complex traits. 118 | - [SuSiE-PCA](https://github.com/mancusolab/susiepca): a scalable 119 | Bayesian variable selection technique for sparse principal component 120 | analysis 121 | - [twas_sim](https://github.com/mancusolab/twas_sim): a Python 122 | software to simulate [TWAS](https://www.nature.com/articles/ng.3506) 123 | statistics. 124 | - [FactorGo](https://github.com/mancusolab/factorgo): a scalable 125 | variational factor analysis model that learns pleiotropic factors 126 | from GWAS summary statistics. 127 | - [HAMSTA](https://github.com/tszfungc/hamsta): a Python software to 128 | estimate heritability explained by local ancestry data from 129 | admixture mapping summary statistics. 130 | 131 | ------------------------------------------------------------------------ 132 | 133 | ``traceax`` is distributed under the terms of the 134 | [Apache-2.0 license](https://spdx.org/licenses/Apache-2.0.html). 135 | 136 | 137 | ------------------------------------------------------------------------ 138 | 139 | This project has been set up using Hatch. For details and usage 140 | information on Hatch see . 141 | -------------------------------------------------------------------------------- /docs/license.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 MancusoLab 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | equinox # traceax uses this, but we also preload for easier doc building 2 | jax[cpu] # core dependencies 3 | jaxtyping # traceax uses this, but we also preload for easier doc building 4 | jinja2 # necessary for mkdocs and mkdocs ecosystem 5 | lineax # traceax uses this, but we also preload for easier doc building 6 | mkdocs # Main documentation generator. 7 | mkdocs-material # Theme 8 | mkdocs_include_exclude_files # Allow for customising which files get included 9 | mkdocstrings[python] # Autogenerate documentation from docstrings. 10 | mknotebooks # Turn Jupyter Lab notebooks into webpages. 11 | pydantic # data models 12 | pygments # syntax highlighting 13 | pymdown-extensions # Markdown extensions e.g. to handle LaTeX. 14 | # Install latest version of our dependencies 15 | typing_extensions # traceax uses this, but we also preload for easier doc building 16 | -------------------------------------------------------------------------------- /docs/scripts/extension.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | 4 | from griffe import ( 5 | Class, 6 | Docstring, 7 | dynamic_import, 8 | ExprCall, 9 | Extension, 10 | Function, 11 | get_logger, 12 | Inspector, 13 | Object, 14 | ObjectNode, 15 | Parameter, 16 | Visitor, 17 | ) 18 | 19 | 20 | logger = get_logger(__name__) 21 | 22 | 23 | def _get_dynamic_docstring(obj: Object, name: str) -> Docstring: 24 | # import object to get its evaluated docstring 25 | try: 26 | runtime_obj = dynamic_import(obj.path) 27 | init_docstring = getattr(runtime_obj, name).__doc__ 28 | except ImportError: 29 | logger.debug(f"Could not get dynamic docstring for {obj.path}") 30 | return 31 | except AttributeError: 32 | logger.debug(f"Object {obj.path} does not have a __doc__ attribute") 33 | return 34 | 35 | if init_docstring is None: 36 | return None 37 | 38 | # update the object instance with the evaluated docstring 39 | init_docstring = inspect.cleandoc(init_docstring) 40 | 41 | return Docstring(init_docstring, parent=obj) 42 | 43 | 44 | class DynamicDocstrings(Extension): 45 | def __init__(self, paths: list[str] | None = None) -> None: 46 | self.module_paths = paths 47 | 48 | def on_class_members( 49 | self, 50 | *, 51 | node: ast.AST | ObjectNode, 52 | cls: Class, 53 | agent: Visitor | Inspector, 54 | **kwargs, 55 | ) -> None: 56 | logger.debug(f"Inspecting class member {cls.path}") 57 | if isinstance(node, ObjectNode): 58 | return # skip runtime objects, their docstrings are already right 59 | if self.module_paths and cls.parent is None or cls.parent.path not in self.module_paths: 60 | return # skip objects that were not selected 61 | 62 | # pull class attributes as parameters for the __init__ function... 63 | parameters = [] 64 | for attr in cls.members.values(): 65 | if attr.is_attribute: 66 | if attr.value is not None: 67 | # import pdb; pdb.set_trace() 68 | if type(attr.value) is ExprCall and len(attr.value.arguments) > 0: 69 | for arg in attr.value.arguments: 70 | if arg.name == "default": 71 | param = Parameter( 72 | name=attr.name, default=arg.value.name, annotation=attr.annotation, kind=attr.kind 73 | ) 74 | else: 75 | param = Parameter( 76 | name=attr.name, default=attr.value, annotation=attr.annotation, kind=attr.kind 77 | ) 78 | else: 79 | param = Parameter(name=attr.name, annotation=attr.annotation, kind=attr.kind) 80 | parameters.append(param) 81 | 82 | # such a huge hack to pull in inherited attributes 83 | cls.members["__init__"] = Function( 84 | name="__init__", parameters=parameters, docstring=_get_dynamic_docstring(cls, "__init__") 85 | ) 86 | # add docs for __call__ only if it was explicitly defined (e.g., ExpFam, but not concrete subclasses) 87 | if "__call__" in cls.members: 88 | cls.members["__call__"].docstring = _get_dynamic_docstring(cls, "__call__") 89 | 90 | return 91 | -------------------------------------------------------------------------------- /docs/templates/python/material/class.html.jinja: -------------------------------------------------------------------------------- 1 | {% extends "_base/class.html" %} 2 | {% block heading scoped %} 3 | {% if config.show_bases and class.bases %} 4 | {{ class_name }}( 5 | {%- for expression in class.bases -%} 6 | {% if loop.last %} 7 | {% include "expression.html" with context %}) 8 | {% else %} 9 | {% include "expression.html" with context %}, 10 | {% endif %} 11 | {% endfor %} 12 | {% else %} 13 | {{ class_name }} 14 | {% endif %} 15 | 16 | {% endblock heading %} 17 | {% block bases scoped %} 18 | {% if config.show_bases and class.bases %} 19 | {% endif %} 20 | {% endblock bases %} 21 | -------------------------------------------------------------------------------- /docs/templates/python/material/signature.html.jinja: -------------------------------------------------------------------------------- 1 | {%- if config.show_signature -%} 2 | {{ log.debug("Rendering signature") }} 3 | {%- with -%} 4 | 5 | {%- set ns = namespace( 6 | has_pos_only=False, 7 | render_pos_only_separator=True, 8 | render_kw_only_separator=True, 9 | annotation="", 10 | equal="=", 11 | ) 12 | -%} 13 | 14 | ( 15 | {%- for parameter in function.parameters -%} 16 | {{ log.debug("Rendering parameter " + parameter.name) }} 17 | {%- if parameter.name not in ("cls") or loop.index0 > 0 or not (function.parent and function.parent.is_class) -%} 18 | {%- if parameter.kind.value == "positional-only" -%} 19 | {%- set ns.has_pos_only = True -%} 20 | {%- else -%} 21 | {%- if ns.has_pos_only and ns.render_pos_only_separator -%} 22 | {%- set ns.render_pos_only_separator = False %}/, {% endif -%} 23 | {%- if parameter.kind.value == "keyword-only" -%} 24 | {%- if ns.render_kw_only_separator -%} 25 | {%- set ns.render_kw_only_separator = False %}*, {% endif -%} 26 | {%- endif -%} 27 | {%- endif -%} 28 | 29 | {%- if config.show_signature_annotations and parameter.annotation is not none -%} 30 | {%- set ns.equal = " = " -%} 31 | {%- if config.separate_signature and config.signature_crossrefs -%} 32 | {%- with expression = parameter.annotation -%} 33 | {%- set ns.annotation -%}: {% include "expression.html" with context %}{%- endset -%} 34 | {%- endwith -%} 35 | {%- else -%} 36 | {%- set ns.annotation = ": " + parameter.annotation|safe -%} 37 | {%- endif -%} 38 | {%- else -%} 39 | {%- set ns.equal = "=" -%} 40 | {%- set ns.annotation = "" -%} 41 | {%- endif -%} 42 | 43 | {%- if parameter.default is not none and parameter.kind.value != "variadic positional" and parameter.kind.value != "variadic keyword" -%} 44 | {%- set default = ns.equal + parameter.default|safe -%} 45 | {%- endif -%} 46 | 47 | {%- if parameter.kind.value == "variadic positional" -%} 48 | {%- set ns.render_kw_only_separator = False -%} 49 | {%- endif -%} 50 | 51 | {% if parameter.kind.value == "variadic positional" %}*{% elif parameter.kind.value == "variadic keyword" %}**{% endif -%} 52 | {{ parameter.name }}{{ ns.annotation }}{{ default }} 53 | {%- if not loop.last %}, {% endif -%} 54 | 55 | {%- endif -%} 56 | {%- endfor -%} 57 | ) 58 | {%- if config.show_signature_annotations 59 | and function.annotation 60 | and not (config.merge_init_into_class and function.name == "__init__" ) 61 | %} -> {% if config.separate_signature and config.signature_crossrefs -%} 62 | {%- with expression = function.annotation %}{% include "expression.html" with context %}{%- endwith -%} 63 | {%- else -%} 64 | {{ function.annotation|safe }} 65 | {%- endif -%} 66 | {%- endif -%} 67 | 68 | {%- endwith -%} 69 | {%- endif -%} 70 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | theme: # pulled from optimistix 2 | name: material 3 | features: 4 | - navigation.sections # Sections are included in the navigation on the left. 5 | - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. 6 | - header.autohide # header disappears as you scroll 7 | palette: 8 | # Light mode / dark mode 9 | # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as 10 | # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. 11 | - scheme: default 12 | primary: white 13 | accent: amber 14 | toggle: 15 | icon: material/weather-night 16 | name: Switch to dark mode 17 | - scheme: slate 18 | primary: black 19 | accent: amber 20 | toggle: 21 | icon: material/weather-sunny 22 | name: Switch to light mode 23 | icon: 24 | repo: fontawesome/brands/github # GitHub logo in top right 25 | logo: "material/sigma" # traceax logo in top left 26 | favicon: "_static/favicon.png" 27 | custom_dir: "docs/_overrides" # Overriding part of the HTML 28 | 29 | 30 | site_name: Traceax 31 | site_description: The documentation for the traceax software library. 32 | site_url: https://mancusolab.github.io/traceax 33 | 34 | repo_url: https://github.com/mancusolab/traceax 35 | repo_name: mancusolab/traceax 36 | edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate 37 | 38 | strict: true # Don't allow warnings during the build process 39 | 40 | extra_javascript: 41 | # The below three make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ 42 | - _static/mathjax.js 43 | - https://polyfill.io/v3/polyfill.min.js?features=es6 44 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 45 | 46 | extra_css: 47 | - _static/custom_css.css 48 | 49 | markdown_extensions: 50 | - pymdownx.arithmatex: # Render LaTeX via MathJax 51 | generic: true 52 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. 53 | - pymdownx.details # Allowing hidden expandable regions denoted by ??? 54 | - pymdownx.snippets: # Include one Markdown file into another 55 | base_path: docs 56 | - admonition 57 | - toc: 58 | permalink: "¤" # Adds a clickable permalink to each section heading 59 | toc_depth: 4 60 | 61 | plugins: 62 | - search # default search plugin; needs manually re-enabling when using any other plugins 63 | - autorefs # Cross-links to headings 64 | - include_exclude_files: 65 | exclude: 66 | - "_overrides" 67 | - hippogriffe: 68 | extra_public_objects: 69 | - jax.Array 70 | - lineax.AbstractLinearOperator 71 | - mkdocstrings: 72 | default_handler: python 73 | enable_inventory: true 74 | custom_templates: docs/templates 75 | handlers: 76 | python: 77 | paths: [src] 78 | import: 79 | - https://docs.kidger.site/equinox/objects.inv 80 | - https://docs.kidger.site/jaxtyping/objects.inv 81 | options: 82 | extensions: 83 | - docs/scripts/extension.py:DynamicDocstrings: 84 | paths: [ traceax._estimators, traceax._samplers ] 85 | # general options 86 | force_inspection: true 87 | show_bases: true 88 | show_source: false 89 | # heading options 90 | heading_level: 4 91 | show_root_heading: true 92 | show_root_full_path: true 93 | # members options 94 | inherited_members: true 95 | members_order: source 96 | filters: 97 | - "!^_" 98 | - "^__init__$" 99 | # docstring options 100 | show_if_no_docstring: true 101 | # signature/type annotation options 102 | separate_signature: false 103 | annotations_path: brief 104 | show_signature_annotations: true 105 | 106 | nav: 107 | - 'index.md' 108 | - API: 109 | - 'api/estimators.md' 110 | - 'api/samplers.md' 111 | - Misc: 112 | - 'authors.md' 113 | - 'license.md' 114 | - 'conduct.md' 115 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "traceax" 7 | dynamic = ["version"] 8 | description = "Stochastic trace estimation in JAX, Lineax, and Equinox" 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | license = {text = "Apache2.0"} 12 | keywords = [ 13 | "jax", 14 | "trace-estimation", 15 | "statistics", 16 | "machine-learning", 17 | ] 18 | authors = [ 19 | { name = "Linda Serafin", email = "lserafin@usc.edu" }, 20 | { name = "Nicholas Mancuso", email = "nmancuso@usc.edu" }, 21 | ] 22 | classifiers = [ 23 | "Development Status :: 4 - Beta", 24 | "Intended Audience :: Science/Research", 25 | "Intended Audience :: Developers", 26 | "License :: OSI Approved :: Apache Software License", 27 | "Natural Language :: English", 28 | "Topic :: Scientific/Engineering", 29 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 30 | "Topic :: Scientific/Engineering :: Information Analysis", 31 | "Topic :: Scientific/Engineering :: Mathematics", 32 | "Programming Language :: Python", 33 | "Programming Language :: Python :: 3.9", 34 | "Programming Language :: Python :: 3.10", 35 | "Programming Language :: Python :: 3.11", 36 | "Programming Language :: Python :: 3.12", 37 | "Programming Language :: Python :: Implementation :: CPython", 38 | "Programming Language :: Python :: Implementation :: PyPy", 39 | ] 40 | dependencies = [ 41 | "jax>=0.4.13", 42 | "jaxtyping>=0.2.20", 43 | "equinox>=0.11.3", 44 | "lineax>=0.0.4", 45 | "typing_extensions>=4.5.0", 46 | ] 47 | 48 | [project.optional-dependencies] 49 | docs = [ 50 | "hippogriffe", 51 | "mkdocs", 52 | "mkdocs-include-exclude-files", 53 | "mkdocs-material", 54 | "mkdocstrings>=0.29.0", 55 | "mkdocstrings-python", 56 | "pymdown-extensions", 57 | ] 58 | 59 | [project.urls] 60 | Documentation = "https://github.com/mancusolab/traceax#readme" 61 | Issues = "https://github.com/mancusolab/traceax/issues" 62 | Source = "https://github.com/mancusolab/traceax" 63 | 64 | [tool.hatch.version] 65 | source = "vcs" 66 | 67 | [tool.hatch.build.hooks.vcs] 68 | version-file = "src/traceax/_version.py" 69 | 70 | [tool.hatch.envs.default] 71 | dependencies = [ 72 | "coverage[toml]>=6.5", 73 | "pytest", 74 | "pytest-cov", 75 | ] 76 | [tool.hatch.envs.default.scripts] 77 | test = "pytest {args:tests}" 78 | test-cov = "coverage run -m pytest {args:tests}" 79 | cov-report = [ 80 | "- coverage combine", 81 | "coverage report", 82 | ] 83 | cov = [ 84 | "test-cov", 85 | "cov-report", 86 | ] 87 | 88 | [[tool.hatch.envs.all.matrix]] 89 | python = ["3.9", "3.10", "3.11", "3.12"] 90 | 91 | [tool.hatch.envs.lint] 92 | detached = true 93 | dependencies = [ 94 | "mypy>=1.0.0", 95 | "ruff>=0.2.2", 96 | ] 97 | [tool.hatch.envs.lint.scripts] 98 | typing = "mypy --install-types --non-interactive {args:src/traceax tests}" 99 | style = [ 100 | "ruff {args:.}", 101 | ] 102 | fmt = [ 103 | "ruff --fix {args:.}", 104 | "style", 105 | ] 106 | all = [ 107 | "style", 108 | "typing", 109 | ] 110 | 111 | [tool.ruff] 112 | line-length = 120 113 | 114 | [tool.ruff.lint] 115 | fixable = ["ALL"] 116 | select = ["E", "F", "I001"] 117 | ignore = [ 118 | # Allow non-abstract empty methods in abstract base classes 119 | "B027", 120 | # Allow boolean positional values in function calls, like `dict.get(... True)` 121 | "FBT003", 122 | # Ignore checks for possible passwords 123 | "S105", "S106", "S107", 124 | # Ignore complexity 125 | "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", 126 | # gets confused on empty strings for array types (e.g., Bool[Array, ""]) 127 | "F722" 128 | ] 129 | ignore-init-module-imports = true 130 | 131 | [tool.ruff.lint.isort] 132 | known-first-party = ["traceax"] 133 | combine-as-imports = true 134 | lines-after-imports = 2 135 | lines-between-types = 1 136 | known-local-folder = ["src"] 137 | section-order = ["future", "standard-library", "third-party", "jax-ecosystem", "first-party", "local-folder"] 138 | extra-standard-library = ["typing_extensions"] 139 | order-by-type = false 140 | 141 | [tool.ruff.lint.isort.sections] 142 | jax-ecosystem = ["equinox", "jax", "jaxtyping", "lineax"] 143 | 144 | [tool.ruff.lint.per-file-ignores] 145 | # Tests can use magic values, assertions, and relative imports 146 | "tests/**/*" = ["PLR2004", "S101", "TID252"] 147 | 148 | [tool.pyright] 149 | reportIncompatibleMethodOverride = true 150 | reportIncompatibleVariableOverride = false # Incompatible with eqx.AbstractVar 151 | include = ["src/traceax", "tests"] 152 | exclude = ["docs"] 153 | 154 | [tool.coverage.run] 155 | source_pkgs = ["traceax", "tests"] 156 | branch = true 157 | parallel = true 158 | omit = [] 159 | 160 | [tool.coverage.paths] 161 | traceax = ["src/traceax", "*/traceax/src/traceax"] 162 | tests = ["tests", "*/traceax/tests"] 163 | 164 | [tool.coverage.report] 165 | exclude_lines = [ 166 | "no cov", 167 | "if __name__ == .__main__.:", 168 | "if TYPE_CHECKING:", 169 | ] 170 | -------------------------------------------------------------------------------- /src/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/traceax/8a7080f693799634ea4f9c0e620e09c872e80695/src/py.typed -------------------------------------------------------------------------------- /src/traceax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 MancusoLab. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 16 | 17 | from ._estimators import ( 18 | AbstractTraceEstimator as AbstractTraceEstimator, 19 | HutchinsonEstimator as HutchinsonEstimator, 20 | HutchPlusPlusEstimator as HutchPlusPlusEstimator, 21 | XNysTraceEstimator as XNysTraceEstimator, 22 | XTraceEstimator as XTraceEstimator, 23 | ) 24 | from ._samplers import ( 25 | AbstractSampler as AbstractSampler, 26 | NormalSampler as NormalSampler, 27 | RademacherSampler as RademacherSampler, 28 | SphereSampler as SphereSampler, 29 | ) 30 | 31 | 32 | try: 33 | # Change here if project is renamed and does not equal the package name 34 | dist_name = __name__ 35 | __version__ = version(dist_name) 36 | except PackageNotFoundError: # pragma: no cover 37 | __version__ = "unknown" 38 | finally: 39 | del version, PackageNotFoundError 40 | -------------------------------------------------------------------------------- /src/traceax/_estimators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 MancusoLab. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import abstractmethod 16 | from typing import Any 17 | 18 | import equinox as eqx 19 | import jax 20 | import jax.numpy as jnp 21 | import jax.scipy as jsp 22 | 23 | from equinox import AbstractVar 24 | from jax.numpy.linalg import norm 25 | from jaxtyping import Array, PRNGKeyArray 26 | from lineax import AbstractLinearOperator, is_negative_semidefinite, is_positive_semidefinite 27 | 28 | from ._samplers import AbstractSampler, RademacherSampler, SphereSampler 29 | 30 | 31 | def _check_shapes(operator: AbstractLinearOperator, k: int) -> tuple[int, int]: 32 | n_in = operator.in_size() 33 | n_out = operator.out_size() 34 | if n_in != n_out: 35 | raise ValueError(f"Trace estimation requires square linear operator. Found {(n_out, n_in)}.") 36 | 37 | if k < 1: 38 | raise ValueError(f"Trace estimation requires positive number of matvecs. Found {k}.") 39 | 40 | return n_in, k 41 | 42 | 43 | def _get_scale(W: Array, D: Array, n: int, k: int) -> Array: 44 | return (n - k + 1) / (n - norm(W, axis=0) ** 2 + jnp.abs(D) ** 2) 45 | 46 | 47 | class AbstractTraceEstimator(eqx.Module, strict=True): 48 | r"""Abstract base class for all trace estimators.""" 49 | 50 | sampler: AbstractVar[AbstractSampler] 51 | 52 | @abstractmethod 53 | def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: 54 | """Estimate the trace of `operator`. 55 | 56 | !!! Example 57 | 58 | ```python 59 | key = jax.random.PRNGKey(...) 60 | operator = lx.MatrixLinearOperator(...) 61 | hutch = tx.HutchinsonEstimator() 62 | result = hutch.compute(key, operator, k=10) 63 | # or 64 | result = hutch(key, operator, k=10) 65 | ``` 66 | 67 | **Arguments:** 68 | 69 | - `key`: the PRNG key used as the random key for sampling. 70 | - `operator`: the (square) linear operator for which the trace is to be estimated. 71 | - `k`: the number of matrix vector operations to perform for trace estimation. 72 | 73 | **Returns:** 74 | 75 | A two-tuple of: 76 | 77 | - The trace estimate. 78 | - A dictionary of any extra statistics above the trace, e.g., the standard error. 79 | """ 80 | ... 81 | 82 | def __call__(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: 83 | """An alias for `estimate`.""" 84 | return self.estimate(key, operator, k) 85 | 86 | 87 | class HutchinsonEstimator(AbstractTraceEstimator): 88 | r"""Girard-Hutchinson Trace Estimator: 89 | 90 | $\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})$, 91 | where $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. 92 | 93 | """ 94 | 95 | sampler: AbstractSampler = RademacherSampler() 96 | 97 | def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: 98 | n, k = _check_shapes(operator, k) 99 | # sample from proposed distribution 100 | samples = self.sampler(key, n, k) 101 | 102 | # project to k-dim space 103 | projected = jax.vmap(operator.mv, (1,), 1)(samples) 104 | 105 | # take the mean across estimates 106 | trace_est = jnp.sum(projected * samples) / k 107 | 108 | return trace_est, {} 109 | 110 | 111 | HutchinsonEstimator.__init__.__doc__ = r"""**Arguments:** 112 | 113 | - `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. 114 | """ 115 | 116 | 117 | class HutchPlusPlusEstimator(AbstractTraceEstimator): 118 | r"""Hutch++ Trace Estimator: 119 | 120 | Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the a _low-rank approximation_ 121 | to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for 122 | $\Omega = [\omega_1, \dotsc, \omega_k]$. 123 | 124 | Hutch++ improves upon Girard-Hutchinson estimator by including the trace of the residuals. Namely, 125 | Hutch++ estimates $\text{trace}(\mathbf{A})$ as 126 | $\text{trace}(\hat{\mathbf{A}}) - \text{trace}(\mathbf{A} - \hat{\mathbf{A}})$. 127 | 128 | As with the Girard-Hutchinson estimator, it requires 129 | $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. 130 | 131 | """ 132 | 133 | sampler: AbstractSampler = RademacherSampler() 134 | 135 | def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: 136 | # generate an n, k matrix X 137 | n, k = _check_shapes(operator, k) 138 | m = k // 3 139 | 140 | # some operators work fine with matrices in mv, some dont; this ensures they all do 141 | mv = jax.vmap(operator.mv, (1,), 1) 142 | 143 | # split X into 2 Xs; X1 and X2, where X1 has shape 2m, where m = k/3 144 | samples = self.sampler(key, n, 2 * m) 145 | X1 = samples[:, :m] 146 | X2 = samples[:, m:] 147 | 148 | Y = mv(X1) 149 | 150 | # compute Q, _ = QR(Y) (orthogonal matrix) 151 | Q, _ = jnp.linalg.qr(Y) 152 | 153 | # compute G = X2 - Q @ (Q.T @ X2) 154 | G = X2 - Q @ (Q.T @ X2) 155 | 156 | # estimate trace = tr(Q.T @ A @ Q) + tr(G.T @ A @ G) / k 157 | AQ = mv(Q) 158 | AG = mv(G) 159 | trace_est = jnp.sum(AQ * Q) + jnp.sum(AG * G) / (G.shape[1]) 160 | 161 | return trace_est, {} 162 | 163 | 164 | HutchPlusPlusEstimator.__init__.__doc__ = r"""**Arguments:** 165 | 166 | - `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. 167 | """ 168 | 169 | 170 | class XTraceEstimator(AbstractTraceEstimator): 171 | r"""XTrace Trace Estimator: 172 | 173 | Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the the _low-rank approximation_ 174 | to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for 175 | $\Omega = [\omega_1, \dotsc, \omega_k]$. 176 | 177 | XTrace improves upon Hutch++ estimator by enforcing *exchangeability* of sampled test-vectors, 178 | to construct a symmetric estimation function with lower variance. 179 | 180 | Additionally, the *improved* XTrace algorithm (i.e. `improved = True`), ensures that test-vectors 181 | are orthogonalized against the low rank approximation $\mathbf{Q}\mathbf{Q}^* \mathbf{A}$ and 182 | renormalized. This improved XTrace approach may provide better empirical results compared with 183 | the non-orthogonalized version. 184 | 185 | As with the Girard-Hutchinson estimator, it requires 186 | $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. 187 | 188 | """ 189 | 190 | sampler: AbstractSampler = SphereSampler() 191 | improved: bool = True 192 | 193 | def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: 194 | n, k = _check_shapes(operator, k) 195 | m = k // 2 196 | 197 | # some operators work fine with matrices in mv, some dont; this ensures they all do 198 | mv = jax.vmap(operator.mv, (1,), 1) 199 | 200 | samples = self.sampler(key, n, m) 201 | Y = mv(samples) 202 | Q, R = jnp.linalg.qr(Y) 203 | 204 | # solve and rescale 205 | S = jnp.linalg.inv(R).T 206 | s = norm(S, axis=0) 207 | S = S / s 208 | 209 | # working variables 210 | Z = mv(Q) 211 | H = Q.T @ Z 212 | W = Q.T @ samples 213 | T = Z.T @ samples 214 | HW = H @ W 215 | 216 | SW_d = jnp.sum(S * W, axis=0) 217 | TW_d = jnp.sum(T * W, axis=0) 218 | SHS_d = jnp.sum(S * (H @ S), axis=0) 219 | WHW_d = jnp.sum(W * HW, axis=0) 220 | 221 | term1 = SW_d * jnp.sum((T - H.T @ W) * S, axis=0) 222 | term2 = (jnp.abs(SW_d) ** 2) * SHS_d 223 | term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0) 224 | 225 | if self.improved: 226 | scale = _get_scale(W, SW_d, n, m) 227 | else: 228 | scale = 1 229 | 230 | estimates = jnp.trace(H) * jnp.ones(m) - SHS_d + (WHW_d - TW_d + term1 + term2 + term3) * scale 231 | trace_est = jnp.mean(estimates) 232 | std_err = jnp.std(estimates) / jnp.sqrt(m) 233 | 234 | return trace_est, {"std.err": std_err} 235 | 236 | 237 | XTraceEstimator.__init__.__doc__ = r"""**Arguments:** 238 | 239 | - `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. 240 | - `improved`: whether to use the *improved* XTrace estimator, which rescales predicted samples. 241 | Default is `True` (see Notes). 242 | """ 243 | 244 | 245 | class XNysTraceEstimator(AbstractTraceEstimator): 246 | r"""XNysTrace Trace Estimator: 247 | 248 | XNysTrace improves upon XTrace estimator when $\mathbf{A}$ is (negative-) positive-semidefinite, by 249 | performing a [Nyström approximation](https://en.wikipedia.org/wiki/Low-rank_matrix_approximations#Nystr%C3%B6m_approximation), 250 | rather than a randomized SVD (i.e., random projection followed by QR decomposition). 251 | 252 | Like, [`traceax.XTraceEstimator`][], the *improved* XNysTrace algorithm (i.e. `improved = True`), ensures 253 | that test-vectors are orthogonalized against the low rank approximation and renormalized. 254 | This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version. 255 | 256 | As with the Girard-Hutchinson estimator, it requires 257 | $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. 258 | 259 | """ 260 | 261 | sampler: AbstractSampler = SphereSampler() 262 | improved: bool = True 263 | 264 | def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: 265 | is_nsd = is_negative_semidefinite(operator) 266 | if not (is_positive_semidefinite(operator) | is_nsd): 267 | raise ValueError("`XNysTraceEstimator` may only be used for positive or negative definite linear operators") 268 | if is_nsd: 269 | operator = -operator 270 | 271 | n, k = _check_shapes(operator, k) 272 | 273 | # some operators work fine with matrices in mv, some dont; this ensures they all do 274 | mv = jax.vmap(operator.mv, (1,), 1) 275 | 276 | samples = self.sampler(key, n, k) 277 | Y = mv(samples) 278 | 279 | # shift for numerical issues 280 | nu = jnp.finfo(Y.dtype).eps * norm(Y, "fro") / jnp.sqrt(n) 281 | Y = Y + samples * nu 282 | Q, R = jnp.linalg.qr(Y) 283 | 284 | # compute and symmetrize H, then take cholesky factor 285 | H = samples.T @ Y 286 | C = jnp.linalg.cholesky(0.5 * (H + H.T)).T 287 | B = jsp.linalg.solve_triangular(C.T, R.T, lower=True).T 288 | 289 | # if improved == True 290 | Qs, Rs = jnp.linalg.qr(samples) 291 | Ws = Qs.T @ samples 292 | 293 | # solve and rescale 294 | if self.improved: 295 | S = jnp.linalg.inv(Rs).T 296 | s = norm(S, axis=0) 297 | S = S / s 298 | scale = _get_scale(Ws, jnp.sum(S * Ws, axis=0), n, k) 299 | else: 300 | scale = 1 301 | 302 | W = Q.T @ samples 303 | S = jsp.linalg.solve_triangular(C, B.T).T / jnp.sqrt(jnp.diag(jnp.linalg.inv(H))) 304 | dSW = jnp.sum(S * W, axis=0) 305 | 306 | estimates = norm(B, "fro") ** 2 - norm(S, axis=0) ** 2 + (jnp.abs(dSW) ** 2) * scale - nu * n 307 | trace_est = jnp.mean(estimates) 308 | std_err = jnp.std(estimates) / jnp.sqrt(k) 309 | trace_est = jnp.where(is_nsd, -trace_est, trace_est) 310 | 311 | return trace_est, {"std.err": std_err} 312 | 313 | 314 | XNysTraceEstimator.__init__.__doc__ = r"""**Arguments:** 315 | 316 | - `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. 317 | - `improved`: whether to use the *improved* XNysTrace estimator, which rescales predicted samples. 318 | Default is `True` (see Notes). 319 | """ 320 | -------------------------------------------------------------------------------- /src/traceax/_samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 MancusoLab. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import abstractmethod 16 | 17 | import equinox as eqx 18 | import jax.numpy as jnp 19 | import jax.random as rdm 20 | 21 | from equinox import AbstractVar 22 | from jax.dtypes import canonicalize_dtype 23 | from jax.numpy import issubdtype 24 | from jaxtyping import Array, DTypeLike, Inexact, Num, PRNGKeyArray 25 | 26 | 27 | class AbstractSampler(eqx.Module, strict=True): 28 | """Abstract base class for all samplers.""" 29 | 30 | dtype: AbstractVar[DTypeLike] 31 | 32 | @abstractmethod 33 | def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Num[Array, "n k"]: 34 | r"""Sample random variates from the underlying distribution as an $n \times k$ 35 | matrix. 36 | 37 | !!! Example 38 | 39 | ```python 40 | sampler = tr.RademacherSampler() 41 | samples = sampler(key, n, k) 42 | ``` 43 | 44 | Each sampler accepts a `dtype` (i.e. `float`, `complex`, `int`) argument upon initialization, 45 | with sensible default values. This makes it possible to sample from more general spaces (e.g., 46 | complex Normal test-vectors). 47 | 48 | !!! Example 49 | 50 | ```python 51 | sampler = tr.NormalSampler(complex) 52 | samples = sampler(key, n, k) 53 | ``` 54 | 55 | **Arguments:** 56 | 57 | - `key`: a jax PRNG key used as the random key. 58 | - `n`: the size of the leading dimension. 59 | - `k`: the size of the trailing dimension. 60 | 61 | **Returns**: 62 | 63 | An Array of random samples. 64 | """ 65 | ... 66 | 67 | 68 | class NormalSampler(AbstractSampler, strict=True): 69 | r"""Standard normal distribution sampler. 70 | 71 | Generates samples $X_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$. 72 | 73 | !!! Note 74 | Supports float and complex-valued types. 75 | """ 76 | 77 | dtype: DTypeLike = eqx.field(converter=canonicalize_dtype, default=float) 78 | 79 | def __check_init__(self): 80 | if not issubdtype(self.dtype, jnp.inexact): 81 | raise ValueError(f"NormalSampler requires float or complex dtype. Found {self.dtype}.") 82 | 83 | def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: 84 | return rdm.normal(key, (n, k), self.dtype) 85 | 86 | 87 | NormalSampler.__init__.__doc__ = r"""**Arguments:** 88 | 89 | - `dtype`: numeric representation for sampled test-vectors. Default is `float`. 90 | """ 91 | 92 | 93 | class SphereSampler(AbstractSampler, strict=True): 94 | r"""Sphere distribution sampler. 95 | 96 | Generates samples $X_1, \dotsc, X_n$ uniformly distributed on the surface of a 97 | $k$ dimensional sphere (i.e. $k-1$-sphere) with radius $\sqrt{n}$. Internally, 98 | this operates by sampling standard normal variates, and then rescaling such that 99 | each $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$. 100 | 101 | !!! Note 102 | Supports float and complex-valued types. 103 | """ 104 | 105 | dtype: DTypeLike = eqx.field(converter=canonicalize_dtype, default=float) 106 | 107 | def __check_init__(self): 108 | if not issubdtype(self.dtype, jnp.inexact): 109 | raise ValueError(f"SphereSampler requires float or complex dtype. Found {self.dtype}.") 110 | 111 | def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: 112 | samples = rdm.normal(key, (n, k), self.dtype) 113 | return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0)) 114 | 115 | 116 | SphereSampler.__init__.__doc__ = r"""**Arguments:** 117 | 118 | - `dtype`: numeric representation for sampled test-vectors. Default is `float`. 119 | """ 120 | 121 | 122 | class RademacherSampler(AbstractSampler, strict=True): 123 | r"""Rademacher distribution sampler. 124 | 125 | Generates samples $X_{ij} \sim \mathcal{U}(-1, +1)$ for $i \in [n]$ and $j \in [k]$. 126 | 127 | !!! Note 128 | Supports integer, float, and complex-valued types. 129 | """ 130 | 131 | dtype: DTypeLike = eqx.field(converter=canonicalize_dtype, default=int) 132 | 133 | def __check_init__(self): 134 | if not issubdtype(self.dtype, jnp.number): 135 | raise ValueError(f"RademacherSampler requires numeric dtype. Found {self.dtype}.") 136 | 137 | def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Num[Array, "n k"]: 138 | return rdm.rademacher(key, (n, k), self.dtype) 139 | 140 | 141 | RademacherSampler.__init__.__doc__ = r"""**Arguments:** 142 | 143 | - `dtype`: numeric representation for sampled test-vectors. Default is `int`. 144 | """ 145 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 MancusoLab. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 MancusoLab. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | import equinox.internal as eqxi 18 | import jax 19 | 20 | 21 | jax.config.update("jax_enable_x64", True) 22 | 23 | 24 | @pytest.fixture 25 | def getkey(): 26 | return eqxi.GetKey() 27 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 MancusoLab. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax.numpy as jnp 16 | import jax.random as rdm 17 | import lineax as lx 18 | 19 | 20 | def has_tag(tags, tag): 21 | return tag is tags or (isinstance(tags, tuple) and tag in tags) 22 | 23 | 24 | def construct_matrix(getkey, tags, size, dtype): 25 | matrix = rdm.normal(getkey(), (size, size), dtype=dtype) 26 | if has_tag(tags, lx.diagonal_tag): 27 | matrix = jnp.diag(jnp.diag(matrix)) 28 | if has_tag(tags, lx.symmetric_tag): 29 | matrix = matrix + matrix.T 30 | if has_tag(tags, lx.lower_triangular_tag): 31 | matrix = jnp.tril(matrix) 32 | if has_tag(tags, lx.upper_triangular_tag): 33 | matrix = jnp.triu(matrix) 34 | if has_tag(tags, lx.unit_diagonal_tag): 35 | matrix = matrix.at[jnp.arange(size), jnp.arange(size)].set(1) 36 | if has_tag(tags, lx.tridiagonal_tag): 37 | diagonal = jnp.diag(jnp.diag(matrix)) 38 | upper_diagonal = jnp.diag(jnp.diag(matrix, k=1), k=1) 39 | lower_diagonal = jnp.diag(jnp.diag(matrix, k=-1), k=-1) 40 | matrix = lower_diagonal + diagonal + upper_diagonal 41 | if has_tag(tags, lx.positive_semidefinite_tag): 42 | matrix = matrix @ matrix.T.conj() 43 | if has_tag(tags, lx.negative_semidefinite_tag): 44 | matrix = -matrix @ matrix.T.conj() 45 | 46 | return matrix 47 | -------------------------------------------------------------------------------- /tests/test_trace.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 MancusoLab. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | import jax.numpy as jnp 18 | import lineax as lx 19 | 20 | import traceax as tr 21 | 22 | from .helpers import ( 23 | construct_matrix, 24 | ) 25 | 26 | 27 | @pytest.mark.parametrize("estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator())) 28 | @pytest.mark.parametrize("k", (5, 10, 50)) 29 | @pytest.mark.parametrize( 30 | "tags", 31 | ( 32 | None, 33 | lx.diagonal_tag, 34 | lx.symmetric_tag, 35 | lx.lower_triangular_tag, 36 | lx.upper_triangular_tag, 37 | lx.tridiagonal_tag, 38 | lx.unit_diagonal_tag, 39 | ), 40 | ) 41 | @pytest.mark.parametrize("size", (5, 50, 500)) 42 | @pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) 43 | def test_matrix_linop(getkey, estimator, k, tags, size, dtype): 44 | k = min(k, size) 45 | matrix = construct_matrix(getkey, tags, size, dtype) 46 | operator = lx.MatrixLinearOperator(matrix, tags=tags) 47 | result = estimator.estimate(getkey(), operator, k) 48 | 49 | assert result is not None 50 | assert result[0] is not None 51 | assert jnp.isfinite(result[0]) 52 | 53 | 54 | @pytest.mark.parametrize("estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator())) 55 | @pytest.mark.parametrize("k", (5, 10, 50)) 56 | @pytest.mark.parametrize("size", (5, 50, 500)) 57 | @pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) 58 | def test_diag_linop(getkey, estimator, k, size, dtype): 59 | k = min(k, size) 60 | matrix = construct_matrix(getkey, lx.diagonal_tag, size, dtype) 61 | operator = lx.DiagonalLinearOperator(jnp.diag(matrix)) 62 | result = estimator.estimate(getkey(), operator, k) 63 | 64 | assert result is not None 65 | assert result[0] is not None 66 | assert jnp.isfinite(result[0]) 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator(), tr.XNysTraceEstimator()) 71 | ) 72 | @pytest.mark.parametrize("k", (5, 10, 50)) 73 | @pytest.mark.parametrize("tags", (lx.positive_semidefinite_tag, lx.negative_semidefinite_tag)) 74 | @pytest.mark.parametrize("size", (5, 50, 500)) 75 | @pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) 76 | def test_nsd_psd_matrix_linop(getkey, estimator, k, tags, size, dtype): 77 | k = min(k, size) 78 | matrix = construct_matrix(getkey, tags, size, dtype) 79 | operator = lx.MatrixLinearOperator(matrix, tags=tags) 80 | result = estimator.estimate(getkey(), operator, k) 81 | 82 | assert result is not None 83 | assert result[0] is not None 84 | assert jnp.isfinite(result[0]) 85 | --------------------------------------------------------------------------------