├── .DS_Store ├── .github └── workflows │ └── version-bump.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── .DS_Store ├── Makefile ├── _static │ ├── .DS_Store │ ├── custom.css │ └── logo.png ├── _templates │ ├── header.html │ └── sidebar_toc.html ├── bibliography.bib ├── conf.py ├── contributing.md ├── examples ├── favicon.ico ├── index.rst ├── make.bat ├── reference │ ├── kernel │ │ ├── index.rst │ │ └── kernel.rst │ └── nn │ │ ├── functional.rst │ │ ├── index.rst │ │ ├── linalg.rst │ │ ├── nn.rst │ │ └── stats.rst └── requirements.txt ├── examples ├── detecting_independence │ ├── .gitignore │ └── detecting_independence.ipynb └── lorenz63 │ ├── .gitignore │ └── main.py ├── linear_operator_learning ├── __init__.py ├── kernel │ ├── __init__.py │ ├── linalg.py │ ├── regressors.py │ ├── structs.py │ └── utils.py ├── nn │ ├── __init__.py │ ├── functional.py │ ├── linalg.py │ ├── modules │ │ ├── __init__.py │ │ ├── ema_covariance.py │ │ ├── loss.py │ │ ├── mlp.py │ │ ├── resnet.py │ │ └── simnorm.py │ ├── regressors.py │ ├── stats.py │ └── structs.py └── py.typed ├── logo.png ├── logo.svg ├── pyproject.toml └── uv.lock /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/.DS_Store -------------------------------------------------------------------------------- /.github/workflows/version-bump.yml: -------------------------------------------------------------------------------- 1 | name: Version Bump 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | # References: 9 | # https://docs.astral.sh/uv/guides/integration/github/ 10 | # https://stackoverflow.com/a/58393457/10344434 11 | # https://github.com/mbarkhau/bumpver?tab=readme-ov-file#reference 12 | jobs: 13 | bump-version: 14 | name: Bump package version 15 | if: "!contains(github.event.head_commit.message, 'Bump version')" 16 | runs-on: ubuntu-latest 17 | 18 | permissions: 19 | contents: write 20 | id-token: write 21 | 22 | steps: 23 | - name: Checkout repository 24 | uses: actions/checkout@v4 25 | 26 | - name: Install uv 27 | uses: astral-sh/setup-uv@v5 28 | 29 | - name: Set up Python 30 | run: uv python install 31 | 32 | - name: Install the project 33 | run: uv sync --all-extras --dev 34 | 35 | - name: Bump version 36 | run: uv run bumpver update --patch 37 | 38 | - name: Bump uv lock 39 | run: uv lock --upgrade-package linear-operator-learning 40 | 41 | - name: Commit 42 | run: | 43 | git config --global user.name 'Alek Frohlich' 44 | git config --global user.email 'alekfrohlich@users.noreply.github.com' 45 | git commit -am "Bump version" 46 | git push 47 | 48 | - name: Build and Publish to PyPI 49 | run: | 50 | uv build 51 | uv publish --trusted-publishing always 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | .idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.9.4 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | types_or: [ python, pyi ] 9 | args: [ --fix ] 10 | # Run the formatter. 11 | - id: ruff-format 12 | types_or: [ python, pyi ] 13 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version, and other tools you might need 8 | build: 9 | os: ubuntu-24.04 10 | tools: 11 | python: "3.11" 12 | 13 | # Build documentation in the "docs/" directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | 17 | # Optionally, but recommended, 18 | # declare the Python requirements required to build your documentation 19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: docs/requirements.txt 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 CSML @ IIT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | SVG Image 3 |

4 | 5 | ## Install 6 | 7 | To install this package as a dependency, run: 8 | 9 | ```bash 10 | pip install linear-operator-learning 11 | ``` 12 | 13 | ## Development 14 | 15 | To develop this project, please setup the [`uv` project manager](https://astral.sh/uv) by running the following commands: 16 | 17 | ```bash 18 | curl -LsSf https://astral.sh/uv/install.sh | sh 19 | git clone git@github.com:CSML-IIT-UCL/linear_operator_learning.git 20 | cd linear_operator_learning 21 | uv sync --dev 22 | uv run pre-commit install 23 | ``` 24 | 25 | ### Optional 26 | Set up your IDE to automatically apply the `ruff` styling. 27 | - [VS Code](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) 28 | - [PyCharm](https://plugins.jetbrains.com/plugin/20574-ruff) 29 | 30 | ## Development principles 31 | 32 | Please adhere to the following principles while contributing to the project: 33 | 34 | 1. Adopt a functional style of programming. Avoid abstractions (classes) at all cost. 35 | 2. To add a new feature, create a branch and when done open a Pull Request. You should _**not**_ approve your own PRs. 36 | 3. The package contains both `numpy` and `torch` based algorithms. Let's keep them separated. 37 | 4. The functions shouldn't change the `dtype` or device of the inputs (that is, keep a functional approach). 38 | 5. Try to complement your contributions with simple examples to be added in the `examples` folder. If you need some additional dependency add it to the `examples` dependency group as `uv add --group examples _your_dependency_`. -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/docs/.DS_Store -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/docs/_static/.DS_Store -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | h2 { 2 | border-bottom-width: 0px !important; 3 | } 4 | 5 | img { 6 | border-radius: 10px !important; 7 | } 8 | 9 | :root { 10 | --mystnb-stdout-bg-color: theme(colors.background); 11 | --mystnb-source-bg-color: theme(colors.background); 12 | } 13 | 14 | div.cell_input { 15 | border-color: hsl(var(--border)) !important; 16 | border-width: 1px !important; 17 | border-radius: var(--radius) !important; 18 | } 19 | 20 | div.cell_output .output { 21 | border: none !important; 22 | } -------------------------------------------------------------------------------- /docs/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/docs/_static/logo.png -------------------------------------------------------------------------------- /docs/_templates/header.html: -------------------------------------------------------------------------------- 1 | {% extends "!header.html" %} 2 | 3 | {% block header_logo %} 4 | 94 | 95 | 96 | {%- if logo_url %} 97 | Logo 98 | {%- endif -%} 99 | {%- if theme_logo_dark and not logo_url %} 100 | 102 | {%- endif -%} 103 | {%- if theme_logo_light and not logo_url %} 104 | Logo 106 | {%- endif -%} 107 | 108 | 109 | 110 | {% endblock header_logo %} 111 | 112 | {%- block header_right %} 113 |
114 | {%- if docsearch or hasdoc('search') %} 115 |
116 | 130 | 159 | 160 |
161 | {%- include "searchbox.html" %} 162 |
163 | 164 | 184 |
185 | {%- endif %} 186 | 187 | {%- block extra_header_link_icons %} 188 | 220 | {%- endblock extra_header_link_icons %} 221 |
222 | {%- endblock header_right %} -------------------------------------------------------------------------------- /docs/_templates/sidebar_toc.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/bibliography.bib: -------------------------------------------------------------------------------- 1 | 2 | @article{Kostic2022, 3 | title={Learning dynamical systems via koopman operator regression in reproducing kernel hilbert spaces}, 4 | author={Kostic, Vladimir and Novelli, Pietro and Maurer, Andreas and Ciliberto, Carlo and Rosasco, Lorenzo and Pontil, Massimiliano}, 5 | journal={Advances in Neural Information Processing Systems}, 6 | volume={35}, 7 | pages={4017--4031}, 8 | year={2022} 9 | } 10 | 11 | @inproceedings{he2016deep, 12 | title={Deep residual learning for image recognition}, 13 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, 14 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 15 | pages={770--778}, 16 | year={2016} 17 | } 18 | 19 | @article{lavoie2022simplicial, 20 | title={Simplicial embeddings in self-supervised learning and downstream classification}, 21 | author={Lavoie, Samuel and Tsirigotis, Christos and Schwarzer, Max and Vani, Ankit and Noukhovitch, Michael and Kawaguchi, Kenji and Courville, Aaron}, 22 | journal={arXiv preprint arXiv:2204.00616}, 23 | year={2022} 24 | } 25 | 26 | @inproceedings{Kostic2024NCP, 27 | author = {Kostic, Vladimir and Pacreau, Gr\'{e}goire and Turri, Giacomo and Novelli, Pietro and Lounici, Karim and Pontil, Massimiliano}, 28 | booktitle = {Advances in Neural Information Processing Systems}, 29 | title = {Neural Conditional Probability for Uncertainty Quantification}, 30 | url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/705b97ecb07ae86524d438abac97a3e2-Paper-Conference.pdf}, 31 | year = {2024} 32 | } 33 | 34 | 35 | @inproceedings{Kostic2023SpectralRates, 36 | title={Sharp Spectral Rates for Koopman Operator Learning}, 37 | author={Vladimir R Kostic and Karim Lounici and Pietro Novelli and Massimiliano Pontil}, 38 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 39 | year={2023}, 40 | url={https://openreview.net/forum?id=Lt3jqxsbVO} 41 | } 42 | 43 | @inproceedings{ 44 | Kostic2023DPNets, 45 | title={Learning invariant representations of time-homogeneous stochastic dynamical systems}, 46 | author={Vladimir R. Kostic and Pietro Novelli and Riccardo Grazzi and Karim Lounici and Massimiliano Pontil}, 47 | booktitle={The Twelfth International Conference on Learning Representations}, 48 | year={2024}, 49 | url={https://openreview.net/forum?id=twSnZwiOIm} 50 | } 51 | 52 | @inproceedings{Fathi2023, 53 | title={Course Correcting Koopman Representations}, 54 | author={Mahan Fathi and Clement Gehring and Jonathan Pilault and David Kanaa and Pierre-Luc Bacon and Ross Goroshin}, 55 | booktitle={The Twelfth International Conference on Learning Representations}, 56 | year={2024}, 57 | url={https://openreview.net/forum?id=A18gWgc5mi} 58 | } 59 | 60 | @inproceedings{Azencot2020CAE, 61 | title={Forecasting sequential data using consistent Koopman autoencoders}, 62 | author={Azencot, Omri and Erichson, N Benjamin and Lin, Vanessa and Mahoney, Michael}, 63 | booktitle={International Conference on Machine Learning}, 64 | pages={475--485}, 65 | year={2020}, 66 | organization={PMLR} 67 | } 68 | 69 | @article{Arbabi2017, 70 | doi = {10.1137/17m1125236}, 71 | url = {https://doi.org/10.1137/17m1125236}, 72 | year = {2017}, 73 | month = jan, 74 | publisher = {Society for Industrial {\&} Applied Mathematics ({SIAM})}, 75 | volume = {16}, 76 | number = {4}, 77 | pages = {2096--2126}, 78 | author = {Hassan Arbabi and Igor Mezi{\'{c}}}, 79 | title = {Ergodic Theory, Dynamic Mode Decomposition, and Computation of Spectral Properties of the Koopman Operator}, 80 | journal = {{SIAM} Journal on Applied Dynamical Systems} 81 | } 82 | 83 | @inproceedings{Meanti2023, 84 | title={Estimating Koopman operators with sketching to provably learn large scale dynamical systems}, 85 | author={Giacomo Meanti and Antoine Chatalic and Vladimir R Kostic and Pietro Novelli and Massimiliano Pontil and Lorenzo Rosasco}, 86 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 87 | year={2023}, 88 | url={https://openreview.net/forum?id=GItLpB1vhK} 89 | } 90 | 91 | @inproceedings{Morton2018, 92 | author = {Morton, Jeremy and Witherden, Freddie D. and Jameson, Antony and Kochenderfer, Mykel J.}, 93 | title = {Deep Dynamical Modeling and Control of Unsteady Fluid Flows}, 94 | year = {2018}, 95 | publisher = {Curran Associates Inc.}, 96 | address = {Red Hook, NY, USA}, 97 | abstract = {The design of flow control systems remains a challenge due to the nonlinear nature of the equations that govern fluid flow. However, recent advances in computational fluid dynamics (CFD) have enabled the simulation of complex fluid flows with high accuracy, opening the possibility of using learning-based approaches to facilitate controller design. We present a method for learning the forced and unforced dynamics of airflow over a cylinder directly from CFD data. The proposed approach, grounded in Koopman theory, is shown to produce stable dynamical models that can predict the time evolution of the cylinder system over extended time horizons. Finally, by performing model predictive control with the learned dynamical models, we are able to find a straightforward, interpretable control law for suppressing vortex shedding in the wake of the cylinder.}, 98 | booktitle = {Proceedings of the 32nd International Conference on Neural Information Processing Systems}, 99 | pages = {9278–9288}, 100 | numpages = {11}, 101 | location = {Montr\'{e}al, Canada}, 102 | series = {NIPS'18} 103 | } 104 | 105 | @article{Schmid2010, 106 | title = {Dynamic mode decomposition of numerical and experimental data}, 107 | volume = {656}, 108 | ISSN = {1469-7645}, 109 | url = {http://dx.doi.org/10.1017/s0022112010001217}, 110 | DOI = {10.1017/s0022112010001217}, 111 | journal = {Journal of Fluid Mechanics}, 112 | publisher = {Cambridge University Press (CUP)}, 113 | author = {Schmid, Peter J.}, 114 | year = {2010}, 115 | month = jul, 116 | pages = {5–28} 117 | } 118 | 119 | @article{Williams2015_KDMD, 120 | title = {A kernel-based method for data-driven koopman spectral analysis}, 121 | volume = {2}, 122 | ISSN = {2158-2505}, 123 | url = {http://dx.doi.org/10.3934/jcd.2015005}, 124 | DOI = {10.3934/jcd.2015005}, 125 | number = {2}, 126 | journal = {Journal of Computational Dynamics}, 127 | publisher = {American Institute of Mathematical Sciences (AIMS)}, 128 | author = {O. Williams, Matthew and W. Rowley, Clarence and G. Kevrekidis, Ioannis}, 129 | year = {2015}, 130 | pages = {247–265} 131 | } 132 | 133 | @article{Williams2015_EDMD, 134 | title = {A Data–Driven Approximation of the Koopman Operator: Extending Dynamic Mode Decomposition}, 135 | volume = {25}, 136 | ISSN = {1432-1467}, 137 | url = {http://dx.doi.org/10.1007/s00332-015-9258-5}, 138 | DOI = {10.1007/s00332-015-9258-5}, 139 | number = {6}, 140 | journal = {Journal of Nonlinear Science}, 141 | publisher = {Springer Science and Business Media LLC}, 142 | author = {Williams, Matthew O. and Kevrekidis, Ioannis G. and Rowley, Clarence W.}, 143 | year = {2015}, 144 | month = jun, 145 | pages = {1307–1346} 146 | } 147 | 148 | @article{Lusch2018, 149 | doi = {10.1038/s41467-018-07210-0}, 150 | url = {https://doi.org/10.1038/s41467-018-07210-0}, 151 | year = {2018}, 152 | month = nov, 153 | publisher = {Springer Science and Business Media {LLC}}, 154 | volume = {9}, 155 | number = {1}, 156 | author = {Bethany Lusch and J. Nathan Kutz and Steven L. Brunton}, 157 | title = {Deep learning for universal linear embeddings of nonlinear dynamics}, 158 | journal = {Nature Communications} 159 | } 160 | 161 | @article{Wu2019, 162 | doi = {10.1007/s00332-019-09567-y}, 163 | url = {https://doi.org/10.1007/s00332-019-09567-y}, 164 | year = {2019}, 165 | month = aug, 166 | publisher = {Springer Science and Business Media {LLC}}, 167 | volume = {30}, 168 | number = {1}, 169 | pages = {23--66}, 170 | author = {Hao Wu and Frank No{\'{e}}}, 171 | title = {Variational Approach for Learning Markov Processes from Time Series Data}, 172 | journal = {Journal of Nonlinear Science} 173 | } 174 | @article{Mardt2018, 175 | doi = {10.1038/s41467-017-02388-1}, 176 | url = {https://doi.org/10.1038/s41467-017-02388-1}, 177 | year = {2018}, 178 | month = jan, 179 | publisher = {Springer Science and Business Media {LLC}}, 180 | volume = {9}, 181 | number = {1}, 182 | author = {Andreas Mardt and Luca Pasquali and Hao Wu and Frank No{\'{e}}}, 183 | title = {{VAMPnets} for deep learning of molecular kinetics}, 184 | journal = {Nature Communications} 185 | } 186 | 187 | @misc{Turri2023, 188 | doi = {10.48550/ARXIV.2312.17348}, 189 | url = {https://arxiv.org/abs/2312.17348}, 190 | author = {Turri, Giacomo and Kostic, Vladimir and Novelli, Pietro and Pontil, Massimiliano}, 191 | keywords = {Machine Learning (cs.LG), Numerical Analysis (math.NA), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Mathematics, FOS: Mathematics}, 192 | title = {A randomized algorithm to solve reduced rank operator regression}, 193 | publisher = {arXiv}, 194 | year = {2023}, 195 | copyright = {arXiv.org perpetual, non-exclusive license} 196 | } 197 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | from sphinxawesome_theme.postprocess import Icons 9 | 10 | html_permalinks_icon = Icons.permalinks_icon # SVG as a string 11 | 12 | project = "Linear Operator Learning" 13 | copyright = "2025, Linear Operator Learning Team" 14 | author = "Linear Operator Learning Team" 15 | 16 | # -- General configuration --------------------------------------------------- 17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 18 | 19 | extensions = [ 20 | "sphinx.ext.autodoc", 21 | "sphinx.ext.autosummary", 22 | "sphinx.ext.napoleon", 23 | 'sphinx.ext.viewcode', 24 | "sphinxawesome_theme", 25 | "sphinxcontrib.bibtex", 26 | "sphinx_design", 27 | "myst_nb" 28 | ] 29 | 30 | myst_enable_extensions = ["amsmath", "dollarmath", "html_image"] 31 | 32 | templates_path = ["_templates"] 33 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 34 | 35 | bibtex_bibfiles = ["bibliography.bib"] 36 | 37 | 38 | # -- Options for HTML output ------------------------------------------------- 39 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 40 | 41 | html_theme = "sphinxawesome_theme" 42 | autodoc_class_signature = "separated" 43 | autoclass_content = "class" 44 | 45 | autodoc_typehints = "description" 46 | autodoc_member_order = "groupwise" 47 | napoleon_preprocess_types = True 48 | napoleon_use_rtype = False 49 | autodoc_mock_imports = ["escnn", "escnn.group"] 50 | 51 | master_doc = "index" 52 | 53 | source_suffix = { 54 | ".rst": "restructuredtext", 55 | ".txt": "restructuredtext", 56 | ".md": "myst-nb", 57 | ".ipynb": "myst-nb", 58 | } 59 | 60 | # Favicon configuration 61 | # html_favicon = '_static/favicon.ico' 62 | 63 | # Configure syntax highlighting for Awesome Sphinx Theme 64 | pygments_style = "tango" 65 | pygments_style_dark = "material" 66 | 67 | # Additional theme configuration 68 | html_title = "Linear Operator Learning" 69 | html_theme_options = { 70 | "show_prev_next": False, 71 | "show_scrolltop": True, 72 | "extra_header_link_icons": { 73 | "GitHub": { 74 | "link": "https://github.com/CSML-IIT-UCL/linear_operator_learning", 75 | "icon": """""", 76 | }, 77 | }, 78 | "show_breadcrumbs": True, 79 | } 80 | 81 | html_sidebars = { 82 | "**": ["sidebar_main_nav_links.html", "sidebar_toc.html"] 83 | } 84 | nb_execution_mode = "off" 85 | 86 | html_css_files = ["custom.css"] 87 | 88 | html_favicon = "favicon.ico" 89 | html_static_path = ["_static"] 90 | templates_path = ["_templates"] -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | (contributing)= 2 | # Developer guide 3 | 4 | To develop this project, please setup the [`uv` project manager](https://astral.sh/uv) by running the following commands: 5 | 6 | ```bash 7 | curl -LsSf https://astral.sh/uv/install.sh | sh 8 | git clone git@github.com:CSML-IIT-UCL/linear_operator_learning.git 9 | cd linear_operator_learning 10 | uv sync --dev 11 | uv run pre-commit install 12 | ``` 13 | 14 | ### Optional 15 | Set up your IDE to automatically apply the `ruff` styling. 16 | - [VS Code](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) 17 | - [PyCharm](https://plugins.jetbrains.com/plugin/20574-ruff) 18 | 19 | ## Development principles 20 | 21 | Please adhere to the following principles while contributing to the project: 22 | 23 | 1. Adopt a functional style of programming. Avoid abstractions (classes) at all cost. 24 | 2. To add a new feature, create a branch and when done open a Pull Request. You should _**not**_ approve your own PRs. 25 | 3. The package contains both `numpy` and `torch` based algorithms. Let's keep them separated. 26 | 4. The functions shouldn't change the `dtype` or device of the inputs (that is, keep a functional approach). 27 | 5. Try to complement your contributions with simple examples to be added in the `examples` folder. If you need some additional dependency add it to the `examples` dependency group as `uv add --group examples _your_dependency_`. -------------------------------------------------------------------------------- /docs/examples: -------------------------------------------------------------------------------- 1 | ../examples -------------------------------------------------------------------------------- /docs/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/docs/favicon.ico -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ======================== 2 | Linear Operator Learning 3 | ======================== 4 | 5 | Linear Operator Learning — LOL for short — is package to learn linear operators, developed by the `Computational Statistics & Machine Learning lab `_ at the Italian Institute of Technology. 6 | 7 | | 8 | 9 | You can install it with :code:`pip` as 10 | 11 | .. code:: 12 | 13 | pip install linear-operator-learning 14 | 15 | | 16 | 17 | If you want to contibute to the project, please follow :ref:`these guidelines `. 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :caption: Getting Started 22 | :hidden: 23 | 24 | Quickstart 25 | contributing.md 26 | 27 | .. toctree:: 28 | :maxdepth: 2 29 | :caption: Examples 30 | :hidden: 31 | 32 | Independence Testing 33 | 34 | .. toctree:: 35 | :maxdepth: 2 36 | :caption: API Reference 37 | :hidden: 38 | 39 | reference/kernel/index 40 | reference/nn/index -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/reference/kernel/index.rst: -------------------------------------------------------------------------------- 1 | .. _kernel_index: 2 | ============== 3 | Kernel Methods 4 | ============== 5 | 6 | Algorithms and utilities to learn linear operators via kernel methods. 7 | 8 | .. toctree:: 9 | 10 | kernel 11 | -------------------------------------------------------------------------------- /docs/reference/kernel/kernel.rst: -------------------------------------------------------------------------------- 1 | .. _kernel_reference: 2 | ============== 3 | :code:`kernel` 4 | ============== 5 | .. module:: linear_operator_learning.kernel 6 | 7 | .. rst-class:: lead 8 | 9 | Kernel Methods 10 | 11 | Table of Contents 12 | ----------------- 13 | 14 | - :ref:`Regressors ` 15 | - :ref:`Types ` 16 | - :ref:`Linear Algebra ` 17 | - :ref:`Utilities ` 18 | 19 | 20 | .. _kernel_regressors: 21 | Regressors 22 | ---------- 23 | 24 | Common 25 | ~~~~~~ 26 | 27 | .. autofunction:: linear_operator_learning.kernel.predict 28 | 29 | .. autofunction:: linear_operator_learning.kernel.eig 30 | 31 | .. autofunction:: linear_operator_learning.kernel.evaluate_eigenfunction 32 | 33 | .. _rrr: 34 | Reduced Rank 35 | ~~~~~~~~~~~~ 36 | .. autofunction:: linear_operator_learning.kernel.reduced_rank 37 | 38 | .. autofunction:: linear_operator_learning.kernel.nystroem_reduced_rank 39 | 40 | .. autofunction:: linear_operator_learning.kernel.rand_reduced_rank 41 | 42 | .. _pcr: 43 | Principal Component Regression 44 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 45 | .. autofunction:: linear_operator_learning.kernel.pcr 46 | 47 | .. autofunction:: linear_operator_learning.kernel.nystroem_pcr 48 | 49 | .. _kernel_types: 50 | Types 51 | ----- 52 | 53 | .. autoclass:: linear_operator_learning.kernel.structs.FitResult 54 | :members: 55 | 56 | .. autoclass:: linear_operator_learning.kernel.structs.EigResult 57 | :members: 58 | 59 | .. _kernel_linalg: 60 | Linear Algebra Utilities 61 | ------------------------ 62 | 63 | .. autofunction:: linear_operator_learning.kernel.linalg.weighted_norm 64 | 65 | .. autofunction:: linear_operator_learning.kernel.linalg.stable_topk 66 | 67 | .. autofunction:: linear_operator_learning.kernel.linalg.add_diagonal_ 68 | 69 | .. _kernel_utils: 70 | General Utilities 71 | ----------------- 72 | 73 | .. autofunction:: linear_operator_learning.kernel.utils.topk 74 | 75 | .. autofunction:: linear_operator_learning.kernel.utils.sanitize_complex_conjugates 76 | 77 | 78 | .. footbibliography:: -------------------------------------------------------------------------------- /docs/reference/nn/functional.rst: -------------------------------------------------------------------------------- 1 | .. _nn_functional: 2 | ===================== 3 | :code:`nn.functional` 4 | ===================== 5 | 6 | .. rst-class:: lead 7 | 8 | Functional implementations 9 | 10 | .. module:: linear_operator_learning.nn 11 | 12 | Table of Contents 13 | ~~~~~~~~~~~~~~~~~ 14 | 15 | - :ref:`Loss Functions ` 16 | - :ref:`Regularization Functions ` 17 | 18 | .. _nn_func_loss_fns: 19 | Loss Functions 20 | ~~~~~~~~~~~~~~ 21 | 22 | .. autofunction:: linear_operator_learning.nn.functional.l2_contrastive_loss 23 | 24 | .. autofunction:: linear_operator_learning.nn.functional.kl_contrastive_loss 25 | 26 | .. autofunction:: linear_operator_learning.nn.functional.vamp_loss 27 | 28 | .. autofunction:: linear_operator_learning.nn.functional.dp_loss 29 | 30 | .. _nn_func_reg_fns: 31 | Regularization Functions 32 | ~~~~~~~~~~~~~~~~~~~~~~~~ 33 | 34 | .. autofunction:: linear_operator_learning.nn.functional.orthonormal_fro_reg 35 | 36 | .. autofunction:: linear_operator_learning.nn.functional.orthonormal_logfro_reg 37 | 38 | .. footbibliography:: -------------------------------------------------------------------------------- /docs/reference/nn/index.rst: -------------------------------------------------------------------------------- 1 | .. _nn_index: 2 | =============== 3 | Neural Networks 4 | =============== 5 | 6 | Functions and modules to learn linear operators via neural networks. 7 | 8 | .. toctree:: 9 | 10 | nn 11 | functional 12 | stats 13 | linalg 14 | -------------------------------------------------------------------------------- /docs/reference/nn/linalg.rst: -------------------------------------------------------------------------------- 1 | .. _nn_linalg: 2 | ===================== 3 | :code:`nn.linalg` 4 | ===================== 5 | 6 | .. rst-class:: lead 7 | 8 | Linear Algebra Utilities 9 | 10 | .. module:: linear_operator_learning.nn 11 | 12 | .. autofunction:: linear_operator_learning.nn.linalg.sqrtmh 13 | 14 | .. footbibliography:: -------------------------------------------------------------------------------- /docs/reference/nn/nn.rst: -------------------------------------------------------------------------------- 1 | .. _nn: 2 | ========== 3 | :code:`nn` 4 | ========== 5 | 6 | .. rst-class:: lead 7 | 8 | Neural Network Modules 9 | 10 | .. module:: linear_operator_learning.nn 11 | 12 | Table of Contents 13 | ~~~~~~~~~~~~~~~~~ 14 | 15 | - :ref:`Regressors ` 16 | - :ref:`Loss Functions ` 17 | - :ref:`Modules ` 18 | 19 | .. _nn_regressors: 20 | Regressors (see also :ref:`kernel regressors `) 21 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 22 | 23 | .. autofunction:: linear_operator_learning.nn.ridge_least_squares 24 | 25 | .. autofunction:: linear_operator_learning.nn.eig 26 | 27 | .. autofunction:: linear_operator_learning.nn.evaluate_eigenfunction 28 | 29 | .. _nn_loss_fns: 30 | Loss Functions 31 | ~~~~~~~~~~~~~~ 32 | 33 | .. autoclass:: linear_operator_learning.nn.L2ContrastiveLoss 34 | :members: 35 | :exclude-members: __init__, __new__ 36 | 37 | .. autoclass:: linear_operator_learning.nn.KLContrastiveLoss 38 | :members: 39 | :exclude-members: __init__, __new__ 40 | 41 | .. autoclass:: linear_operator_learning.nn.VampLoss 42 | :members: 43 | :exclude-members: __init__, __new__ 44 | 45 | .. autoclass:: linear_operator_learning.nn.DPLoss 46 | :members: 47 | :exclude-members: __init__, __new__ 48 | 49 | .. _nn_modules: 50 | Modules 51 | ~~~~~~~ 52 | 53 | .. autoclass:: linear_operator_learning.nn.MLP 54 | :members: 55 | :exclude-members: __init__, __new__, forward 56 | 57 | .. autoclass:: linear_operator_learning.nn.ResNet 58 | :members: 59 | :exclude-members: __init__, __new__, forward 60 | 61 | .. autoclass:: linear_operator_learning.nn.SimNorm 62 | :members: 63 | :exclude-members: __init__, __new__, forward 64 | 65 | .. autoclass:: linear_operator_learning.nn.EMACovariance 66 | :members: 67 | :exclude-members: __init__, __new__, forward 68 | 69 | .. footbibliography:: 70 | 71 | -------------------------------------------------------------------------------- /docs/reference/nn/stats.rst: -------------------------------------------------------------------------------- 1 | .. _nn_stats: 2 | ===================== 3 | :code:`nn.stats` 4 | ===================== 5 | 6 | .. rst-class:: lead 7 | 8 | Statistics Utilities 9 | 10 | .. module:: linear_operator_learning.nn 11 | 12 | .. autofunction:: linear_operator_learning.nn.stats.covariance 13 | 14 | .. autofunction:: linear_operator_learning.nn.stats.cov_norm_squared_unbiased 15 | 16 | .. autofunction:: linear_operator_learning.nn.stats.cross_cov_norm_squared_unbiased 17 | 18 | .. autofunction:: linear_operator_learning.nn.stats.whitening 19 | 20 | .. footbibliography:: -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx==7.3.7 2 | sphinxawesome-theme==5.2.0 3 | sphinxcontrib-applehelp==2.0.0 4 | sphinxcontrib-devhelp==2.0.0 5 | sphinxcontrib-htmlhelp==2.1.0 6 | sphinxcontrib-jsmath==1.0.1 7 | sphinxcontrib-qthelp==2.0.0 8 | sphinxcontrib-serializinghtml==2.0.0 9 | sphinxcontrib-bibtex 10 | sphinx_design 11 | myst-parser 12 | myst-nb 13 | git+https://github.com/CSML-IIT-UCL/linear_operator_learning -------------------------------------------------------------------------------- /examples/detecting_independence/.gitignore: -------------------------------------------------------------------------------- 1 | runs/ -------------------------------------------------------------------------------- /examples/detecting_independence/detecting_independence.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Detecting Independence\n", 8 | "\n", 9 | "In *Neural Conditional Probability for Uncertainty Quantification* (Kostic et al., 2024), the authors claim that the (deflated) conditional expectation operator can be used to detect the independence of two random variables X and Y by verifying whether it is zero. Here, we show this equivaliance in practice.\n", 10 | "\n", 11 | "## Dataset\n", 12 | "\n", 13 | "We consider the data model\n", 14 | "\n", 15 | "$$Y = tX + (1-t)X',$$\n", 16 | "\n", 17 | "where $X$ and $X'$ are independent standard Gaussians in $\\mathbb{R}$, and $t \\in [0,1]$ is an interpolating factor. This model allows us to explore both extreme cases ($t = 0$ for independence and $t = 1$ where $Y = X$) and the continuum in between, to assess the robustness of NCP in detecting independence." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import torch\n", 27 | "from torch.utils.data import DataLoader, TensorDataset, random_split\n", 28 | "\n", 29 | "\n", 30 | "def make_dataset(n_samples: int = 200, t: float = 0.0):\n", 31 | " \"\"\"Draw sample from data model Y = tX + (1-t)X_, where X and X_ are independent gaussians.\n", 32 | "\n", 33 | " If t = 0, then X and Y are independent. Otherwise, if t->1, X and Y become ever more dependent.\n", 34 | "\n", 35 | " Args:\n", 36 | " n_samples (int, optional): Number of samples. Defaults to 200.\n", 37 | " t (float, optional): Interpolation factor. Defaults to 0.0.\n", 38 | " \"\"\"\n", 39 | " X = torch.normal(mean=0, std=1, size=(n_samples, 1))\n", 40 | " X_ = torch.normal(mean=0, std=1, size=(n_samples, 1))\n", 41 | " Y = t * X + (1 - t) * X_\n", 42 | "\n", 43 | " ds = TensorDataset(X, Y)\n", 44 | "\n", 45 | " # Split data into train and val sets\n", 46 | " train_ds, val_ds = random_split(ds, lengths=[0.85, 0.15])\n", 47 | "\n", 48 | " return train_ds, val_ds" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Learning the conditional expectation operator\n", 56 | "\n", 57 | "Now, we go through the process of learning the conditional expectation operator $\\mathbb{E}_{Y \\mid X}: L^2_Y \\mapsto L^2_X$\n", 58 | "\n", 59 | "$$[\\mathbb{E}_{Y \\mid X}g](x) = \\mathbb{E}[g(Y) \\mid X = x],$$\n", 60 | "\n", 61 | "where $g \\in L^2_Y$. We begin by noting that, if $\\{u_i\\}_{i=1}^\\infty$ and $\\{v_j\\}_{j=1}^\\infty$ were orthonormal bases of $L^2_X$ and $L^2_Y$ [(Orthonormal basis wikipedia)](https://en.wikipedia.org/wiki/Orthonormal_basis), then we could see the conditional expectation operator as an (infinite) matrix $\\mathbf{E}$, where\n", 62 | "\n", 63 | "$$\\mathbf{E}_{ij} = \\langle u_i, \\mathbb{E}_{Y \\mid X}v_j \\rangle_{L^2_X} = \\mathbb{E}_X[u_i(X)[\\mathbb{E}_{Y \\mid X}v_j](X)] = \\mathbb{E}_{XY}[u_i(X)v_j(Y)].$$\n", 64 | "\n", 65 | "Hence, to learn the operator, we \"only\" need to learn the most important parts of $\\mathbf{E}$. The standard way to deal with such problems is to restrict oneself to finite subspaces of $L^2_X$ and $L^2_Y$ and then learn the (finite) matrix there. This corresponds to finding orthonormal functions $\\{u_i\\}_{i=1}^d$ and $\\{v_j\\}_{j=1}^d$ s.t.\n", 66 | "\n", 67 | "$$\\lVert \\mathbb{E}_{Y \\mid X} - \\mathbb{E}_{Y \\mid X}^d \\rVert$$\n", 68 | "\n", 69 | "is minimized, where $d \\in \\mathbb{N}$ is the dimension and $\\mathbb{E}_{Y \\mid X}^d$ is the truncated operator that acts on $span\\{v_j\\}_{j=1}^d$ and $span\\{u_i\\}_{i=1}^d$. The theoretical solution of this problem is given by the truncated (rank d) Singular Value Decomposition [(Low-rank matrix approximation wikipedia)](https://en.wikipedia.org/wiki/Low-rank_approximation), which also has the nice benefit of ordering the bases by their importance a la PCA, meaning that $u_1$ is more important than $u_2$, and so on and so forth.\n", 70 | "\n", 71 | "## A representation learning problem\n", 72 | "\n", 73 | "A key insight of Kostic et al. (2024) is that this problem corresponds to a representation learning problem, where the goal is to find latent variables $u,v \\in \\mathbb{R}^d$ that are\n", 74 | "\n", 75 | "1. [(Whitened, wikipedia)](https://en.wikipedia.org/wiki/Whitening_transformation): $\\mathbb{E}[u_i(X)u_j(X)] = \\mathbb{E}[v_i(Y)v_j(Y)] = \\delta_{ij}$; and\n", 76 | "2. Minimize the contrastive loss\n", 77 | "\n", 78 | "$$\\frac{1}{N(N-1)}\\sum_{i \\neq j}\\langle u_{i}, Sv_{j} \\rangle^2 - \\frac{2}{N}\\sum_{i=1}\\langle u_{i}, Sv_{i} \\rangle,$$\n", 79 | "\n", 80 | "where $S$ is the matrix of the conditional expectation operator on these subspaces/features, which can be learned end-to-end with backpropagation or estimated with running means." 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Representation learning in Torch" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 2, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "from torch.nn import Module\n", 97 | "import torch\n", 98 | "import math\n", 99 | "from torch import Tensor\n", 100 | "\n", 101 | "\n", 102 | "class _Matrix(Module):\n", 103 | " \"\"\"Module representing the matrix form of the truncated conditional expectation operator.\"\"\"\n", 104 | "\n", 105 | " def __init__(\n", 106 | " self,\n", 107 | " dim_u: int,\n", 108 | " dim_v: int,\n", 109 | " ) -> None:\n", 110 | " super().__init__()\n", 111 | " self.weights = torch.nn.Parameter(\n", 112 | " torch.normal(mean=0.0, std=2.0 / math.sqrt(dim_u * dim_v), size=(dim_u, dim_v))\n", 113 | " )\n", 114 | "\n", 115 | " def forward(self, v: Tensor) -> Tensor:\n", 116 | " \"\"\"Forward pass of the truncated conditional expectation operator's matrix (v -> Sv).\"\"\"\n", 117 | " # TODO: Unify Pietro, Giacomo and Dani's ideas on how to normalize\\symmetrize the operator.\n", 118 | " out = v @ self.weights.T\n", 119 | " return out\n", 120 | "\n", 121 | "\n", 122 | "class NCP(Module):\n", 123 | " \"\"\"Neural Conditional Probability in PyTorch.\n", 124 | "\n", 125 | " Args:\n", 126 | " embedding_x (Module): Neural embedding of x.\n", 127 | " embedding_dim_x (int): Latent dimension of x.\n", 128 | " embedding_y (Module): Neural embedding of y.\n", 129 | " embedding_dim_y (int): Latent dimension of y.\n", 130 | " \"\"\"\n", 131 | "\n", 132 | " def __init__(\n", 133 | " self,\n", 134 | " embedding_x: Module,\n", 135 | " embedding_y: Module,\n", 136 | " embedding_dim_x: int,\n", 137 | " embedding_dim_y: int,\n", 138 | " ) -> None:\n", 139 | " super().__init__()\n", 140 | " self.U = embedding_x\n", 141 | " self.V = embedding_y\n", 142 | "\n", 143 | " self.dim_u = embedding_dim_x\n", 144 | " self.dim_v = embedding_dim_y\n", 145 | "\n", 146 | " self.S = _Matrix(self.dim_u, self.dim_v)\n", 147 | "\n", 148 | " def forward(self, x: Tensor, y: Tensor) -> Tensor:\n", 149 | " \"\"\"Forward pass of NCP.\"\"\"\n", 150 | " u = self.U(x)\n", 151 | " v = self.V(y)\n", 152 | " Sv = self.S(v)\n", 153 | "\n", 154 | " return u, Sv" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "## Training NCP\n", 162 | "\n", 163 | "We now how to train the NCP module above with the contrastive loss from `linear_operator_learning.nn` with orthonormality regularization and standard deep learning techniques." 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 3, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "from torch.optim import Optimizer\n", 173 | "\n", 174 | "def train(\n", 175 | " ncp: NCP,\n", 176 | " train_dataloader: DataLoader,\n", 177 | " device: str,\n", 178 | " loss_fn: callable,\n", 179 | " optimizer: Optimizer,\n", 180 | ") -> Tensor:\n", 181 | " \"\"\"Training logic of NCP.\"\"\"\n", 182 | " ncp.train()\n", 183 | " for batch, (x, y) in enumerate(train_dataloader):\n", 184 | " x, y = x.to(device), y.to(device)\n", 185 | "\n", 186 | " u, Sv = ncp(x, y)\n", 187 | " loss = loss_fn(u, Sv)\n", 188 | " loss.backward()\n", 189 | " optimizer.step()\n", 190 | " optimizer.zero_grad()" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 4, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "Using cpu device\n", 203 | "run_id = (0.0, 0)\n", 204 | "run_id = (0.1, 0)\n", 205 | "run_id = (0.2, 0)\n", 206 | "run_id = (0.3, 0)\n", 207 | "run_id = (0.4, 0)\n", 208 | "run_id = (0.5, 0)\n", 209 | "run_id = (0.6, 0)\n", 210 | "run_id = (0.7, 0)\n", 211 | "run_id = (0.8, 0)\n", 212 | "run_id = (0.9, 0)\n", 213 | "run_id = (1.0, 0)\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "import torch\n", 219 | "import linear_operator_learning as lol\n", 220 | "\n", 221 | "\n", 222 | "\n", 223 | "SEED = 1\n", 224 | "REPEATS = 1\n", 225 | "BATCH_SIZE = 256\n", 226 | "N_SAMPLES = 5000\n", 227 | "MLP_PARAMS = dict(\n", 228 | " output_shape=2,\n", 229 | " n_hidden=2,\n", 230 | " layer_size=32,\n", 231 | " activation=torch.nn.ELU,\n", 232 | " bias=False,\n", 233 | " iterative_whitening=False,\n", 234 | ")\n", 235 | "EPOCHS = 100\n", 236 | "WHITENING_N_SAMPLES = 2000\n", 237 | "\n", 238 | "torch.manual_seed(SEED)\n", 239 | "\n", 240 | "# device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else \"cpu\"\n", 241 | "device = \"cpu\"\n", 242 | "print(f\"Using {device} device\")\n", 243 | "\n", 244 | "results = dict()\n", 245 | "for t in torch.linspace(start=0, end=1, steps=11):\n", 246 | " for r in range(REPEATS):\n", 247 | " run_id = (round(t.item(), 2), r)\n", 248 | " print(f\"run_id = {run_id}\")\n", 249 | "\n", 250 | " # Load data_________________________________________________________________________________\n", 251 | " train_ds, val_ds = make_dataset(n_samples=N_SAMPLES, t=t.item())\n", 252 | "\n", 253 | " train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)\n", 254 | " val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)\n", 255 | "\n", 256 | " # Build NCP_________________________________________________________________________________\n", 257 | " ncp = NCP(\n", 258 | " embedding_x=lol.nn.MLP(input_shape=1, **MLP_PARAMS),\n", 259 | " embedding_dim_x=MLP_PARAMS[\"output_shape\"],\n", 260 | " embedding_y=lol.nn.MLP(input_shape=1, **MLP_PARAMS),\n", 261 | " embedding_dim_y=MLP_PARAMS[\"output_shape\"],\n", 262 | " ).to(device)\n", 263 | "\n", 264 | " # Train NCP_________________________________________________________________________________\n", 265 | " loss_fn = lol.nn.L2ContrastiveLoss()\n", 266 | " optimizer = torch.optim.Adam(ncp.parameters(), lr=5e-4)\n", 267 | "\n", 268 | " for epoch in range(EPOCHS):\n", 269 | " train(ncp, train_dl, device, loss_fn, optimizer)\n", 270 | "\n", 271 | " # Extract norm______________________________________________________________________________\n", 272 | " x = torch.normal(mean=0, std=1, size=(WHITENING_N_SAMPLES, 1)).to(device)\n", 273 | " x_ = torch.normal(mean=0, std=1, size=(WHITENING_N_SAMPLES, 1)).to(device)\n", 274 | " y = t * x + (1 - t) * x_\n", 275 | " u, Sv = ncp(x, y)\n", 276 | "\n", 277 | " _, _, svals, _, _ = lol.nn.stats.whitening(u, Sv)\n", 278 | " results[run_id] = svals.max().item()" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "## Plots" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 5, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "data": { 295 | "image/png": "", 296 | "text/plain": [ 297 | "
" 298 | ] 299 | }, 300 | "metadata": {}, 301 | "output_type": "display_data" 302 | } 303 | ], 304 | "source": [ 305 | "import pandas as pd\n", 306 | "import seaborn as sns\n", 307 | "\n", 308 | "results_df = pd.DataFrame(\n", 309 | " data=[(t, r, norm) for ((t, r), norm) in results.items()],\n", 310 | " columns=[\"t\", \"r\", \"norm\"],\n", 311 | ")\n", 312 | "sns.pointplot(results_df, x=\"t\", y=\"norm\");" 313 | ] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": ".venv", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.11.11" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 2 337 | } 338 | -------------------------------------------------------------------------------- /examples/lorenz63/.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | *.png -------------------------------------------------------------------------------- /examples/lorenz63/main.py: -------------------------------------------------------------------------------- 1 | """Lorenz 63 example.""" 2 | 3 | import functools 4 | import json 5 | import sys 6 | from collections import defaultdict 7 | from pathlib import Path 8 | from time import perf_counter 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import scipy.integrate 13 | from loguru import logger 14 | from scipy.spatial.distance import pdist 15 | from sklearn.gaussian_process.kernels import RBF 16 | 17 | import linear_operator_learning as lol 18 | 19 | # Configure logger 20 | logger.remove() 21 | logger.add( 22 | sys.stdout, # Log to standard output 23 | format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", 24 | colorize=True, # Enable colorization 25 | ) 26 | 27 | this_folder = Path(__file__).parent.resolve() 28 | 29 | 30 | # Adapted from https://realpython.com/python-timer/#creating-a-python-timer-decorator 31 | def timer(func): 32 | """A decorator that times the execution of a function. 33 | 34 | Args: 35 | func: Callable to be timed. 36 | 37 | Returns: 38 | Tuple containing the output of `func` and the time it took to execute. 39 | """ 40 | 41 | @functools.wraps(func) 42 | def wrapper_timer(*args, **kwargs): 43 | tic = perf_counter() 44 | value = func(*args, **kwargs) 45 | toc = perf_counter() 46 | elapsed_time = toc - tic 47 | return value, elapsed_time 48 | 49 | return wrapper_timer 50 | 51 | 52 | def L63(x: np.ndarray, t: int = 1, sigma=10, mu=28, beta=8 / 3, dt=0.01): 53 | """Simulates the Lorenz '63 system using given initial conditions and parameters. 54 | 55 | Args: 56 | x (np.ndarray): Initial state vector of the system. 57 | t (int, optional): Total number of time steps for the simulation. Default is 1. 58 | sigma (float, optional): Prandtl number. Default is 10. 59 | mu (float, optional): Rayleigh number minus 1. Default is 28. 60 | beta (float, optional): Geometric factor. Default is 8/3. 61 | dt (float, optional): Time step size. Default is 0.01. 62 | 63 | Returns: 64 | np.ndarray: Array containing the state of the system at each time step. 65 | """ 66 | M_lin = np.array([[-sigma, sigma, 0], [mu, 0, 0], [0, 0, -beta]]) 67 | 68 | def D(_, x): 69 | dx = M_lin @ x 70 | dx[1] -= x[2] * x[0] 71 | dx[2] += x[0] * x[1] 72 | return dx 73 | 74 | sim_time = dt * (t + 1) 75 | t_eval = np.linspace(0, sim_time, t + 1, endpoint=True) 76 | t_span = (0, t_eval[-1]) 77 | sol = scipy.integrate.solve_ivp(D, t_span, x, t_eval=t_eval, method="RK45") 78 | return sol.y.T 79 | 80 | 81 | def make_dataset(): 82 | """Generates a dataset of the Lorenz '63 system, with the given parameters. 83 | 84 | Returns: 85 | dict: A dictionary containing the train and test datasets. 86 | """ 87 | train_samples = 10000 88 | test_samples = 100 89 | t = train_samples + 1000 + test_samples 90 | x = np.ones(3) 91 | raw_data = L63(x, t=t) 92 | # raw_data = LogisticMap(N=20).sample(X0 = np.ones(1), T=configs.train_samples + 1000 + configs.test_samples) 93 | mean = np.mean(raw_data, axis=0) 94 | norm = np.max(np.abs(raw_data), axis=0) 95 | # Data rescaling 96 | data = raw_data - mean 97 | data /= norm 98 | 99 | train_data = data[: train_samples + 1] 100 | test_data = data[-test_samples - 1 :] 101 | return {"train": train_data, "test": test_data} 102 | 103 | 104 | def center_selection(num_pts: int, num_centers: int, rng_seed: int | None = None): 105 | """Randomly selects a specified number of center indices from a given number of points. 106 | 107 | Args: 108 | num_pts (int): Total number of available points to select from. 109 | num_centers (int): Number of center indices to select. 110 | rng_seed (int | None, optional): Seed for the random number generator. Defaults to None. 111 | 112 | Returns: 113 | ndarray: Array of selected center indices. 114 | """ 115 | rng = np.random.default_rng(rng_seed) 116 | rand_indices = rng.choice(num_pts, num_centers) 117 | return rand_indices 118 | 119 | 120 | def main(): 121 | """Main entry point for the Lorenz 63 benchmarks. 122 | 123 | This script trains multiple reduced rank regression models with different 124 | training set sizes and plots the training times and root mean squared errors 125 | (rMSEs). 126 | """ 127 | data = make_dataset() 128 | train_data = data["train"] 129 | test_data = data["test"] 130 | # Length scale of the kernel: median of the pairwise distances of the dataset 131 | data_pdist = pdist(train_data) 132 | kernel = RBF(length_scale=np.quantile(data_pdist, 0.5)) 133 | 134 | rank = 25 135 | num_centers = 250 136 | tikhonov_reg = 1e-6 137 | 138 | train_stops = np.logspace(3, 4, 10).astype(int) 139 | timings = defaultdict(list) 140 | rMSEs = defaultdict(list) 141 | 142 | for stop in train_stops: 143 | logger.info(f"######## {stop} Training points ########") 144 | assert stop < len(train_data) 145 | X = train_data[:stop] 146 | Y = train_data[1 : stop + 1] 147 | kernel_X = kernel(X) 148 | kernel_Y = kernel(Y) 149 | kernel_YX = kernel(Y, X) 150 | X_test = test_data[:-1] 151 | kernel_Xtest_X = kernel(X_test, X) 152 | Y_test = test_data[1:] 153 | 154 | # Vanilla Reduced Rank Regression 155 | fit_results, fit_time = timer(lol.kernel.reduced_rank)( 156 | kernel_X, kernel_Y, tikhonov_reg, rank 157 | ) 158 | Y_pred = lol.kernel.predict(1, fit_results, kernel_YX, kernel_Xtest_X, Y) 159 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2)) 160 | timings["Vanilla RRR"].append(fit_time) 161 | logger.info(f"Vanilla RRR {fit_time:.1e}s") 162 | rMSEs["Vanilla RRR"].append(rMSE.item()) 163 | fit_results, fit_time = timer(lol.kernel.pcr)(kernel_X, tikhonov_reg, rank) 164 | Y_pred = lol.kernel.predict(1, fit_results, kernel_YX, kernel_Xtest_X, Y) 165 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2)) 166 | timings["Vanilla PCR"].append(fit_time) 167 | logger.info(f"Vanilla PCR {fit_time:.1e}s") 168 | rMSEs["Vanilla PCR"].append(rMSE.item()) 169 | # Nystroem 170 | center_idxs = center_selection(len(X), num_centers, rng_seed=42) 171 | fit_results, fit_time = timer(lol.kernel.nystroem_reduced_rank)( 172 | kernel(X[center_idxs]), 173 | kernel(Y[center_idxs]), 174 | kernel_X[:, center_idxs], 175 | kernel_Y[:, center_idxs], 176 | tikhonov_reg, 177 | rank, 178 | ) 179 | Y_pred = lol.kernel.predict( 180 | 1, 181 | fit_results, 182 | kernel(Y[center_idxs], X[center_idxs]), 183 | kernel_Xtest_X[:, center_idxs], 184 | Y[center_idxs], 185 | ) 186 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2)) 187 | timings["Nystroem RRR"].append(fit_time) 188 | logger.info(f"Nystroem RRR {fit_time:.1e}s") 189 | rMSEs["Nystroem RRR"].append(rMSE.item()) 190 | 191 | center_idxs = center_selection(len(X), num_centers, rng_seed=42) 192 | fit_results, fit_time = timer(lol.kernel.nystroem_pcr)( 193 | kernel(X[center_idxs]), 194 | kernel(Y[center_idxs]), 195 | kernel_X[:, center_idxs], 196 | kernel_Y[:, center_idxs], 197 | tikhonov_reg, 198 | rank, 199 | ) 200 | Y_pred = lol.kernel.predict( 201 | 1, 202 | fit_results, 203 | kernel(Y[center_idxs], X[center_idxs]), 204 | kernel_Xtest_X[:, center_idxs], 205 | Y[center_idxs], 206 | ) 207 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2)) 208 | timings["Nystroem PCR"].append(fit_time) 209 | logger.info(f"Nystroem PCR {fit_time:.1e}s") 210 | rMSEs["Nystroem PCR"].append(rMSE.item()) 211 | # Randomized 212 | fit_results, fit_time = timer(lol.kernel.rand_reduced_rank)( 213 | kernel_X, kernel_Y, tikhonov_reg, rank 214 | ) 215 | Y_pred = lol.kernel.predict(1, fit_results, kernel_YX, kernel_Xtest_X, Y) 216 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2)) 217 | timings["Randomized RRR"].append(fit_time) 218 | logger.info(f"Randomized RRR {fit_time:.1e}s") 219 | rMSEs["Randomized RRR"].append(rMSE.item()) 220 | 221 | # Save results to json 222 | with open(this_folder / "timings.json", "w") as f: 223 | json.dump(timings, f) 224 | with open(this_folder / "rMSEs.json", "w") as f: 225 | json.dump(rMSEs, f) 226 | # Plot results 227 | fig, axes = plt.subplots(ncols=2, figsize=(7, 3), dpi=144) 228 | 229 | for name in [ 230 | "Vanilla RRR", 231 | "Vanilla PCR", 232 | "Nystroem RRR", 233 | "Nystroem PCR", 234 | "Randomized RRR", 235 | ]: 236 | axes[0].plot(train_stops, rMSEs[name], ".-", label=name) 237 | axes[1].plot(train_stops, timings[name], ".-", label=name) 238 | 239 | axes[0].set_title("rMSE") 240 | axes[1].set_title("Training time (s)") 241 | axes[1].legend(frameon=False, loc="upper left") 242 | axes[1].set_yscale("log") 243 | for ax in axes: 244 | ax.set_xscale("log") 245 | ax.set_xlabel("Training points") 246 | plt.tight_layout() 247 | # Save plot 248 | plt.savefig(this_folder / "reduced_rank_benchmarks.png") 249 | 250 | 251 | if __name__ == "__main__": 252 | main() 253 | -------------------------------------------------------------------------------- /linear_operator_learning/__init__.py: -------------------------------------------------------------------------------- 1 | """Main entry point for linear operator learning.""" 2 | 3 | import sys 4 | 5 | 6 | def __getattr__(attr): 7 | """Lazy-import submodules.""" 8 | submodules = { 9 | "kernel": "linear_operator_learning.kernel", 10 | "nn": "linear_operator_learning.nn", 11 | } 12 | 13 | if attr in submodules: 14 | module = __import__(submodules[attr], fromlist=[""]) 15 | setattr(sys.modules[__name__], attr, module) # Attach to the module namespace 16 | return module 17 | 18 | raise AttributeError(f"Unknown submodule {attr}") 19 | -------------------------------------------------------------------------------- /linear_operator_learning/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | """Kernel methods entry point.""" 2 | 3 | from linear_operator_learning.kernel.regressors import * # noqa 4 | -------------------------------------------------------------------------------- /linear_operator_learning/kernel/linalg.py: -------------------------------------------------------------------------------- 1 | """Linear algebra utilities for the `kernel` algorithms.""" 2 | 3 | from warnings import warn 4 | 5 | import numpy as np 6 | from numpy import ndarray 7 | 8 | from linear_operator_learning.kernel.utils import topk 9 | 10 | 11 | def add_diagonal_(M: ndarray, alpha: float): 12 | """Add alpha to the diagonal of M inplace. 13 | 14 | Args: 15 | M (ndarray): The matrix to modify inplace. 16 | alpha (float): The value to add to the diagonal of M. 17 | """ 18 | np.fill_diagonal(M, M.diagonal() + alpha) 19 | 20 | 21 | def stable_topk( 22 | vec: ndarray, 23 | k_max: int, 24 | rcond: float | None = None, 25 | ignore_warnings: bool = True, 26 | ): 27 | """Takes up to k_max indices of the top k_max values of vec. If the values are below rcond, they are discarded. 28 | 29 | Args: 30 | vec (ndarray): Vector to extract the top k indices from. 31 | k_max (int): Number of indices to extract. 32 | rcond (float, optional): Value below which the values are discarded. Defaults to None, in which case it is set according to the machine precision of vec's dtype. 33 | ignore_warnings (bool): If False, raise a warning when some elements are discarted for being below the requested numerical precision. 34 | 35 | """ 36 | if rcond is None: 37 | rcond = 10.0 * vec.shape[0] * np.finfo(vec.dtype).eps 38 | 39 | top_vec, top_idxs = topk(vec, k_max) 40 | 41 | if all(top_vec > rcond): 42 | return top_vec, top_idxs 43 | else: 44 | valid = top_vec > rcond 45 | # In the case of multiple occurrences of the maximum vec, the indices corresponding to the first occurrence are returned. 46 | first_invalid = np.argmax(np.logical_not(valid)) 47 | _first_discarded_val = np.max(np.abs(vec[first_invalid:])) 48 | 49 | if not ignore_warnings: 50 | warn( 51 | f"Warning: Discarted {k_max - vec.shape[0]} dimensions of the {k_max} requested due to numerical instability. Consider decreasing the k. The largest discarded value is: {_first_discarded_val:.3e}." 52 | ) 53 | return top_vec[valid], top_idxs[valid] 54 | 55 | 56 | def weighted_norm(A: ndarray, M: ndarray | None = None): 57 | r"""Weighted norm of the columns of A. 58 | 59 | Args: 60 | A (ndarray): 1D or 2D array. If 2D, the columns are treated as vectors. 61 | M (ndarray or LinearOperator, optional): Weigthing matrix. the norm of the vector :math:`a` is given by :math:`\langle a, Ma \rangle` . Defaults to None, corresponding to the Identity matrix. Warning: no checks are 62 | performed on M being a PSD operator. 63 | 64 | Returns: 65 | (ndarray or float): If ``A.ndim == 2`` returns 1D array of floats corresponding to the norms of 66 | the columns of A. Else return a float. 67 | """ 68 | assert A.ndim <= 2, "'A' must be a vector or a 2D array" 69 | if M is None: 70 | norm = np.linalg.norm(A, axis=0) 71 | else: 72 | _A = np.dot(M, A) 73 | _A_T = np.dot(M.T, A) 74 | norm = np.real(np.sum(0.5 * (np.conj(A) * _A + np.conj(A) * _A_T), axis=0)) 75 | rcond = 10.0 * A.shape[0] * np.finfo(A.dtype).eps 76 | norm = np.where(norm < rcond, 0.0, norm) 77 | return np.sqrt(norm) 78 | -------------------------------------------------------------------------------- /linear_operator_learning/kernel/regressors.py: -------------------------------------------------------------------------------- 1 | """Kernel-based regressors for linear operators.""" 2 | 3 | from math import sqrt 4 | from typing import Literal 5 | from warnings import warn 6 | 7 | import numpy as np 8 | import scipy.linalg 9 | from numpy import ndarray 10 | from scipy.sparse.linalg import eigs, eigsh 11 | 12 | from linear_operator_learning.kernel.linalg import ( 13 | add_diagonal_, 14 | stable_topk, 15 | weighted_norm, 16 | ) 17 | from linear_operator_learning.kernel.structs import EigResult, FitResult 18 | from linear_operator_learning.kernel.utils import sanitize_complex_conjugates 19 | 20 | __all__ = [ 21 | "predict", 22 | "eig", 23 | "evaluate_eigenfunction", 24 | "pcr", 25 | "nystroem_pcr", 26 | "reduced_rank", 27 | "nystroem_reduced_rank", 28 | "rand_reduced_rank", 29 | ] 30 | 31 | 32 | def eig( 33 | fit_result: FitResult, 34 | K_X: ndarray, # Kernel matrix of the input data 35 | K_YX: ndarray, # Kernel matrix between the output data and the input data 36 | ) -> EigResult: 37 | """Computes the eigendecomposition of a regressor. 38 | 39 | Args: 40 | fit_result (FitResult): Fit result as defined in ``linear_operator_learning.kernel.structs``. 41 | K_X (ndarray): Kernel matrix of the input data. 42 | K_YX (ndarray): Kernel matrix between the output data and the input data. 43 | 44 | 45 | Shape: 46 | ``K_X``: :math:`(N, N)`, where :math:`N` is the sample size. 47 | 48 | ``K_YX``: :math:`(N, N)`, where :math:`N` is the sample size. 49 | 50 | Output: ``U, V`` of shape :math:`(N, R)`, ``svals`` of shape :math:`R` 51 | where :math:`N` is the sample size and :math:`R` is the rank of the regressor. 52 | """ 53 | # SUV.TZ -> V.T K_YX U (right ev = SUvr, left ev = ZVvl) 54 | U = fit_result["U"] 55 | V = fit_result["V"] 56 | r_dim = (K_X.shape[0]) ** (-1) 57 | 58 | W_YX = np.linalg.multi_dot([V.T, r_dim * K_YX, U]) 59 | W_X = np.linalg.multi_dot([U.T, r_dim * K_X, U]) 60 | 61 | values, vl, vr = scipy.linalg.eig(W_YX, left=True, right=True) # Left -> V, Right -> U 62 | values = sanitize_complex_conjugates(values) 63 | r_perm = np.argsort(values) 64 | vr = vr[:, r_perm] 65 | l_perm = np.argsort(values.conj()) 66 | vl = vl[:, l_perm] 67 | values = values[r_perm] 68 | 69 | rcond = 1000.0 * np.finfo(U.dtype).eps 70 | # Normalization in RKHS 71 | norm_r = weighted_norm(vr, W_X) 72 | norm_r = np.where(norm_r < rcond, np.inf, norm_r) 73 | vr = vr / norm_r 74 | 75 | # Bi-orthogonality of left eigenfunctions 76 | norm_l = np.diag(np.linalg.multi_dot([vl.T, W_YX, vr])) 77 | norm_l = np.where(np.abs(norm_l) < rcond, np.inf, norm_l) 78 | vl = vl / norm_l 79 | result: EigResult = {"values": values, "left": V @ vl, "right": U @ vr} 80 | return result 81 | 82 | 83 | def evaluate_eigenfunction( 84 | eig_result: EigResult, 85 | which: Literal["left", "right"], 86 | K_Xin_X_or_Y: ndarray, 87 | ): 88 | """Evaluates left or right eigenfunctions of a regressor. 89 | 90 | Args: 91 | eig_result: EigResult object containing eigendecomposition results 92 | which: String indicating "left" or "right" eigenfunctions 93 | K_Xin_X_or_Y: Kernel matrix between initial conditions and input data (for right 94 | eigenfunctions) or output data (for left eigenfunctions) 95 | 96 | 97 | Shape: 98 | ``eig_result``: ``U, V`` of shape :math:`(N, R)`, ``svals`` of shape :math:`R` 99 | where :math:`N` is the sample size and :math:`R` is the rank of the regressor. 100 | 101 | ``K_Xin_X_or_Y``: :math:`(N_0, N)`, where :math:`N_0` is the number of inputs to 102 | predict and :math:`N` is the sample size. 103 | 104 | Output: :math:`(N_0, R)` 105 | """ 106 | vr_or_vl = eig_result[which] 107 | rsqrt_dim = (K_Xin_X_or_Y.shape[1]) ** (-0.5) 108 | return np.linalg.multi_dot([rsqrt_dim * K_Xin_X_or_Y, vr_or_vl]) 109 | 110 | 111 | def predict( 112 | num_steps: int, 113 | fit_result: FitResult, 114 | kernel_YX: ndarray, 115 | kernel_Xin_X: ndarray, 116 | obs_train_Y: ndarray, 117 | ) -> ndarray: 118 | """Predicts future states given initial values using a fitted regressor. 119 | 120 | Args: 121 | num_steps (int): Number of steps to predict forward (returns the last prediction) 122 | fit_result (FitResult): FitResult object containing fitted U and V matrices 123 | kernel_YX (ndarray): Kernel matrix between output data and input data (or inducing points for Nystroem) 124 | kernel_Xin_X (ndarray): Kernel matrix between initial conditions and input data (or inducing points for Nystroem) 125 | obs_train_Y (ndarray): Observable evaluated on output training data (or inducing points for Nystroem) 126 | 127 | Shape: 128 | ``kernel_YX``: :math:`(N, N)`, where :math:`N` is the number of training data, or inducing points for Nystroem. 129 | 130 | ``kernel_Xin_X``: :math:`(N_0, N)`, where :math:`N_0` is the number of inputs to predict. 131 | 132 | ``obs_train_Y``: :math:`(N, *)`, where :math:`*` is the shape of the observable. 133 | 134 | Output: :math:`(N, *)`. 135 | """ 136 | # G = S UV.T Z 137 | # G^n = (SU)(V.T K_YX U)^(n-1)(V.T Z) 138 | U = fit_result["U"] 139 | V = fit_result["V"] 140 | npts = U.shape[0] 141 | K_dot_U = kernel_Xin_X @ U / sqrt(npts) 142 | V_dot_obs = V.T @ obs_train_Y / sqrt(npts) 143 | V_K_YX_U = np.linalg.multi_dot([V.T, kernel_YX, U]) / npts 144 | M = np.linalg.matrix_power(V_K_YX_U, num_steps - 1) 145 | return np.linalg.multi_dot([K_dot_U, M, V_dot_obs]) 146 | 147 | 148 | def pcr( 149 | kernel_X: ndarray, 150 | tikhonov_reg: float = 0.0, 151 | rank: int | None = None, 152 | svd_solver: Literal["arnoldi", "full"] = "arnoldi", 153 | ) -> FitResult: 154 | """Fits the Principal Components estimator. 155 | 156 | Args: 157 | kernel_X (ndarray): Kernel matrix of the input data. 158 | tikhonov_reg (float, optional): Tikhonov (ridge) regularization parameter. Defaults to 0.0. 159 | rank (int | None, optional): Rank of the estimator. Defaults to None. 160 | svd_solver (Literal[ "arnoldi", "full" ], optional): Solver for the generalized eigenvalue problem. Defaults to "arnoldi". 161 | 162 | Shape: 163 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data. 164 | """ 165 | npts = kernel_X.shape[0] 166 | add_diagonal_(kernel_X, npts * tikhonov_reg) 167 | if svd_solver == "arnoldi": 168 | _num_arnoldi_eigs = min(rank + 5, kernel_X.shape[0]) 169 | values, vectors = eigsh(kernel_X, k=_num_arnoldi_eigs) 170 | elif svd_solver == "full": 171 | values, vectors = scipy.linalg.eigh(kernel_X) 172 | else: 173 | raise ValueError(f"Unknown svd_solver {svd_solver}") 174 | add_diagonal_(kernel_X, -npts * tikhonov_reg) 175 | 176 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False) 177 | vectors = vectors[:, stable_values_idxs] 178 | Q = sqrt(npts) * vectors / np.sqrt(values) 179 | kernel_X_eigvalsh = np.sqrt(np.abs(values)) / npts 180 | result: FitResult = {"U": Q, "V": Q, "svals": kernel_X_eigvalsh} 181 | return result 182 | 183 | 184 | def nystroem_pcr( 185 | kernel_X: ndarray, # Kernel matrix of the input inducing points 186 | kernel_Y: ndarray, # Kernel matrix of the output inducing points 187 | kernel_Xnys: ndarray, # Kernel matrix between the input data and the input inducing points 188 | kernel_Ynys: ndarray, # Kernel matrix between the output data and the output inducing points 189 | tikhonov_reg: float = 0.0, # Tikhonov (ridge) regularization parameter (can be 0) 190 | rank: int | None = None, # Rank of the estimator 191 | svd_solver: Literal["arnoldi", "full"] = "arnoldi", 192 | ) -> FitResult: 193 | """Fits the Principal Components estimator using the Nyström method from :footcite:t:`Meanti2023`. 194 | 195 | Args: 196 | kernel_X (ndarray): Kernel matrix of the input inducing points. 197 | kernel_Y (ndarray): Kernel matrix of the output inducing points. 198 | kernel_Xnys (ndarray): Kernel matrix between the input data and the input inducing points. 199 | kernel_Ynys (ndarray): Kernel matrix between the output data and the output inducing points. 200 | tikhonov_reg (float, optional): Tikhonov (ridge) regularization parameter. Defaults to 0.0. 201 | rank (int | None, optional): Rank of the estimator. Defaults to None. 202 | svd_solver (Literal[ "arnoldi", "full" ], optional): Solver for the generalized eigenvalue problem. Defaults to "arnoldi". 203 | 204 | Shape: 205 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data. 206 | 207 | ``kernel_Y``: :math:`(N, N)`. 208 | 209 | ``kernel_Xnys``: :math:`(N, M)`, where :math:`M` is the number of Nystroem centers (inducing points). 210 | 211 | ``kernel_Ynys``: :math:`(N, M)`. 212 | """ 213 | ncenters = kernel_X.shape[0] 214 | npts = kernel_Xnys.shape[0] 215 | eps = 1000 * np.finfo(kernel_X.dtype).eps 216 | reg = max(eps, tikhonov_reg) 217 | kernel_Xnys_sq = kernel_Xnys.T @ kernel_Xnys 218 | add_diagonal_(kernel_X, reg * ncenters) 219 | if svd_solver == "full": 220 | values, vectors = scipy.linalg.eigh( 221 | kernel_Xnys_sq, kernel_X 222 | ) # normalization leads to needing to invert evals 223 | elif svd_solver == "arnoldi": 224 | _oversampling = max(10, 4 * int(np.sqrt(rank))) 225 | _num_arnoldi_eigs = min(rank + _oversampling, ncenters) 226 | values, vectors = eigsh( 227 | kernel_Xnys_sq, 228 | M=kernel_X, 229 | k=_num_arnoldi_eigs, 230 | which="LM", 231 | ) 232 | else: 233 | raise ValueError(f"Unknown svd_solver {svd_solver}") 234 | add_diagonal_(kernel_X, -reg * ncenters) 235 | 236 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False) 237 | vectors = vectors[:, stable_values_idxs] 238 | 239 | U = sqrt(ncenters) * vectors / np.sqrt(values) 240 | V = np.linalg.multi_dot([kernel_Ynys.T, kernel_Xnys, vectors]) 241 | V = scipy.linalg.lstsq(kernel_Y, V)[0] 242 | V = sqrt(ncenters) * V / np.sqrt(values) 243 | 244 | kernel_X_eigvalsh = np.sqrt(np.abs(values)) / npts 245 | result: FitResult = {"U": U, "V": V, "svals": kernel_X_eigvalsh} 246 | return result 247 | 248 | 249 | def reduced_rank( 250 | kernel_X: ndarray, # Kernel matrix of the input data 251 | kernel_Y: ndarray, # Kernel matrix of the output data 252 | tikhonov_reg: float, # Tikhonov (ridge) regularization parameter, can be 0 253 | rank: int, # Rank of the estimator 254 | svd_solver: Literal["arnoldi", "full"] = "arnoldi", 255 | ) -> FitResult: 256 | """Fits the Reduced Rank estimator from :footcite:t:`Kostic2022`. 257 | 258 | Args: 259 | kernel_X (ndarray): Kernel matrix of the input data. 260 | kernel_Y (ndarray): Kernel matrix of the output data. 261 | tikhonov_reg (float): Tikhonov (ridge) regularization parameter. 262 | rank (int): Rank of the estimator. 263 | svd_solver (Literal[ "arnoldi", "full" ], optional): Solver for the generalized eigenvalue problem. Defaults to "arnoldi". 264 | 265 | Shape: 266 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data. 267 | 268 | ``kernel_Y``: :math:`(N, N)`. 269 | """ 270 | # Number of data points 271 | npts = kernel_X.shape[0] 272 | eps = 1000.0 * np.finfo(kernel_X.dtype).eps 273 | penalty = max(eps, tikhonov_reg) * npts 274 | 275 | A = (kernel_Y / sqrt(npts)) @ (kernel_X / sqrt(npts)) 276 | add_diagonal_(kernel_X, penalty) 277 | # Find U via Generalized eigenvalue problem equivalent to the SVD. If K is ill-conditioned might be slow. 278 | # Prefer svd_solver == 'randomized' in such a case. 279 | if svd_solver == "arnoldi": 280 | # Adding a small buffer to the Arnoldi-computed eigenvalues. 281 | num_arnoldi_eigs = min(rank + 5, npts) 282 | values, vectors = eigs(A, k=num_arnoldi_eigs, M=kernel_X) 283 | elif svd_solver == "full": # 'full' 284 | values, vectors = scipy.linalg.eig(A, kernel_X, overwrite_a=True, overwrite_b=True) 285 | else: 286 | raise ValueError(f"Unknown svd_solver: {svd_solver}") 287 | # Remove the penalty from kernel_X (inplace) 288 | add_diagonal_(kernel_X, -penalty) 289 | 290 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False) 291 | vectors = vectors[:, stable_values_idxs] 292 | # Compare the filtered eigenvalues with the regularization strength, and warn if there are any eigenvalues that are smaller than the regularization strength. 293 | if not np.all(np.abs(values) >= tikhonov_reg): 294 | warn( 295 | f"Warning: {(np.abs(values) < tikhonov_reg).sum()} out of the {len(values)} squared singular values are smaller than the regularization strength {tikhonov_reg:.2e}. Consider redudcing the regularization strength to avoid overfitting." 296 | ) 297 | 298 | # Eigenvector normalization 299 | kernel_X_vecs = np.dot(kernel_X / sqrt(npts), vectors) 300 | vecs_norm = np.sqrt( 301 | np.sum( 302 | kernel_X_vecs**2 + tikhonov_reg * kernel_X_vecs * vectors * sqrt(npts), 303 | axis=0, 304 | ) 305 | ) 306 | 307 | norm_rcond = 1000.0 * np.finfo(values.dtype).eps 308 | values, stable_values_idxs = stable_topk(vecs_norm, rank, rcond=norm_rcond) 309 | U = vectors[:, stable_values_idxs] / vecs_norm[stable_values_idxs] 310 | 311 | # Ordering the results 312 | V = kernel_X @ U 313 | svals = np.sqrt(np.abs(values)) 314 | result: FitResult = {"U": U.real, "V": V.real, "svals": svals} 315 | return result 316 | 317 | 318 | def nystroem_reduced_rank( 319 | kernel_X: ndarray, # Kernel matrix of the input inducing points 320 | kernel_Y: ndarray, # Kernel matrix of the output inducing points 321 | kernel_Xnys: ndarray, # Kernel matrix between the input data and the input inducing points 322 | kernel_Ynys: ndarray, # Kernel matrix between the output data and the output inducing points 323 | tikhonov_reg: float, # Tikhonov (ridge) regularization parameter 324 | rank: int, # Rank of the estimator 325 | svd_solver: Literal["arnoldi", "full"] = "arnoldi", 326 | ) -> FitResult: 327 | """Fits the Nyström Reduced Rank estimator from :footcite:t:`Meanti2023`. 328 | 329 | Args: 330 | kernel_X (ndarray): Kernel matrix of the input inducing points. 331 | kernel_Y (ndarray): Kernel matrix of the output inducing points. 332 | kernel_Xnys (ndarray): Kernel matrix between the input data and the input inducing points. 333 | kernel_Ynys (ndarray): Kernel matrix between the output data and the output inducing points. 334 | tikhonov_reg (float): Tikhonov (ridge) regularization parameter. 335 | rank (int): Rank of the estimator. 336 | svd_solver (Literal[ "arnoldi", "full" ], optional): Solver for the generalized eigenvalue problem. Defaults to "arnoldi". 337 | 338 | Shape: 339 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data. 340 | 341 | ``kernel_Y``: :math:`(N, N)`. 342 | 343 | ``kernel_Xnys``: :math:`(N, M)`, where :math:`M` is the number of Nystroem centers (inducing points). 344 | 345 | ``kernel_Ynys``: :math:`(N, M)`. 346 | """ 347 | num_points = kernel_Xnys.shape[0] 348 | num_centers = kernel_X.shape[0] 349 | 350 | eps = 1000 * np.finfo(kernel_X.dtype).eps * num_centers 351 | reg = max(eps, tikhonov_reg) 352 | 353 | # LHS of the generalized eigenvalue problem 354 | sqrt_Mn = sqrt(num_centers * num_points) 355 | kernel_YX_nys = (kernel_Ynys.T / sqrt_Mn) @ (kernel_Xnys / sqrt_Mn) 356 | 357 | _tmp_YX = scipy.linalg.lstsq(kernel_Y * (num_centers**-1), kernel_YX_nys)[0] 358 | kernel_XYX = kernel_YX_nys.T @ _tmp_YX 359 | 360 | # RHS of the generalized eigenvalue problem 361 | kernel_Xnys_sq = (kernel_Xnys.T / sqrt_Mn) @ (kernel_Xnys / sqrt_Mn) + reg * kernel_X * ( 362 | num_centers**-1 363 | ) 364 | 365 | add_diagonal_(kernel_Xnys_sq, eps) 366 | A = scipy.linalg.lstsq(kernel_Xnys_sq, kernel_XYX)[0] 367 | if svd_solver == "full": 368 | values, vectors = scipy.linalg.eigh( 369 | kernel_XYX, kernel_Xnys_sq 370 | ) # normalization leads to needing to invert evals 371 | elif svd_solver == "arnoldi": 372 | _oversampling = max(10, 4 * int(np.sqrt(rank))) 373 | _num_arnoldi_eigs = min(rank + _oversampling, num_centers) 374 | values, vectors = eigs(kernel_XYX, k=_num_arnoldi_eigs, M=kernel_Xnys_sq) 375 | else: 376 | raise ValueError(f"Unknown svd_solver {svd_solver}") 377 | add_diagonal_(kernel_Xnys_sq, -eps) 378 | 379 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False) 380 | vectors = vectors[:, stable_values_idxs] 381 | # Compare the filtered eigenvalues with the regularization strength, and warn if there are any eigenvalues that are smaller than the regularization strength. 382 | if not np.all(np.abs(values) >= tikhonov_reg): 383 | warn( 384 | f"Warning: {(np.abs(values) < tikhonov_reg).sum()} out of the {len(values)} squared singular values are smaller than the regularization strength {tikhonov_reg:.2e}. Consider redudcing the regularization strength to avoid overfitting." 385 | ) 386 | # Eigenvector normalization 387 | vecs_norm = np.sqrt(np.abs(np.sum(vectors.conj() * (kernel_XYX @ vectors), axis=0))) 388 | norm_rcond = 1000.0 * np.finfo(values.dtype).eps 389 | values, stable_values_idxs = stable_topk(vecs_norm, rank, rcond=norm_rcond) 390 | vectors = vectors[:, stable_values_idxs] / vecs_norm[stable_values_idxs] 391 | U = A @ vectors 392 | V = _tmp_YX @ vectors 393 | svals = np.sqrt(np.abs(values)) 394 | result: FitResult = {"U": U.real, "V": V.real, "svals": svals} 395 | return result 396 | 397 | 398 | def rand_reduced_rank( 399 | kernel_X: ndarray, 400 | kernel_Y: ndarray, 401 | tikhonov_reg: float, 402 | rank: int, 403 | n_oversamples: int = 5, 404 | optimal_sketching: bool = False, 405 | iterated_power: int = 1, 406 | rng_seed: int | None = None, 407 | precomputed_cholesky=None, 408 | ) -> FitResult: 409 | """Fits the Randomized Reduced Rank Estimator from :footcite:t:`Turri2023`. 410 | 411 | Args: 412 | kernel_X (ndarray): Kernel matrix of the input data 413 | kernel_Y (ndarray): Kernel matrix of the output data 414 | tikhonov_reg (float): Tikhonov (ridge) regularization parameter 415 | rank (int): Rank of the estimator 416 | n_oversamples (int, optional): Number of Oversamples. Defaults to 5. 417 | optimal_sketching (bool, optional): Whether to use optimal sketching (slower but more accurate) or not.. Defaults to False. 418 | iterated_power (int, optional): Number of iterations of the power method. Defaults to 1. 419 | rng_seed (int | None, optional): Random Number Generators seed. Defaults to None. 420 | precomputed_cholesky (optional): Precomputed Cholesky decomposition. Should be the output of cho_factor evaluated on the regularized kernel matrix.. Defaults to None. 421 | 422 | Shape: 423 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data. 424 | 425 | ``kernel_Y``: :math:`(N, N)`. 426 | """ 427 | rng = np.random.default_rng(rng_seed) 428 | npts = kernel_X.shape[0] 429 | 430 | penalty = npts * tikhonov_reg 431 | add_diagonal_(kernel_X, penalty) 432 | if precomputed_cholesky is None: 433 | cholesky_decomposition = scipy.linalg.cho_factor(kernel_X) 434 | else: 435 | cholesky_decomposition = precomputed_cholesky 436 | add_diagonal_(kernel_X, -penalty) 437 | 438 | sketch_dimension = rank + n_oversamples 439 | 440 | if optimal_sketching: 441 | cov = kernel_Y / npts 442 | sketch = rng.multivariate_normal( 443 | np.zeros(npts, dtype=kernel_Y.dtype), cov, size=sketch_dimension 444 | ).T 445 | else: 446 | sketch = rng.standard_normal(size=(npts, sketch_dimension)) 447 | 448 | for _ in range(iterated_power): 449 | # Powered randomized rangefinder 450 | sketch = (kernel_Y / npts) @ ( 451 | sketch - penalty * scipy.linalg.cho_solve(cholesky_decomposition, sketch) 452 | ) 453 | sketch, _ = scipy.linalg.qr(sketch, mode="economic") # QR re-orthogonalization 454 | 455 | kernel_X_sketch = scipy.linalg.cho_solve(cholesky_decomposition, sketch) 456 | _M = sketch - penalty * kernel_X_sketch 457 | 458 | F_0 = sketch.T @ sketch - penalty * (sketch.T @ kernel_X_sketch) # Symmetric 459 | F_0 = 0.5 * (F_0 + F_0.T) 460 | F_1 = _M.T @ ((kernel_Y @ _M) / npts) 461 | 462 | values, vectors = scipy.linalg.eig(scipy.linalg.lstsq(F_0, F_1)[0]) 463 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False) 464 | vectors = vectors[:, stable_values_idxs] 465 | 466 | # Remove elements in the kernel of F_0 467 | relative_norm_sq = np.abs( 468 | np.sum(vectors.conj() * (F_0 @ vectors), axis=0) / np.linalg.norm(vectors, axis=0) ** 2 469 | ) 470 | norm_rcond = 1000.0 * np.finfo(values.dtype).eps 471 | values, stable_values_idxs = stable_topk(relative_norm_sq, rank, rcond=norm_rcond) 472 | vectors = vectors[:, stable_values_idxs] 473 | 474 | vecs_norms = (np.sum(vectors.conj() * (F_0 @ vectors), axis=0).real) ** 0.5 475 | vectors = vectors / vecs_norms 476 | 477 | U = sqrt(npts) * kernel_X_sketch @ vectors 478 | V = sqrt(npts) * _M @ vectors 479 | svals = np.sqrt(values) 480 | result: FitResult = {"U": U, "V": V, "svals": svals} 481 | return result 482 | -------------------------------------------------------------------------------- /linear_operator_learning/kernel/structs.py: -------------------------------------------------------------------------------- 1 | """Structs used by the `kernel` algorithms.""" 2 | 3 | from typing import TypedDict 4 | 5 | import numpy as np 6 | from numpy import ndarray 7 | 8 | 9 | class FitResult(TypedDict): 10 | """Return type for kernel regressors.""" 11 | 12 | U: ndarray 13 | V: ndarray 14 | svals: ndarray | None 15 | 16 | 17 | class EigResult(TypedDict): 18 | """Return type for eigenvalue decompositions of kernel regressors.""" 19 | 20 | values: ndarray 21 | left: ndarray | None 22 | right: ndarray 23 | -------------------------------------------------------------------------------- /linear_operator_learning/kernel/utils.py: -------------------------------------------------------------------------------- 1 | """Generic Utilities.""" 2 | 3 | from math import sqrt 4 | 5 | import numpy as np 6 | from numpy import ndarray 7 | from scipy.spatial.distance import pdist 8 | 9 | 10 | def topk(vec: ndarray, k: int): 11 | """Get the top k values from a Numpy array. 12 | 13 | Args: 14 | vec (ndarray): A 1D numpy array 15 | k (int): Number of elements to keep 16 | 17 | Returns: 18 | values, indices: top k values and their indices 19 | """ 20 | assert np.ndim(vec) == 1, "'vec' must be a 1D array" 21 | assert k > 0, "k should be greater than 0" 22 | sort_perm = np.flip(np.argsort(vec)) # descending order 23 | indices = sort_perm[:k] 24 | values = vec[indices] 25 | return values, indices 26 | 27 | 28 | def sanitize_complex_conjugates(vec: ndarray, tol: float = 10.0): 29 | """This function acts on 1D complex vectors. If the real parts of two elements are close, sets them equal. Furthermore, sets to 0 the imaginary parts smaller than `tol` times the machine precision. 30 | 31 | Args: 32 | vec (ndarray): A 1D vector to sanitize. 33 | tol (float, optional): Tolerance for comparisons. Defaults to 10.0. 34 | 35 | """ 36 | assert issubclass(vec.dtype.type, np.complexfloating), "The input element should be complex" 37 | assert vec.ndim == 1 38 | rcond = tol * np.finfo(vec.dtype).eps 39 | pdist_real_part = pdist(vec.real[:, None]) 40 | # Set the same element whenever pdist is smaller than eps*tol 41 | condensed_idxs = np.argwhere(pdist_real_part < rcond)[:, 0] 42 | fuzzy_real = vec.real.copy() 43 | if condensed_idxs.shape[0] >= 1: 44 | for idx in condensed_idxs: 45 | i, j = _row_col_from_condensed_index(vec.real.shape[0], idx) 46 | avg = 0.5 * (fuzzy_real[i] + fuzzy_real[j]) 47 | fuzzy_real[i] = avg 48 | fuzzy_real[j] = avg 49 | fuzzy_imag = vec.imag.copy() 50 | fuzzy_imag[np.abs(fuzzy_imag) < rcond] = 0.0 51 | return fuzzy_real + 1j * fuzzy_imag 52 | 53 | 54 | def _row_col_from_condensed_index(d, index): 55 | # Credits to: https://stackoverflow.com/a/14839010 56 | b = 1 - (2 * d) 57 | i = (-b - sqrt(b**2 - 8 * index)) // 2 58 | j = index + i * (b + i + 2) // 2 + 1 59 | return (int(i), int(j)) 60 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Neural network methods entry point.""" 2 | 3 | import linear_operator_learning.nn.functional as functional 4 | import linear_operator_learning.nn.stats as stats 5 | from linear_operator_learning.nn.modules import * # noqa: F403 6 | from linear_operator_learning.nn.regressors import eig, evaluate_eigenfunction, ridge_least_squares 7 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/functional.py: -------------------------------------------------------------------------------- 1 | """Functional interface.""" 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from linear_operator_learning.nn.linalg import sqrtmh 7 | from linear_operator_learning.nn.stats import cov_norm_squared_unbiased, covariance 8 | 9 | # Losses_____________________________________________________________________________________________ 10 | 11 | 12 | def vamp_loss( 13 | x: Tensor, y: Tensor, schatten_norm: int = 2, center_covariances: bool = True 14 | ) -> Tensor: 15 | """See :class:`linear_operator_learning.nn.VampLoss` for details.""" 16 | cov_x, cov_y, cov_xy = ( 17 | covariance(x, center=center_covariances), 18 | covariance(y, center=center_covariances), 19 | covariance(x, y, center=center_covariances), 20 | ) 21 | if schatten_norm == 2: 22 | # Using least squares in place of pinv for numerical stability 23 | M_x = torch.linalg.lstsq(cov_x, cov_xy).solution 24 | M_y = torch.linalg.lstsq(cov_y, cov_xy.T).solution 25 | return -torch.trace(M_x @ M_y) 26 | elif schatten_norm == 1: 27 | sqrt_cov_x = sqrtmh(cov_x) 28 | sqrt_cov_y = sqrtmh(cov_y) 29 | M = torch.linalg.multi_dot( 30 | [ 31 | torch.linalg.pinv(sqrt_cov_x, hermitian=True), 32 | cov_xy, 33 | torch.linalg.pinv(sqrt_cov_y, hermitian=True), 34 | ] 35 | ) 36 | return -torch.linalg.matrix_norm(M, "nuc") 37 | else: 38 | raise NotImplementedError(f"Schatten norm {schatten_norm} not implemented") 39 | 40 | 41 | def dp_loss( 42 | x: Tensor, 43 | y: Tensor, 44 | relaxed: bool = True, 45 | center_covariances: bool = True, 46 | ) -> Tensor: 47 | """See :class:`linear_operator_learning.nn.DPLoss` for details.""" 48 | cov_x, cov_y, cov_xy = ( 49 | covariance(x, center=center_covariances), 50 | covariance(y, center=center_covariances), 51 | covariance(x, y, center=center_covariances), 52 | ) 53 | if relaxed: 54 | S = (torch.linalg.matrix_norm(cov_xy, ord="fro") ** 2) / ( 55 | torch.linalg.matrix_norm(cov_x, ord=2) * torch.linalg.matrix_norm(cov_y, ord=2) 56 | ) 57 | else: 58 | M_x = torch.linalg.lstsq(cov_x, cov_xy).solution 59 | M_y = torch.linalg.lstsq(cov_y, cov_xy.T).solution 60 | S = torch.trace(M_x @ M_y) 61 | return -S 62 | 63 | 64 | def l2_contrastive_loss(x: Tensor, y: Tensor) -> Tensor: 65 | """See :class:`linear_operator_learning.nn.L2ContrastiveLoss` for details.""" 66 | assert x.shape == y.shape 67 | assert x.ndim == 2 68 | 69 | npts, dim = x.shape 70 | diag = 2 * torch.mean(x * y) * dim 71 | square_term = torch.matmul(x, y.T) ** 2 72 | off_diag = ( 73 | torch.mean(torch.triu(square_term, diagonal=1) + torch.tril(square_term, diagonal=-1)) 74 | * npts 75 | / (npts - 1) 76 | ) 77 | return off_diag - diag 78 | 79 | 80 | def kl_contrastive_loss(X: Tensor, Y: Tensor) -> Tensor: 81 | """See :class:`linear_operator_learning.nn.KLContrastiveLoss` for details.""" 82 | assert X.shape == Y.shape 83 | assert X.ndim == 2 84 | 85 | npts, dim = X.shape 86 | log_term = torch.mean(torch.log(X * Y)) * dim 87 | linear_term = torch.matmul(X, Y.T) 88 | off_diag = ( 89 | torch.mean(torch.triu(linear_term, diagonal=1) + torch.tril(linear_term, diagonal=-1)) 90 | * npts 91 | / (npts - 1) 92 | ) 93 | return off_diag - log_term 94 | 95 | 96 | # Regularizers______________________________________________________________________________________ 97 | 98 | 99 | def orthonormal_fro_reg(x: Tensor) -> Tensor: 100 | r"""Orthonormality regularization with Frobenious norm of covariance of `x`. 101 | 102 | Given a batch of realizations of `x`, the orthonormality regularization term penalizes: 103 | 104 | 1. Orthogonality: Linear dependencies among dimensions, 105 | 2. Normality: Deviations of each dimension’s variance from 1, 106 | 3. Centering: Deviations of each dimension’s mean from 0. 107 | 108 | .. math:: 109 | 110 | \frac{1}{D} \| \mathbf{C}_{X} - I \|_F^2 + 2 \| \mathbb{E}_{X} x \|^2 = \frac{1}{D} (\text{tr}(\mathbf{C}^2_{X}) - 2 \text{tr}(\mathbf{C}_{X}) + D + 2 \| \mathbb{E}_{X} x \|^2) 111 | 112 | Args: 113 | x (Tensor): Input features. 114 | 115 | Shape: 116 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 117 | """ 118 | x_mean = x.mean(dim=0, keepdim=True) 119 | x_centered = x - x_mean 120 | # As ||Cx||_F^2 = E_(x,x')~p(x) [((x - E_p(x) x)^T (x' - E_p(x) x'))^2] = tr(Cx^2), involves the product of 121 | # covariances, unbiased estimation of this term requires the use of U-statistics 122 | Cx_fro_2 = cov_norm_squared_unbiased(x_centered) 123 | # tr(Cx) = E_p(x) [(x - E_p(x))^T (x - E_p(x))] ≈ 1/N Σ_n (x_n - E_p(x))^T (x_n - E_p(x)) 124 | tr_Cx = torch.einsum("ij,ij->", x_centered, x_centered) / x.shape[0] 125 | centering_loss = (x_mean**2).sum() # ||E_p(x) x||^2 126 | D = x.shape[-1] # ||I||_F^2 = D 127 | reg = Cx_fro_2 - 2 * tr_Cx + D + 2 * centering_loss 128 | return reg / D 129 | 130 | 131 | def orthonormal_logfro_reg(x: Tensor) -> Tensor: 132 | r"""Orthonormality regularization with log-Frobenious norm of covariance of x by :footcite:t:`Kostic2023DPNets`. 133 | 134 | .. math:: 135 | 136 | \frac{1}{D}\text{Tr}(C_X^{2} - C_X -\ln(C_X)). 137 | 138 | Args: 139 | x (Tensor): Input features. 140 | 141 | Shape: 142 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 143 | """ 144 | cov = covariance(x) # shape: (D, D) 145 | eps = torch.finfo(cov.dtype).eps * cov.shape[0] 146 | vals_x = torch.linalg.eigvalsh(cov) 147 | vals_x = torch.where(vals_x > eps, vals_x, eps) 148 | orth_loss = torch.mean(-torch.log(vals_x) + vals_x * (vals_x - 1.0)) 149 | # TODO: Centering like this? 150 | centering_loss = (x.mean(0, keepdim=True) ** 2).sum() # ||E_p(x) x||^2 151 | reg = orth_loss + 2 * centering_loss 152 | return reg 153 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/linalg.py: -------------------------------------------------------------------------------- 1 | """Linear Algebra.""" 2 | 3 | from typing import NamedTuple 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | 9 | def sqrtmh(A: Tensor) -> Tensor: 10 | """Compute the square root of a Symmetric or Hermitian positive definite matrix or batch of matrices. 11 | 12 | Used code from `this issue `_. 13 | 14 | Args: 15 | A (Tensor): Symmetric or Hermitian positive definite matrix or batch of matrices. 16 | 17 | Shape: 18 | ``A``: :math:`(N, N)` 19 | 20 | Output: :math:`(N, N)` 21 | """ 22 | L, Q = torch.linalg.eigh(A) 23 | zero = torch.zeros((), device=L.device, dtype=L.dtype) 24 | threshold = L.max(-1).values * L.size(-1) * torch.finfo(L.dtype).eps 25 | L = L.where(L > threshold.unsqueeze(-1), zero) # zero out small components 26 | return (Q * L.sqrt().unsqueeze(-2)) @ Q.mH 27 | 28 | 29 | #################################################################################################### 30 | # TODO: THIS IS JUST COPY AND PASTE FROM OLD NCP 31 | # Should topk and filter_reduced_rank_svals be in utils? They look like linalg to me, specially the 32 | # filter 33 | #################################################################################################### 34 | 35 | 36 | # Sorting and parsing 37 | class TopKReturnType(NamedTuple): # noqa: D101 38 | values: Tensor 39 | indices: Tensor 40 | 41 | 42 | def topk(vec: Tensor, k: int): # noqa: D103 43 | assert vec.ndim == 1, "'vec' must be a 1D array" 44 | assert k > 0, "k should be greater than 0" 45 | sort_perm = torch.flip(torch.argsort(vec), dims=[0]) # descending order 46 | indices = sort_perm[:k] 47 | values = vec[indices] 48 | return TopKReturnType(values, indices) 49 | 50 | 51 | def filter_reduced_rank_svals(values, vectors): # noqa: D103 52 | eps = 2 * torch.finfo(torch.get_default_dtype()).eps 53 | # Filtering procedure. 54 | # Create a mask which is True when the real part of the eigenvalue is negative or the imaginary part is nonzero 55 | is_invalid = torch.logical_or( 56 | torch.real(values) <= eps, 57 | torch.imag(values) != 0 58 | if torch.is_complex(values) 59 | else torch.zeros(len(values), device=values.device), 60 | ) 61 | # Check if any is invalid take the first occurrence of a True value in the mask and filter everything after that 62 | if torch.any(is_invalid): 63 | values = values[~is_invalid].real 64 | vectors = vectors[:, ~is_invalid] 65 | 66 | sort_perm = topk(values, len(values)).indices 67 | values = values[sort_perm] 68 | vectors = vectors[:, sort_perm] 69 | 70 | # Assert that the eigenvectors do not have any imaginary part 71 | assert torch.all( 72 | torch.imag(vectors) == 0 if torch.is_complex(values) else torch.ones(len(values)) 73 | ), "The eigenvectors should be real. Decrease the rank or increase the regularization strength." 74 | 75 | # Take the real part of the eigenvectors 76 | vectors = torch.real(vectors) 77 | values = torch.real(values) 78 | return values, vectors 79 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Modules entry point.""" 2 | 3 | from .ema_covariance import EMACovariance 4 | from .loss import * # noqa: F403 5 | from .mlp import MLP 6 | from .resnet import * # noqa: F403 7 | from .simnorm import SimNorm 8 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/modules/ema_covariance.py: -------------------------------------------------------------------------------- 1 | """Exponential moving average of the covariance matrices.""" 2 | 3 | import torch 4 | import torch.distributed 5 | from torch import Tensor 6 | 7 | 8 | class EMACovariance(torch.nn.Module): 9 | r"""Exponential moving average of the covariance matrices. 10 | 11 | Gives an online estimate of the covariances and means :math:`C` adding the batch covariance :math:`\hat{C}` via the following update forumla 12 | 13 | .. math:: 14 | 15 | C \leftarrow (1 - m)C + m \hat{C} 16 | 17 | Args: 18 | feature_dim: The number of features in the input and output tensors. 19 | momentum: The momentum for the exponential moving average. 20 | center: Whether to center the data before computing the covariance matrices. 21 | """ 22 | 23 | def __init__(self, feature_dim: int, momentum: float = 0.01, center: bool = True): 24 | super().__init__() 25 | self.is_centered = center 26 | self.momentum = momentum 27 | self.register_buffer("mean_X", torch.zeros(feature_dim)) 28 | self.register_buffer("cov_X", torch.eye(feature_dim)) 29 | self.register_buffer("mean_Y", torch.zeros(feature_dim)) 30 | self.register_buffer("cov_Y", torch.eye(feature_dim)) 31 | self.register_buffer("cov_XY", torch.eye(feature_dim)) 32 | self.register_buffer("is_initialized", torch.tensor(False, dtype=torch.bool)) 33 | 34 | @torch.no_grad() 35 | def forward(self, X: Tensor, Y: Tensor): 36 | """Update the exponential moving average of the covariance matrices. 37 | 38 | Args: 39 | X: Input tensor. 40 | Y: Output tensor. 41 | 42 | Shape: 43 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 44 | 45 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 46 | """ 47 | if not self.training: 48 | return 49 | assert X.ndim == 2 50 | assert X.shape == Y.shape 51 | assert X.shape[1] == self.mean_X.shape[0] 52 | if not self.is_initialized.item(): 53 | self._first_forward(X, Y) 54 | else: 55 | mean_X = X.mean(dim=0, keepdim=True) 56 | mean_Y = Y.mean(dim=0, keepdim=True) 57 | # Update means 58 | self._inplace_EMA(mean_X[0], self.mean_X) 59 | self._inplace_EMA(mean_Y[0], self.mean_Y) 60 | 61 | if self.is_centered: 62 | X = X - self.mean_X 63 | Y = Y - self.mean_Y 64 | 65 | cov_X = torch.mm(X.T, X) / X.shape[0] 66 | cov_Y = torch.mm(Y.T, Y) / Y.shape[0] 67 | cov_XY = torch.mm(X.T, Y) / X.shape[0] 68 | # Update covariances 69 | self._inplace_EMA(cov_X, self.cov_X) 70 | self._inplace_EMA(cov_Y, self.cov_Y) 71 | self._inplace_EMA(cov_XY, self.cov_XY) 72 | 73 | def _first_forward(self, X: torch.Tensor, Y: torch.Tensor): 74 | mean_X = X.mean(dim=0, keepdim=True) 75 | self._inplace_set(mean_X[0], self.mean_X) 76 | mean_Y = Y.mean(dim=0, keepdim=True) 77 | self._inplace_set(mean_Y[0], self.mean_Y) 78 | if self.is_centered: 79 | X = X - self.mean_X 80 | Y = Y - self.mean_Y 81 | 82 | cov_X = torch.mm(X.T, X) / X.shape[0] 83 | cov_Y = torch.mm(Y.T, Y) / Y.shape[0] 84 | cov_XY = torch.mm(X.T, Y) / X.shape[0] 85 | self._inplace_set(cov_X, self.cov_X) 86 | self._inplace_set(cov_Y, self.cov_Y) 87 | self._inplace_set(cov_XY, self.cov_XY) 88 | self.is_initialized = torch.tensor(True, dtype=torch.bool) 89 | 90 | def _inplace_set(self, update, current): 91 | if torch.distributed.is_initialized(): 92 | torch.distributed.all_reduce(update, op=torch.distributed.ReduceOp.SUM) 93 | update /= torch.distributed.get_world_size() 94 | current.copy_(update) 95 | 96 | def _inplace_EMA(self, update, current): 97 | alpha = 1 - self.momentum 98 | if torch.distributed.is_initialized(): 99 | torch.distributed.all_reduce(update, op=torch.distributed.ReduceOp.SUM) 100 | update /= torch.distributed.get_world_size() 101 | 102 | current.mul_(alpha).add_(update, alpha=self.momentum) 103 | 104 | 105 | def test_EMACovariance(): # noqa: D103 106 | torch.manual_seed(0) 107 | 108 | dims = 5 109 | dummy_X = torch.randn(10, dims) 110 | dummy_Y = torch.randn(10, dims) 111 | cov_module = EMACovariance(feature_dim=dims) 112 | 113 | # Check that when model is not set to training covariance is not updated 114 | cov_module.eval() 115 | cov_module(dummy_X, dummy_Y) 116 | assert torch.allclose(cov_module.cov_X, torch.eye(dims)) 117 | assert torch.allclose(cov_module.cov_Y, torch.eye(dims)) 118 | assert torch.allclose(cov_module.cov_XY, torch.eye(dims)) 119 | 120 | assert torch.allclose(cov_module.mean_X, torch.zeros(dims)) 121 | assert torch.allclose(cov_module.mean_Y, torch.zeros(dims)) 122 | 123 | # Check that the first_forward is correctly called 124 | cov_module.train() 125 | assert not cov_module.is_initialized.item() 126 | cov_module(dummy_X, dummy_Y) 127 | assert cov_module.is_initialized.item() 128 | assert torch.allclose(cov_module.mean_X, dummy_X.mean(dim=0)) 129 | assert torch.allclose(cov_module.mean_Y, dummy_Y.mean(dim=0)) 130 | if cov_module.is_centered: 131 | assert torch.allclose(cov_module.cov_X, torch.cov(dummy_X.T, correction=0)) 132 | assert torch.allclose(cov_module.cov_Y, torch.cov(dummy_Y.T, correction=0)) 133 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/modules/loss.py: -------------------------------------------------------------------------------- 1 | """Loss functions for representation learning.""" 2 | 3 | from typing import Literal 4 | 5 | from torch import Tensor 6 | from torch.nn import Module 7 | 8 | from linear_operator_learning.nn import functional as F 9 | 10 | __all__ = ["VampLoss", "L2ContrastiveLoss", "KLContrastiveLoss", "DPLoss"] 11 | 12 | # Losses_____________________________________________________________________________________________ 13 | 14 | 15 | class _RegularizedLoss(Module): 16 | """Base class for regularized losses. 17 | 18 | Args: 19 | gamma (float, optional): Regularization strength. 20 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `. 21 | """ 22 | 23 | def __init__( 24 | self, gamma: float, regularizer: Literal["orthn_fro", "orthn_logfro"] 25 | ) -> None: # TODO: Automatically determine 'gamma' from dim_x and dim_y 26 | super().__init__() 27 | self.gamma = gamma 28 | 29 | if regularizer == "orthn_fro": 30 | self.regularizer = F.orthonormal_fro_reg 31 | elif regularizer == "orthn_logfro": 32 | self.regularizer = F.orthonormal_logfro_reg 33 | else: 34 | raise NotImplementedError(f"Regularizer {regularizer} not supported!") 35 | 36 | 37 | class VampLoss(_RegularizedLoss): 38 | r"""Variational Approach for learning Markov Processes (VAMP) score by :footcite:t:`Wu2019`. 39 | 40 | .. math:: 41 | 42 | \mathcal{L}(x, y) = -\sum_{i} \sigma_{i}(A)^{p} \qquad \text{where}~A = \big(x^{\top}x\big)^{\dagger/2}x^{\top}y\big(y^{\top}y\big)^{\dagger/2}. 43 | 44 | Args: 45 | schatten_norm (int, optional): Computes the VAMP-p score with ``p = schatten_norm``. Defaults to 2. 46 | center_covariances (bool, optional): Use centered covariances to compute the VAMP score. Defaults to True. 47 | gamma (float, optional): Regularization strength. Defaults to 1e-3. 48 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `. 49 | """ 50 | 51 | def __init__( 52 | self, 53 | schatten_norm: int = 2, 54 | center_covariances: bool = True, 55 | gamma: float = 1e-3, 56 | regularizer: Literal["orthn_fro", "orthn_logfro"] = "orthn_fro", 57 | ) -> None: 58 | super().__init__(gamma, regularizer) 59 | self.schatten_norm = schatten_norm 60 | self.center_covariances = center_covariances 61 | 62 | def forward(self, x: Tensor, y: Tensor) -> Tensor: 63 | """Forward pass of VAMP loss. 64 | 65 | Args: 66 | x (Tensor): Features for x. 67 | y (Tensor): Features for y. 68 | 69 | Raises: 70 | NotImplementedError: If ``schatten_norm`` is not 1 or 2. 71 | 72 | Shape: 73 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 74 | 75 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 76 | """ 77 | return F.vamp_loss( 78 | x, 79 | y, 80 | self.schatten_norm, 81 | self.center_covariances, 82 | ) + self.gamma * (self.regularizer(x) + self.regularizer(y)) 83 | 84 | 85 | class L2ContrastiveLoss(_RegularizedLoss): 86 | r"""NCP/Contrastive/Mutual Information Loss based on the :math:`L^{2}` error by :footcite:t:`Kostic2024NCP`. 87 | 88 | .. math:: 89 | 90 | \mathcal{L}(x, y) = \frac{1}{N(N-1)}\sum_{i \neq j}\langle x_{i}, y_{j} \rangle^2 - \frac{2}{N}\sum_{i=1}\langle x_{i}, y_{i} \rangle. 91 | 92 | Args: 93 | gamma (float, optional): Regularization strength. Defaults to 1e-3. 94 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | gamma: float = 1e-3, 100 | regularizer: Literal["orthn_fro", "orthn_logfro"] = "orthn_fro", 101 | ) -> None: 102 | super().__init__(gamma, regularizer) 103 | 104 | def forward(self, x: Tensor, y: Tensor) -> Tensor: # noqa: D102 105 | """Forward pass of the L2 contrastive loss. 106 | 107 | Args: 108 | x (Tensor): Input features. 109 | y (Tensor): Output features. 110 | 111 | 112 | Shape: 113 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 114 | 115 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 116 | """ 117 | return F.l2_contrastive_loss(x, y) + self.gamma * ( 118 | self.regularizer(x) + self.regularizer(y) 119 | ) 120 | 121 | 122 | class DPLoss(_RegularizedLoss): 123 | r"""Deep Projection Loss by :footcite:t:`Kostic2023DPNets`. 124 | 125 | .. math:: 126 | 127 | \mathcal{L}(x, y) = -\frac{\|x^{\top}y\|^{2}_{{\rm F}}}{\|x^{\top}x\|^{2}\|y^{\top}y\|^{2}}. 128 | 129 | Args: 130 | relaxed (bool, optional): Whether to use the relaxed (more numerically stable) or the full deep-projection loss. Defaults to True. 131 | center_covariances (bool, optional): Use centered covariances to compute the Deep Projection loss. Defaults to True. 132 | gamma (float, optional): Regularization strength. Defaults to 1e-3. 133 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `. 134 | """ 135 | 136 | def __init__( 137 | self, 138 | relaxed: bool = True, 139 | center_covariances: bool = True, 140 | gamma: float = 1e-3, 141 | regularizer: Literal["orthn_fro", "orthn_logfro"] = "orthn_fro", 142 | ) -> None: 143 | super().__init__(gamma, regularizer) 144 | self.relaxed = relaxed 145 | self.center_covariances = center_covariances 146 | 147 | def forward(self, x: Tensor, y: Tensor) -> Tensor: 148 | """Forward pass of DPLoss. 149 | 150 | Args: 151 | x (Tensor): Features for x. 152 | y (Tensor): Features for y. 153 | 154 | Shape: 155 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 156 | 157 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 158 | """ 159 | return F.dp_loss( 160 | x, 161 | y, 162 | self.relaxed, 163 | self.center_covariances, 164 | ) + self.gamma * (self.regularizer(x) + self.regularizer(y)) 165 | 166 | 167 | class KLContrastiveLoss(_RegularizedLoss): 168 | r"""NCP/Contrastive/Mutual Information Loss based on the KL divergence. 169 | 170 | .. math:: 171 | 172 | \mathcal{L}(x, y) = \frac{1}{N(N-1)}\sum_{i \neq j}\langle x_{i}, y_{j} \rangle - \frac{2}{N}\sum_{i=1}\log\big(\langle x_{i}, y_{i} \rangle\big). 173 | 174 | Args: 175 | gamma (float, optional): Regularization strength. Defaults to 1e-3. 176 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `. 177 | 178 | 179 | """ 180 | 181 | def __init__( 182 | self, 183 | gamma: float = 1e-3, 184 | regularizer: Literal["orthn_fro", "orthn_logfro"] = "orthn_fro", 185 | ) -> None: 186 | super().__init__(gamma, regularizer) 187 | 188 | def forward(self, x: Tensor, y: Tensor) -> Tensor: # noqa: D102 189 | """Forward pass of the KL contrastive loss. 190 | 191 | Args: 192 | x (Tensor): Input features. 193 | y (Tensor): Output features. 194 | 195 | 196 | Shape: 197 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 198 | 199 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 200 | """ 201 | return F.kl_contrastive_loss(x, y) + self.gamma * ( 202 | self.regularizer(x) + self.regularizer(y) 203 | ) 204 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/modules/mlp.py: -------------------------------------------------------------------------------- 1 | # TODO: Refactor the models, add docstrings, etc... 2 | """PyTorch Models.""" 3 | 4 | import torch 5 | from torch.nn import Conv2d, Dropout, Linear, MaxPool2d, Module, ReLU, Sequential 6 | 7 | 8 | class _MLPBlock(Module): 9 | def __init__(self, input_size, output_size, dropout=0.0, activation=ReLU, bias=True): 10 | super(_MLPBlock, self).__init__() 11 | self.linear = Linear(input_size, output_size, bias=bias) 12 | self.dropout = Dropout(dropout) 13 | self.activation = activation() 14 | 15 | def forward(self, x): 16 | out = self.linear(x) 17 | out = self.dropout(out) 18 | out = self.activation(out) 19 | return out 20 | 21 | 22 | class MLP(Module): 23 | """Multi Layer Perceptron. 24 | 25 | Args: 26 | input_shape (int): Input shape of the MLP. 27 | n_hidden (int): Number of hidden layers. 28 | layer_size (int or list of ints): Number of neurons in each layer. If an int is 29 | provided, it is used as the number of neurons for all hidden layers. Otherwise, 30 | the list of int is used to define the number of neurons for each layer. 31 | output_shape (int): Output shape of the MLP. 32 | dropout (float): Dropout probability between layers. Defaults to 0.0. 33 | activation (torch.nn.Module): Activation function. Defaults to ReLU. 34 | iterative_whitening (bool): Whether to add an IterNorm layer at the end of the 35 | network. Defaults to False. 36 | bias (bool): Whether to include bias in the layers. Defaults to False. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | input_shape, 42 | n_hidden, 43 | layer_size, 44 | output_shape, 45 | dropout=0.0, 46 | activation=ReLU, 47 | iterative_whitening=False, 48 | bias=False, 49 | ): 50 | super(MLP, self).__init__() 51 | if isinstance(layer_size, int): 52 | layer_size = [layer_size] * n_hidden 53 | if n_hidden == 0: 54 | layers = [Linear(input_shape, output_shape, bias=False)] 55 | else: 56 | layers = [] 57 | for layer in range(n_hidden): 58 | if layer == 0: 59 | layers.append( 60 | _MLPBlock(input_shape, layer_size[layer], dropout, activation, bias=bias) 61 | ) 62 | else: 63 | layers.append( 64 | _MLPBlock( 65 | layer_size[layer - 1], layer_size[layer], dropout, activation, bias=bias 66 | ) 67 | ) 68 | 69 | layers.append(Linear(layer_size[-1], output_shape, bias=False)) 70 | if iterative_whitening: 71 | # layers.append(IterNorm(output_shape)) 72 | raise NotImplementedError("IterNorm isn't implemented") 73 | self.model = Sequential(*layers) 74 | 75 | def forward(self, x): # noqa: D102 76 | return self.model(x) 77 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/modules/resnet.py: -------------------------------------------------------------------------------- 1 | """Resnet Module.""" 2 | 3 | from typing import Any, Callable, List, Optional, Type, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | 9 | __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] 10 | 11 | 12 | def conv3x3( 13 | in_planes: int, 14 | out_planes: int, 15 | stride: int = 1, 16 | groups: int = 1, 17 | dilation: int = 1, 18 | padding_mode: str = "zeros", 19 | ) -> nn.Conv2d: 20 | """3x3 convolution with padding.""" 21 | return nn.Conv2d( 22 | in_planes, 23 | out_planes, 24 | kernel_size=3, 25 | stride=stride, 26 | padding=dilation, 27 | groups=groups, 28 | bias=False, 29 | dilation=dilation, 30 | padding_mode=padding_mode, 31 | ) 32 | 33 | 34 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 35 | """1x1 convolution.""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion: int = 1 41 | 42 | def __init__( 43 | self, 44 | inplanes: int, 45 | planes: int, 46 | stride: int = 1, 47 | downsample: Optional[nn.Module] = None, 48 | groups: int = 1, 49 | base_width: int = 64, 50 | dilation: int = 1, 51 | padding_mode: str = "zeros", 52 | norm_layer: Optional[Callable[..., nn.Module]] = None, 53 | ) -> None: 54 | super().__init__() 55 | if norm_layer is None: 56 | norm_layer = nn.BatchNorm2d 57 | if groups != 1 or base_width != 64: 58 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 59 | if dilation > 1: 60 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 61 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 62 | self.conv1 = conv3x3(inplanes, planes, stride, padding_mode=padding_mode) 63 | self.bn1 = norm_layer(planes) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.conv2 = conv3x3(planes, planes, padding_mode=padding_mode) 66 | self.bn2 = norm_layer(planes) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x: Tensor) -> Tensor: 71 | identity = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | 80 | if self.downsample is not None: 81 | identity = self.downsample(x) 82 | 83 | out += identity 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class Bottleneck(nn.Module): 90 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 91 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 92 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 93 | # This variant is also known as ResNet V1.5 and improves accuracy according to 94 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 95 | 96 | expansion: int = 4 97 | 98 | def __init__( 99 | self, 100 | inplanes: int, 101 | planes: int, 102 | stride: int = 1, 103 | downsample: Optional[nn.Module] = None, 104 | groups: int = 1, 105 | base_width: int = 64, 106 | dilation: int = 1, 107 | padding_mode: str = "zeros", 108 | norm_layer: Optional[Callable[..., nn.Module]] = None, 109 | ) -> None: 110 | super().__init__() 111 | if norm_layer is None: 112 | norm_layer = nn.BatchNorm2d 113 | width = int(planes * (base_width / 64.0)) * groups 114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 115 | self.conv1 = conv1x1(inplanes, width) 116 | self.bn1 = norm_layer(width) 117 | self.conv2 = conv3x3(width, width, stride, groups, dilation, padding_mode=padding_mode) 118 | self.bn2 = norm_layer(width) 119 | self.conv3 = conv1x1(width, planes * self.expansion) 120 | self.bn3 = norm_layer(planes * self.expansion) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.downsample = downsample 123 | self.stride = stride 124 | 125 | def forward(self, x: Tensor) -> Tensor: 126 | identity = x 127 | 128 | out = self.conv1(x) 129 | out = self.bn1(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv2(out) 133 | out = self.bn2(out) 134 | out = self.relu(out) 135 | 136 | out = self.conv3(out) 137 | out = self.bn3(out) 138 | 139 | if self.downsample is not None: 140 | identity = self.downsample(x) 141 | 142 | out += identity 143 | out = self.relu(out) 144 | 145 | return out 146 | 147 | 148 | class ResNet(nn.Module): 149 | """ResNet model from :footcite:t:`he2016deep`. 150 | 151 | Args: 152 | block (Type[Union[BasicBlock, Bottleneck]]): Block type. 153 | layers (List[int]): Number of layers. 154 | channels_in (int): Number of input channels. 155 | num_features (int): Number of features. 156 | zero_init_residual (bool): Zero initialization of residual. 157 | groups (int): Number of groups. 158 | width_per_group (int): Width per group. 159 | replace_stride_with_dilation (Optional[List[bool]]): Replace stride with dilation. 160 | padding_mode (str): Padding mode for the convolutional layers. 161 | norm_layer (Optional[Callable[..., nn.Module]]): Normalization layer. 162 | """ 163 | 164 | def __init__( 165 | self, 166 | block: Type[Union[BasicBlock, Bottleneck]], 167 | layers: List[int], 168 | channels_in: int = 3, 169 | num_features: int = 1024, 170 | zero_init_residual: bool = False, 171 | groups: int = 1, 172 | width_per_group: int = 64, 173 | replace_stride_with_dilation: Optional[List[bool]] = None, 174 | padding_mode: str = "zeros", 175 | norm_layer: Optional[Callable[..., nn.Module]] = None, 176 | ) -> None: 177 | super().__init__() 178 | if norm_layer is None: 179 | norm_layer = nn.BatchNorm2d 180 | self._norm_layer = norm_layer 181 | 182 | self.inplanes = 64 183 | self.dilation = 1 184 | if replace_stride_with_dilation is None: 185 | # each element in the tuple indicates if we should replace 186 | # the 2x2 stride with a dilated convolution instead 187 | replace_stride_with_dilation = [False, False, False] 188 | if len(replace_stride_with_dilation) != 3: 189 | raise ValueError( 190 | "replace_stride_with_dilation should be None " 191 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 192 | ) 193 | self.groups = groups 194 | self.base_width = width_per_group 195 | self.conv1 = nn.Conv2d( 196 | channels_in, 197 | self.inplanes, 198 | kernel_size=7, 199 | stride=2, 200 | padding=3, 201 | bias=False, 202 | padding_mode=padding_mode, 203 | ) 204 | self.bn1 = norm_layer(self.inplanes) 205 | self.relu = nn.ReLU(inplace=True) 206 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 207 | self.layer1 = self._make_layer(block, 64, layers[0], padding_mode=padding_mode) 208 | self.layer2 = self._make_layer( 209 | block, 210 | 128, 211 | layers[1], 212 | stride=2, 213 | dilate=replace_stride_with_dilation[0], 214 | padding_mode=padding_mode, 215 | ) 216 | self.layer3 = self._make_layer( 217 | block, 218 | 256, 219 | layers[2], 220 | stride=2, 221 | dilate=replace_stride_with_dilation[1], 222 | padding_mode=padding_mode, 223 | ) 224 | self.layer4 = self._make_layer( 225 | block, 226 | 512, 227 | layers[3], 228 | stride=2, 229 | dilate=replace_stride_with_dilation[2], 230 | padding_mode=padding_mode, 231 | ) 232 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 233 | self.fc = nn.Linear(512 * block.expansion, num_features, bias=False) 234 | 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 238 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 239 | nn.init.constant_(m.weight, 1) 240 | nn.init.constant_(m.bias, 0) 241 | 242 | # Zero-initialize the last BN in each residual branch, 243 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 244 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 245 | if zero_init_residual: 246 | for m in self.modules(): 247 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 248 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 249 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 250 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 251 | 252 | def _make_layer( 253 | self, 254 | block: Type[Union[BasicBlock, Bottleneck]], 255 | planes: int, 256 | blocks: int, 257 | stride: int = 1, 258 | dilate: bool = False, 259 | padding_mode: str = "zeros", 260 | ) -> nn.Sequential: 261 | norm_layer = self._norm_layer 262 | downsample = None 263 | previous_dilation = self.dilation 264 | if dilate: 265 | self.dilation *= stride 266 | stride = 1 267 | if stride != 1 or self.inplanes != planes * block.expansion: 268 | downsample = nn.Sequential( 269 | conv1x1(self.inplanes, planes * block.expansion, stride), 270 | norm_layer(planes * block.expansion), 271 | ) 272 | 273 | layers = [] 274 | layers.append( 275 | block( 276 | self.inplanes, 277 | planes, 278 | stride, 279 | downsample, 280 | self.groups, 281 | self.base_width, 282 | previous_dilation, 283 | padding_mode, 284 | norm_layer, 285 | ) 286 | ) 287 | self.inplanes = planes * block.expansion 288 | for _ in range(1, blocks): 289 | layers.append( 290 | block( 291 | self.inplanes, 292 | planes, 293 | groups=self.groups, 294 | base_width=self.base_width, 295 | dilation=self.dilation, 296 | norm_layer=norm_layer, 297 | padding_mode=padding_mode, 298 | ) 299 | ) 300 | 301 | return nn.Sequential(*layers) 302 | 303 | def _forward_impl(self, x: Tensor) -> Tensor: 304 | # See note [TorchScript super()] 305 | x = self.conv1(x) 306 | x = self.bn1(x) 307 | x = self.relu(x) 308 | x = self.maxpool(x) 309 | 310 | x = self.layer1(x) 311 | x = self.layer2(x) 312 | x = self.layer3(x) 313 | x = self.layer4(x) 314 | 315 | x = self.avgpool(x) 316 | x = torch.flatten(x, 1) 317 | x = self.fc(x) 318 | return x 319 | 320 | def forward(self, x: Tensor) -> Tensor: 321 | """Forward pass of the ResNet model.""" 322 | return self._forward_impl(x) 323 | 324 | 325 | def _resnet( 326 | block: Type[Union[BasicBlock, Bottleneck]], 327 | layers: List[int], 328 | **kwargs: Any, 329 | ) -> ResNet: 330 | model = ResNet(block, layers, **kwargs) 331 | 332 | return model 333 | 334 | 335 | def resnet18(*args, **kwargs: Any) -> ResNet: 336 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__.""" 337 | return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs) 338 | 339 | 340 | def resnet34(*args, **kwargs: Any) -> ResNet: 341 | """ResNet-34 from `Deep Residual Learning for Image Recognition `__.""" 342 | return _resnet(BasicBlock, [3, 4, 6, 3], **kwargs) 343 | 344 | 345 | def resnet50(*args, **kwargs: Any) -> ResNet: 346 | """ResNet-50 from `Deep Residual Learning for Image Recognition `__.""" 347 | return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs) 348 | 349 | 350 | def resnet101(*args, **kwargs: Any) -> ResNet: 351 | """ResNet-101 from `Deep Residual Learning for Image Recognition `__.""" 352 | return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs) 353 | 354 | 355 | def resnet152(*args, **kwargs: Any) -> ResNet: 356 | """ResNet-101 from `Deep Residual Learning for Image Recognition `__.""" 357 | return _resnet(Bottleneck, [3, 8, 36, 3], **kwargs) 358 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/modules/simnorm.py: -------------------------------------------------------------------------------- 1 | """Simplicial normalizarion.""" 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | 8 | class SimNorm(nn.Module): 9 | """Simplicial normalization from :footcite:t:`lavoie2022simplicial`. 10 | 11 | Simplicial normalization splits the input into chunks of dimension :code:`dim`, applies a softmax transformation to each of the chunks separately, and concatenates them back together. 12 | 13 | Args: 14 | dim (int): Dimension of the simplicial groups. 15 | """ 16 | 17 | def __init__(self, dim: int = 8): 18 | super().__init__() 19 | self.dim = dim 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | """Forward pass of the simplicial normalization module.""" 23 | shp = x.shape 24 | x = x.view(*shp[:-1], -1, self.dim) 25 | x = F.softmax(x, dim=-1) 26 | return x.view(*shp) 27 | 28 | def __repr__(self): 29 | """String representation of the simplicial norm module.""" 30 | return f"SimNorm(dim={self.dim})" 31 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/regressors.py: -------------------------------------------------------------------------------- 1 | """NN regressors.""" 2 | 3 | from typing import Literal 4 | 5 | import numpy as np 6 | import scipy.linalg 7 | import torch 8 | from torch import Tensor 9 | 10 | from linear_operator_learning.nn.structs import EigResult, FitResult 11 | 12 | 13 | def ridge_least_squares( 14 | cov_X: Tensor, 15 | tikhonov_reg: float = 0.0, 16 | ) -> FitResult: 17 | """Fit the ridge least squares estimator for the transfer operator. 18 | 19 | Args: 20 | cov_X (Tensor): covariance matrix of the input data. 21 | tikhonov_reg (float, optional): Ridge regularization. Defaults to 0.0. 22 | 23 | """ 24 | dim = cov_X.shape[0] 25 | reg_input_covariance = cov_X + tikhonov_reg * torch.eye( 26 | dim, dtype=cov_X.dtype, device=cov_X.device 27 | ) 28 | values, vectors = torch.linalg.eigh(reg_input_covariance) 29 | # Divide columns of vectors by square root of eigenvalues 30 | rsqrt_evals = 1.0 / torch.sqrt(values + 1e-10) 31 | Q = vectors @ torch.diag(rsqrt_evals) 32 | result: FitResult = FitResult({"U": Q, "V": Q, "svals": values}) 33 | return result 34 | 35 | 36 | def eig( 37 | fit_result: FitResult, 38 | cov_XY: Tensor, 39 | ) -> EigResult: 40 | """Computes the eigendecomposition of a regressor. 41 | 42 | Args: 43 | fit_result (FitResult): Fit result as defined in ``linear_operator_learning.nn.structs``. 44 | cov_XY (Tensor): Cross covariance matrix between the input and output data. 45 | 46 | 47 | Shape: 48 | ``cov_XY``: :math:`(D, D)`, where :math:`D` is the number of features. 49 | 50 | Output: ``U, V`` of shape :math:`(D, R)`, ``svals`` of shape :math:`R` 51 | where :math:`D` is the number of features and :math:`R` is the rank of the regressor. 52 | """ 53 | dtype_and_device = { 54 | "dtype": cov_XY.dtype, 55 | "device": cov_XY.device, 56 | } 57 | U = fit_result["U"] 58 | # Using the trick described in https://arxiv.org/abs/1905.11490 59 | M = torch.linalg.multi_dot([U.T, cov_XY, U]) 60 | # Convertion to numpy 61 | M = M.numpy(force=True) 62 | values, lv, rv = scipy.linalg.eig(M, left=True, right=True) 63 | r_perm = torch.tensor(np.argsort(values), device=cov_XY.device) 64 | l_perm = torch.tensor(np.argsort(values.conj()), device=cov_XY.device) 65 | values = values[r_perm] 66 | # Back to torch, casting to appropriate dtype and device 67 | values = torch.complex( 68 | torch.tensor(values.real, **dtype_and_device), torch.tensor(values.imag, **dtype_and_device) 69 | ) 70 | lv = torch.complex( 71 | torch.tensor(lv.real, **dtype_and_device), torch.tensor(lv.imag, **dtype_and_device) 72 | ) 73 | rv = torch.complex( 74 | torch.tensor(rv.real, **dtype_and_device), torch.tensor(rv.imag, **dtype_and_device) 75 | ) 76 | # Normalization in RKHS norm 77 | rv = U.to(rv.dtype) @ rv 78 | rv = rv[:, r_perm] 79 | rv = rv / torch.linalg.norm(rv, axis=0) 80 | # # Biorthogonalization 81 | lv = torch.linalg.multi_dot([cov_XY.T.to(lv.dtype), U.to(lv.dtype), lv]) 82 | lv = lv[:, l_perm] 83 | l_norm = torch.sum(lv * rv, axis=0) 84 | lv = lv / l_norm 85 | result: EigResult = EigResult({"values": values, "left": lv, "right": rv}) 86 | return result 87 | 88 | 89 | def evaluate_eigenfunction( 90 | eig_result: EigResult, 91 | which: Literal["left", "right"], 92 | X: Tensor, 93 | ): 94 | """Evaluates left or right eigenfunctions of a regressor. 95 | 96 | Args: 97 | eig_result: EigResult object containing eigendecomposition results 98 | which: String indicating "left" or "right" eigenfunctions. 99 | X: Feature map of the input data 100 | 101 | 102 | Shape: 103 | ``eig_results``: ``U, V`` of shape :math:`(D, R)`, ``svals`` of shape :math:`R` 104 | where :math:`D` is the number of features and :math:`R` is the rank of the regressor. 105 | 106 | ``X``: :math:`(N_0, D)`, where :math:`N_0` is the number of inputs to predict and :math:`D` is the number of features. 107 | 108 | Output: :math:`(N_0, R)` 109 | """ 110 | vr_or_vl = eig_result[which] 111 | return X.to(vr_or_vl.dtype) @ vr_or_vl 112 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/stats.py: -------------------------------------------------------------------------------- 1 | """Statistics utilities for multi-variate random variables.""" 2 | 3 | from math import sqrt 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from linear_operator_learning.nn.linalg import filter_reduced_rank_svals, sqrtmh 9 | 10 | 11 | def covariance( 12 | X: Tensor, 13 | Y: Tensor | None = None, 14 | center: bool = True, 15 | norm: float | None = None, 16 | ) -> Tensor: 17 | """Computes the covariance of X or cross-covariance between X and Y if Y is given. 18 | 19 | Args: 20 | X (Tensor): Input features. 21 | Y (Tensor | None, optional): Output features. Defaults to None. 22 | center (bool, optional): Whether to compute centered covariances. Defaults to True. 23 | norm (float | None, optional): Normalization factor. Defaults to None. 24 | 25 | Shape: 26 | ``X``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 27 | 28 | ``Y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 29 | 30 | Output: :math:`(D, D)`, where :math:`D` is the number of features. 31 | """ 32 | assert X.ndim == 2 33 | if norm is None: 34 | norm = sqrt(X.shape[0]) 35 | else: 36 | assert norm > 0 37 | norm = sqrt(norm) 38 | if Y is None: 39 | X = X / norm 40 | if center: 41 | X = X - X.mean(dim=0, keepdim=True) 42 | return torch.mm(X.T, X) 43 | else: 44 | assert Y.ndim == 2 45 | X = X / norm 46 | Y = Y / norm 47 | if center: 48 | X = X - X.mean(dim=0, keepdim=True) 49 | Y = Y - Y.mean(dim=0, keepdim=True) 50 | return torch.mm(X.T, Y) 51 | 52 | 53 | def cross_cov_norm_squared_unbiased(x: Tensor, y: Tensor, permutation=None): 54 | r"""Compute the unbiased estimation of :math:`\|\mathbf{C}_{xy}\|_F^2` from a batch of samples, using U-statistics. 55 | 56 | Given the Covariance matrix :math:`\mathbf{C}_{xy} = \mathbb{E}_p(x,y) [x^{\top} y]`, this function computes an unbiased estimation 57 | of the Frobenius norm of the covariance matrix from two independent sampling sets (an effective samples size of :math:`N^2`). 58 | 59 | .. math:: 60 | 61 | \begin{align} 62 | \|\mathbf{C}_{xy}\|_F^2 &= \text{tr}(\mathbf{C}_{xy}^{\top} \mathbf{C}_{xy}) 63 | = \sum_i \sum_j (\mathbb{E}_{x,y \sim p(x,y)} [x_i y_j]) (\mathbb{E}_{x',y' \sim p(x,y)} [x_j y_i']) \\ 64 | &= \mathbb{E}_{(x,y),(x',y') \sim p(x,y)} [(x^{\top} y') (x'^{T} y)] \\ 65 | &\approx \frac{1}{N^2} \sum_n \sum_m [(x_{n}^{\top} y^{\prime}_m) (x^{\prime \top}_m y_n)] 66 | \end{align} 67 | 68 | .. note:: 69 | The random variable is assumed to be centered. 70 | 71 | Args: 72 | x (Tensor): Centered realizations of a random variable `x` of shape (N, D_x). 73 | y (Tensor): Centered realizations of a random variable `y` of shape (N, D_y). 74 | permutation (Tensor, optional): List of integer indices of shape (n_samples,) used to permute the samples. 75 | 76 | Returns: 77 | Tensor: Unbiased estimation of :math:`\|\mathbf{C}_{xy}\|_F^2` using U-statistics. 78 | """ 79 | n_samples = x.shape[0] 80 | 81 | # Permute the rows independently to simulate independent sampling 82 | perm = permutation if permutation is not None else torch.randperm(n_samples) 83 | assert perm.shape == (n_samples,), f"Invalid permutation {perm.shape}!=({n_samples},)" 84 | xp = x[perm] # Independent sampling of x' 85 | yp = y[perm] # Independent sampling of y' 86 | 87 | # Compute 1/N^2 Σ_n Σ_m [(x_n.T y'_m) (x'_m.T y_n)] 88 | val = torch.einsum("nj,mj,mk,nk->", x, yp, xp, y) 89 | cov_fro_norm = val / (n_samples**2) 90 | return cov_fro_norm 91 | 92 | 93 | def cov_norm_squared_unbiased(x: Tensor, permutation=None): 94 | r"""Compute the unbiased estimation of :math:`\|\mathbf{C}_x\|_F^2` from a batch of samples. 95 | 96 | Given the Covariance matrix :math:`\mathbf{C}_x = \mathbb{E}_p(x) [x^{\top} x]`, this function computes an unbiased estimation 97 | of the Frobenius norm of the covariance matrix from a single sampling set. 98 | 99 | .. math:: 100 | 101 | \begin{align} 102 | \|\mathbf{C}_x\|_F^2 &= \text{tr}(\mathbf{C}_x^{\top} \mathbf{C}_x) = \sum_i \sum_j (\mathbb{E}_{x} [x_i x_j]) (\mathbb{E}_{x'} [x'_j x'_i]) \\ 103 | &= \mathbb{E}_{x,x' \sim p(x)} [(x^{\top} x')^2] \\ 104 | &\approx \frac{1}{N^2} \sum_n \sum_m [(x_n^{\top} x'_m)^2] 105 | \end{align} 106 | 107 | 108 | .. note:: 109 | 110 | The random variable is assumed to be centered. 111 | 112 | Args: 113 | x (Tensor): (n_samples, r) Centered realizations of a random variable x = [x_1, ..., x_r]. 114 | permutation (Tensor, optional): List of integer indices of shape (n_samples,) used to permute the samples. 115 | 116 | Returns: 117 | Tensor: Unbiased estimation of :math:`\|\mathbf{C}_x\|_F^2` using U-statistics. 118 | """ 119 | return cross_cov_norm_squared_unbiased(x=x, y=x, permutation=permutation) 120 | 121 | 122 | def whitening(u: Tensor, v: Tensor) -> tuple: 123 | """Computes whitening matrices for ``u`` and ``v``. 124 | 125 | Args: 126 | u (Tensor): Input features. 127 | v (Tensor): Output features. 128 | 129 | 130 | Shape: 131 | ``u``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 132 | 133 | ``v``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. 134 | 135 | ``sqrt_cov_u_inv``: :math:`(D, D)` 136 | 137 | ``sqrt_cov_v_inv``: :math:`(D, D)` 138 | 139 | ``sing_val``: :math:`(D,)` 140 | 141 | ``sing_vec_l``: :math:`(D, D)` 142 | 143 | ``sing_vec_r``: :math:`(D, D)` 144 | """ 145 | cov_u = covariance(u) 146 | cov_v = covariance(v) 147 | cov_uv = covariance(u, v) 148 | 149 | sqrt_cov_u_inv = torch.linalg.pinv(sqrtmh(cov_u)) 150 | sqrt_cov_v_inv = torch.linalg.pinv(sqrtmh(cov_v)) 151 | 152 | M = sqrt_cov_u_inv @ cov_uv @ sqrt_cov_v_inv 153 | e_val, sing_vec_l = torch.linalg.eigh(M @ M.T) 154 | e_val, sing_vec_l = filter_reduced_rank_svals(e_val, sing_vec_l) 155 | sing_val = torch.sqrt(e_val) 156 | sing_vec_r = (M.T @ sing_vec_l) / sing_val 157 | 158 | return sqrt_cov_u_inv, sqrt_cov_v_inv, sing_val, sing_vec_l, sing_vec_r 159 | -------------------------------------------------------------------------------- /linear_operator_learning/nn/structs.py: -------------------------------------------------------------------------------- 1 | """Structs used by the `nn` algorithms.""" 2 | 3 | from typing import TypedDict 4 | 5 | from torch import Tensor 6 | 7 | 8 | class FitResult(TypedDict): 9 | """Return type for nn regressors.""" 10 | 11 | U: Tensor 12 | V: Tensor 13 | svals: Tensor | None 14 | 15 | 16 | class EigResult(TypedDict): 17 | """Return type for eigenvalue decompositions of nn regressors.""" 18 | 19 | values: Tensor 20 | left: Tensor | None 21 | right: Tensor 22 | -------------------------------------------------------------------------------- /linear_operator_learning/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/linear_operator_learning/py.typed -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/logo.png -------------------------------------------------------------------------------- /logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "linear-operator-learning" 3 | version = "0.2.5" 4 | description = "A package to learn linear operators" 5 | readme = "README.md" 6 | authors = [ 7 | { name = "Alek Frohlich", email = "alek.frohlich@gmail.com" }, 8 | { name = "Pietro Novelli", email = "pietronvll@gmail.com"} 9 | ] 10 | requires-python = ">=3.10" 11 | dependencies = [ 12 | "numpy>=1.26", 13 | "scipy>=1.12", 14 | "torch>=2.0", 15 | ] 16 | 17 | [build-system] 18 | requires = ["hatchling"] 19 | build-backend = "hatchling.build" 20 | 21 | [dependency-groups] 22 | dev = [ 23 | "bumpver>=2024.1130", 24 | "pre-commit>=4.1.0", 25 | "pytest>=8.3.4", 26 | "ruff>=0.9.4", 27 | "ty>=0.0.1a7", 28 | ] 29 | docs = [ 30 | "myst-parser>=4.0.0", 31 | "setuptools>=75.8.0", 32 | "sphinx-design>=0.6.1", 33 | "sphinx>=8.1.3", 34 | "sphinx-autobuild>=2024.10.3", 35 | "sphinxawesome-theme>=5.3.2", 36 | "sphinxcontrib-applehelp==2.0.0", 37 | "sphinxcontrib-bibtex>=2.6.3", 38 | "sphinxcontrib-jsmath==1.0.1", 39 | "myst-nb>=1.2.0", 40 | ] 41 | examples = [ 42 | "ipykernel>=6.29.5", 43 | "lightning>=2.5.0.post0", 44 | "loguru>=0.7.3", 45 | "matplotlib>=3.10.0", 46 | "scikit-learn>=1.6.1", 47 | "seaborn>=0.13.2", 48 | ] 49 | 50 | [tool.ruff] 51 | line-length = 100 52 | exclude = ["docs"] 53 | 54 | [tool.ruff.lint] 55 | select = [ 56 | "D", # pydocstyle rules, limiting to those that adhere to the google convention 57 | "E4", # Errors 58 | "E7", 59 | "E9", 60 | "F", # Pyflakes 61 | "I" 62 | ] 63 | 64 | ignore = [ 65 | "F401", # Don't remove unused imports 66 | "D107", # Document __init__ arguments inside class docstring 67 | ] 68 | 69 | 70 | 71 | [tool.ruff.lint.pydocstyle] 72 | convention = "google" 73 | 74 | [bumpver] 75 | current_version = "0.2.5" 76 | version_pattern = "MAJOR.MINOR.PATCH" 77 | 78 | [bumpver.file_patterns] 79 | "pyproject.toml" = [ 80 | 'version = "{version}"' 81 | ] 82 | --------------------------------------------------------------------------------