├── .DS_Store
├── .github
└── workflows
│ └── version-bump.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── docs
├── .DS_Store
├── Makefile
├── _static
│ ├── .DS_Store
│ ├── custom.css
│ └── logo.png
├── _templates
│ ├── header.html
│ └── sidebar_toc.html
├── bibliography.bib
├── conf.py
├── contributing.md
├── examples
├── favicon.ico
├── index.rst
├── make.bat
├── reference
│ ├── kernel
│ │ ├── index.rst
│ │ └── kernel.rst
│ └── nn
│ │ ├── functional.rst
│ │ ├── index.rst
│ │ ├── linalg.rst
│ │ ├── nn.rst
│ │ └── stats.rst
└── requirements.txt
├── examples
├── detecting_independence
│ ├── .gitignore
│ └── detecting_independence.ipynb
└── lorenz63
│ ├── .gitignore
│ └── main.py
├── linear_operator_learning
├── __init__.py
├── kernel
│ ├── __init__.py
│ ├── linalg.py
│ ├── regressors.py
│ ├── structs.py
│ └── utils.py
├── nn
│ ├── __init__.py
│ ├── functional.py
│ ├── linalg.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── ema_covariance.py
│ │ ├── loss.py
│ │ ├── mlp.py
│ │ ├── resnet.py
│ │ └── simnorm.py
│ ├── regressors.py
│ ├── stats.py
│ └── structs.py
└── py.typed
├── logo.png
├── logo.svg
├── pyproject.toml
└── uv.lock
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/.DS_Store
--------------------------------------------------------------------------------
/.github/workflows/version-bump.yml:
--------------------------------------------------------------------------------
1 | name: Version Bump
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | # References:
9 | # https://docs.astral.sh/uv/guides/integration/github/
10 | # https://stackoverflow.com/a/58393457/10344434
11 | # https://github.com/mbarkhau/bumpver?tab=readme-ov-file#reference
12 | jobs:
13 | bump-version:
14 | name: Bump package version
15 | if: "!contains(github.event.head_commit.message, 'Bump version')"
16 | runs-on: ubuntu-latest
17 |
18 | permissions:
19 | contents: write
20 | id-token: write
21 |
22 | steps:
23 | - name: Checkout repository
24 | uses: actions/checkout@v4
25 |
26 | - name: Install uv
27 | uses: astral-sh/setup-uv@v5
28 |
29 | - name: Set up Python
30 | run: uv python install
31 |
32 | - name: Install the project
33 | run: uv sync --all-extras --dev
34 |
35 | - name: Bump version
36 | run: uv run bumpver update --patch
37 |
38 | - name: Bump uv lock
39 | run: uv lock --upgrade-package linear-operator-learning
40 |
41 | - name: Commit
42 | run: |
43 | git config --global user.name 'Alek Frohlich'
44 | git config --global user.email 'alekfrohlich@users.noreply.github.com'
45 | git commit -am "Bump version"
46 | git push
47 |
48 | - name: Build and Publish to PyPI
49 | run: |
50 | uv build
51 | uv publish --trusted-publishing always
52 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | .idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | # Ruff version.
4 | rev: v0.9.4
5 | hooks:
6 | # Run the linter.
7 | - id: ruff
8 | types_or: [ python, pyi ]
9 | args: [ --fix ]
10 | # Run the formatter.
11 | - id: ruff-format
12 | types_or: [ python, pyi ]
13 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
2 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the OS, Python version, and other tools you might need
8 | build:
9 | os: ubuntu-24.04
10 | tools:
11 | python: "3.11"
12 |
13 | # Build documentation in the "docs/" directory with Sphinx
14 | sphinx:
15 | configuration: docs/conf.py
16 |
17 | # Optionally, but recommended,
18 | # declare the Python requirements required to build your documentation
19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
20 | python:
21 | install:
22 | - requirements: docs/requirements.txt
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 CSML @ IIT
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## Install
6 |
7 | To install this package as a dependency, run:
8 |
9 | ```bash
10 | pip install linear-operator-learning
11 | ```
12 |
13 | ## Development
14 |
15 | To develop this project, please setup the [`uv` project manager](https://astral.sh/uv) by running the following commands:
16 |
17 | ```bash
18 | curl -LsSf https://astral.sh/uv/install.sh | sh
19 | git clone git@github.com:CSML-IIT-UCL/linear_operator_learning.git
20 | cd linear_operator_learning
21 | uv sync --dev
22 | uv run pre-commit install
23 | ```
24 |
25 | ### Optional
26 | Set up your IDE to automatically apply the `ruff` styling.
27 | - [VS Code](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff)
28 | - [PyCharm](https://plugins.jetbrains.com/plugin/20574-ruff)
29 |
30 | ## Development principles
31 |
32 | Please adhere to the following principles while contributing to the project:
33 |
34 | 1. Adopt a functional style of programming. Avoid abstractions (classes) at all cost.
35 | 2. To add a new feature, create a branch and when done open a Pull Request. You should _**not**_ approve your own PRs.
36 | 3. The package contains both `numpy` and `torch` based algorithms. Let's keep them separated.
37 | 4. The functions shouldn't change the `dtype` or device of the inputs (that is, keep a functional approach).
38 | 5. Try to complement your contributions with simple examples to be added in the `examples` folder. If you need some additional dependency add it to the `examples` dependency group as `uv add --group examples _your_dependency_`.
--------------------------------------------------------------------------------
/docs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/docs/.DS_Store
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/docs/_static/.DS_Store
--------------------------------------------------------------------------------
/docs/_static/custom.css:
--------------------------------------------------------------------------------
1 | h2 {
2 | border-bottom-width: 0px !important;
3 | }
4 |
5 | img {
6 | border-radius: 10px !important;
7 | }
8 |
9 | :root {
10 | --mystnb-stdout-bg-color: theme(colors.background);
11 | --mystnb-source-bg-color: theme(colors.background);
12 | }
13 |
14 | div.cell_input {
15 | border-color: hsl(var(--border)) !important;
16 | border-width: 1px !important;
17 | border-radius: var(--radius) !important;
18 | }
19 |
20 | div.cell_output .output {
21 | border: none !important;
22 | }
--------------------------------------------------------------------------------
/docs/_static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/docs/_static/logo.png
--------------------------------------------------------------------------------
/docs/_templates/header.html:
--------------------------------------------------------------------------------
1 | {% extends "!header.html" %}
2 |
3 | {% block header_logo %}
4 |
94 |
95 |
96 | {%- if logo_url %}
97 |
98 | {%- endif -%}
99 | {%- if theme_logo_dark and not logo_url %}
100 |
102 | {%- endif -%}
103 | {%- if theme_logo_light and not logo_url %}
104 |
106 | {%- endif -%}
107 | {{ docstitle }}
108 |
109 |
110 | {% endblock header_logo %}
111 |
112 | {%- block header_right %}
113 |
114 | {%- if docsearch or hasdoc('search') %}
115 |
116 |
130 |
159 |
160 |
161 | {%- include "searchbox.html" %}
162 |
163 |
164 |
184 |
185 | {%- endif %}
186 |
187 | {%- block extra_header_link_icons %}
188 |
220 | {%- endblock extra_header_link_icons %}
221 |
222 | {%- endblock header_right %}
--------------------------------------------------------------------------------
/docs/_templates/sidebar_toc.html:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/bibliography.bib:
--------------------------------------------------------------------------------
1 |
2 | @article{Kostic2022,
3 | title={Learning dynamical systems via koopman operator regression in reproducing kernel hilbert spaces},
4 | author={Kostic, Vladimir and Novelli, Pietro and Maurer, Andreas and Ciliberto, Carlo and Rosasco, Lorenzo and Pontil, Massimiliano},
5 | journal={Advances in Neural Information Processing Systems},
6 | volume={35},
7 | pages={4017--4031},
8 | year={2022}
9 | }
10 |
11 | @inproceedings{he2016deep,
12 | title={Deep residual learning for image recognition},
13 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
14 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
15 | pages={770--778},
16 | year={2016}
17 | }
18 |
19 | @article{lavoie2022simplicial,
20 | title={Simplicial embeddings in self-supervised learning and downstream classification},
21 | author={Lavoie, Samuel and Tsirigotis, Christos and Schwarzer, Max and Vani, Ankit and Noukhovitch, Michael and Kawaguchi, Kenji and Courville, Aaron},
22 | journal={arXiv preprint arXiv:2204.00616},
23 | year={2022}
24 | }
25 |
26 | @inproceedings{Kostic2024NCP,
27 | author = {Kostic, Vladimir and Pacreau, Gr\'{e}goire and Turri, Giacomo and Novelli, Pietro and Lounici, Karim and Pontil, Massimiliano},
28 | booktitle = {Advances in Neural Information Processing Systems},
29 | title = {Neural Conditional Probability for Uncertainty Quantification},
30 | url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/705b97ecb07ae86524d438abac97a3e2-Paper-Conference.pdf},
31 | year = {2024}
32 | }
33 |
34 |
35 | @inproceedings{Kostic2023SpectralRates,
36 | title={Sharp Spectral Rates for Koopman Operator Learning},
37 | author={Vladimir R Kostic and Karim Lounici and Pietro Novelli and Massimiliano Pontil},
38 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
39 | year={2023},
40 | url={https://openreview.net/forum?id=Lt3jqxsbVO}
41 | }
42 |
43 | @inproceedings{
44 | Kostic2023DPNets,
45 | title={Learning invariant representations of time-homogeneous stochastic dynamical systems},
46 | author={Vladimir R. Kostic and Pietro Novelli and Riccardo Grazzi and Karim Lounici and Massimiliano Pontil},
47 | booktitle={The Twelfth International Conference on Learning Representations},
48 | year={2024},
49 | url={https://openreview.net/forum?id=twSnZwiOIm}
50 | }
51 |
52 | @inproceedings{Fathi2023,
53 | title={Course Correcting Koopman Representations},
54 | author={Mahan Fathi and Clement Gehring and Jonathan Pilault and David Kanaa and Pierre-Luc Bacon and Ross Goroshin},
55 | booktitle={The Twelfth International Conference on Learning Representations},
56 | year={2024},
57 | url={https://openreview.net/forum?id=A18gWgc5mi}
58 | }
59 |
60 | @inproceedings{Azencot2020CAE,
61 | title={Forecasting sequential data using consistent Koopman autoencoders},
62 | author={Azencot, Omri and Erichson, N Benjamin and Lin, Vanessa and Mahoney, Michael},
63 | booktitle={International Conference on Machine Learning},
64 | pages={475--485},
65 | year={2020},
66 | organization={PMLR}
67 | }
68 |
69 | @article{Arbabi2017,
70 | doi = {10.1137/17m1125236},
71 | url = {https://doi.org/10.1137/17m1125236},
72 | year = {2017},
73 | month = jan,
74 | publisher = {Society for Industrial {\&} Applied Mathematics ({SIAM})},
75 | volume = {16},
76 | number = {4},
77 | pages = {2096--2126},
78 | author = {Hassan Arbabi and Igor Mezi{\'{c}}},
79 | title = {Ergodic Theory, Dynamic Mode Decomposition, and Computation of Spectral Properties of the Koopman Operator},
80 | journal = {{SIAM} Journal on Applied Dynamical Systems}
81 | }
82 |
83 | @inproceedings{Meanti2023,
84 | title={Estimating Koopman operators with sketching to provably learn large scale dynamical systems},
85 | author={Giacomo Meanti and Antoine Chatalic and Vladimir R Kostic and Pietro Novelli and Massimiliano Pontil and Lorenzo Rosasco},
86 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
87 | year={2023},
88 | url={https://openreview.net/forum?id=GItLpB1vhK}
89 | }
90 |
91 | @inproceedings{Morton2018,
92 | author = {Morton, Jeremy and Witherden, Freddie D. and Jameson, Antony and Kochenderfer, Mykel J.},
93 | title = {Deep Dynamical Modeling and Control of Unsteady Fluid Flows},
94 | year = {2018},
95 | publisher = {Curran Associates Inc.},
96 | address = {Red Hook, NY, USA},
97 | abstract = {The design of flow control systems remains a challenge due to the nonlinear nature of the equations that govern fluid flow. However, recent advances in computational fluid dynamics (CFD) have enabled the simulation of complex fluid flows with high accuracy, opening the possibility of using learning-based approaches to facilitate controller design. We present a method for learning the forced and unforced dynamics of airflow over a cylinder directly from CFD data. The proposed approach, grounded in Koopman theory, is shown to produce stable dynamical models that can predict the time evolution of the cylinder system over extended time horizons. Finally, by performing model predictive control with the learned dynamical models, we are able to find a straightforward, interpretable control law for suppressing vortex shedding in the wake of the cylinder.},
98 | booktitle = {Proceedings of the 32nd International Conference on Neural Information Processing Systems},
99 | pages = {9278–9288},
100 | numpages = {11},
101 | location = {Montr\'{e}al, Canada},
102 | series = {NIPS'18}
103 | }
104 |
105 | @article{Schmid2010,
106 | title = {Dynamic mode decomposition of numerical and experimental data},
107 | volume = {656},
108 | ISSN = {1469-7645},
109 | url = {http://dx.doi.org/10.1017/s0022112010001217},
110 | DOI = {10.1017/s0022112010001217},
111 | journal = {Journal of Fluid Mechanics},
112 | publisher = {Cambridge University Press (CUP)},
113 | author = {Schmid, Peter J.},
114 | year = {2010},
115 | month = jul,
116 | pages = {5–28}
117 | }
118 |
119 | @article{Williams2015_KDMD,
120 | title = {A kernel-based method for data-driven koopman spectral analysis},
121 | volume = {2},
122 | ISSN = {2158-2505},
123 | url = {http://dx.doi.org/10.3934/jcd.2015005},
124 | DOI = {10.3934/jcd.2015005},
125 | number = {2},
126 | journal = {Journal of Computational Dynamics},
127 | publisher = {American Institute of Mathematical Sciences (AIMS)},
128 | author = {O. Williams, Matthew and W. Rowley, Clarence and G. Kevrekidis, Ioannis},
129 | year = {2015},
130 | pages = {247–265}
131 | }
132 |
133 | @article{Williams2015_EDMD,
134 | title = {A Data–Driven Approximation of the Koopman Operator: Extending Dynamic Mode Decomposition},
135 | volume = {25},
136 | ISSN = {1432-1467},
137 | url = {http://dx.doi.org/10.1007/s00332-015-9258-5},
138 | DOI = {10.1007/s00332-015-9258-5},
139 | number = {6},
140 | journal = {Journal of Nonlinear Science},
141 | publisher = {Springer Science and Business Media LLC},
142 | author = {Williams, Matthew O. and Kevrekidis, Ioannis G. and Rowley, Clarence W.},
143 | year = {2015},
144 | month = jun,
145 | pages = {1307–1346}
146 | }
147 |
148 | @article{Lusch2018,
149 | doi = {10.1038/s41467-018-07210-0},
150 | url = {https://doi.org/10.1038/s41467-018-07210-0},
151 | year = {2018},
152 | month = nov,
153 | publisher = {Springer Science and Business Media {LLC}},
154 | volume = {9},
155 | number = {1},
156 | author = {Bethany Lusch and J. Nathan Kutz and Steven L. Brunton},
157 | title = {Deep learning for universal linear embeddings of nonlinear dynamics},
158 | journal = {Nature Communications}
159 | }
160 |
161 | @article{Wu2019,
162 | doi = {10.1007/s00332-019-09567-y},
163 | url = {https://doi.org/10.1007/s00332-019-09567-y},
164 | year = {2019},
165 | month = aug,
166 | publisher = {Springer Science and Business Media {LLC}},
167 | volume = {30},
168 | number = {1},
169 | pages = {23--66},
170 | author = {Hao Wu and Frank No{\'{e}}},
171 | title = {Variational Approach for Learning Markov Processes from Time Series Data},
172 | journal = {Journal of Nonlinear Science}
173 | }
174 | @article{Mardt2018,
175 | doi = {10.1038/s41467-017-02388-1},
176 | url = {https://doi.org/10.1038/s41467-017-02388-1},
177 | year = {2018},
178 | month = jan,
179 | publisher = {Springer Science and Business Media {LLC}},
180 | volume = {9},
181 | number = {1},
182 | author = {Andreas Mardt and Luca Pasquali and Hao Wu and Frank No{\'{e}}},
183 | title = {{VAMPnets} for deep learning of molecular kinetics},
184 | journal = {Nature Communications}
185 | }
186 |
187 | @misc{Turri2023,
188 | doi = {10.48550/ARXIV.2312.17348},
189 | url = {https://arxiv.org/abs/2312.17348},
190 | author = {Turri, Giacomo and Kostic, Vladimir and Novelli, Pietro and Pontil, Massimiliano},
191 | keywords = {Machine Learning (cs.LG), Numerical Analysis (math.NA), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Mathematics, FOS: Mathematics},
192 | title = {A randomized algorithm to solve reduced rank operator regression},
193 | publisher = {arXiv},
194 | year = {2023},
195 | copyright = {arXiv.org perpetual, non-exclusive license}
196 | }
197 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 | from sphinxawesome_theme.postprocess import Icons
9 |
10 | html_permalinks_icon = Icons.permalinks_icon # SVG as a string
11 |
12 | project = "Linear Operator Learning"
13 | copyright = "2025, Linear Operator Learning Team"
14 | author = "Linear Operator Learning Team"
15 |
16 | # -- General configuration ---------------------------------------------------
17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
18 |
19 | extensions = [
20 | "sphinx.ext.autodoc",
21 | "sphinx.ext.autosummary",
22 | "sphinx.ext.napoleon",
23 | 'sphinx.ext.viewcode',
24 | "sphinxawesome_theme",
25 | "sphinxcontrib.bibtex",
26 | "sphinx_design",
27 | "myst_nb"
28 | ]
29 |
30 | myst_enable_extensions = ["amsmath", "dollarmath", "html_image"]
31 |
32 | templates_path = ["_templates"]
33 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
34 |
35 | bibtex_bibfiles = ["bibliography.bib"]
36 |
37 |
38 | # -- Options for HTML output -------------------------------------------------
39 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
40 |
41 | html_theme = "sphinxawesome_theme"
42 | autodoc_class_signature = "separated"
43 | autoclass_content = "class"
44 |
45 | autodoc_typehints = "description"
46 | autodoc_member_order = "groupwise"
47 | napoleon_preprocess_types = True
48 | napoleon_use_rtype = False
49 | autodoc_mock_imports = ["escnn", "escnn.group"]
50 |
51 | master_doc = "index"
52 |
53 | source_suffix = {
54 | ".rst": "restructuredtext",
55 | ".txt": "restructuredtext",
56 | ".md": "myst-nb",
57 | ".ipynb": "myst-nb",
58 | }
59 |
60 | # Favicon configuration
61 | # html_favicon = '_static/favicon.ico'
62 |
63 | # Configure syntax highlighting for Awesome Sphinx Theme
64 | pygments_style = "tango"
65 | pygments_style_dark = "material"
66 |
67 | # Additional theme configuration
68 | html_title = "Linear Operator Learning"
69 | html_theme_options = {
70 | "show_prev_next": False,
71 | "show_scrolltop": True,
72 | "extra_header_link_icons": {
73 | "GitHub": {
74 | "link": "https://github.com/CSML-IIT-UCL/linear_operator_learning",
75 | "icon": """""",
76 | },
77 | },
78 | "show_breadcrumbs": True,
79 | }
80 |
81 | html_sidebars = {
82 | "**": ["sidebar_main_nav_links.html", "sidebar_toc.html"]
83 | }
84 | nb_execution_mode = "off"
85 |
86 | html_css_files = ["custom.css"]
87 |
88 | html_favicon = "favicon.ico"
89 | html_static_path = ["_static"]
90 | templates_path = ["_templates"]
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | (contributing)=
2 | # Developer guide
3 |
4 | To develop this project, please setup the [`uv` project manager](https://astral.sh/uv) by running the following commands:
5 |
6 | ```bash
7 | curl -LsSf https://astral.sh/uv/install.sh | sh
8 | git clone git@github.com:CSML-IIT-UCL/linear_operator_learning.git
9 | cd linear_operator_learning
10 | uv sync --dev
11 | uv run pre-commit install
12 | ```
13 |
14 | ### Optional
15 | Set up your IDE to automatically apply the `ruff` styling.
16 | - [VS Code](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff)
17 | - [PyCharm](https://plugins.jetbrains.com/plugin/20574-ruff)
18 |
19 | ## Development principles
20 |
21 | Please adhere to the following principles while contributing to the project:
22 |
23 | 1. Adopt a functional style of programming. Avoid abstractions (classes) at all cost.
24 | 2. To add a new feature, create a branch and when done open a Pull Request. You should _**not**_ approve your own PRs.
25 | 3. The package contains both `numpy` and `torch` based algorithms. Let's keep them separated.
26 | 4. The functions shouldn't change the `dtype` or device of the inputs (that is, keep a functional approach).
27 | 5. Try to complement your contributions with simple examples to be added in the `examples` folder. If you need some additional dependency add it to the `examples` dependency group as `uv add --group examples _your_dependency_`.
--------------------------------------------------------------------------------
/docs/examples:
--------------------------------------------------------------------------------
1 | ../examples
--------------------------------------------------------------------------------
/docs/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/docs/favicon.ico
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | ========================
2 | Linear Operator Learning
3 | ========================
4 |
5 | Linear Operator Learning — LOL for short — is package to learn linear operators, developed by the `Computational Statistics & Machine Learning lab `_ at the Italian Institute of Technology.
6 |
7 | |
8 |
9 | You can install it with :code:`pip` as
10 |
11 | .. code::
12 |
13 | pip install linear-operator-learning
14 |
15 | |
16 |
17 | If you want to contibute to the project, please follow :ref:`these guidelines `.
18 |
19 | .. toctree::
20 | :maxdepth: 2
21 | :caption: Getting Started
22 | :hidden:
23 |
24 | Quickstart
25 | contributing.md
26 |
27 | .. toctree::
28 | :maxdepth: 2
29 | :caption: Examples
30 | :hidden:
31 |
32 | Independence Testing
33 |
34 | .. toctree::
35 | :maxdepth: 2
36 | :caption: API Reference
37 | :hidden:
38 |
39 | reference/kernel/index
40 | reference/nn/index
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/reference/kernel/index.rst:
--------------------------------------------------------------------------------
1 | .. _kernel_index:
2 | ==============
3 | Kernel Methods
4 | ==============
5 |
6 | Algorithms and utilities to learn linear operators via kernel methods.
7 |
8 | .. toctree::
9 |
10 | kernel
11 |
--------------------------------------------------------------------------------
/docs/reference/kernel/kernel.rst:
--------------------------------------------------------------------------------
1 | .. _kernel_reference:
2 | ==============
3 | :code:`kernel`
4 | ==============
5 | .. module:: linear_operator_learning.kernel
6 |
7 | .. rst-class:: lead
8 |
9 | Kernel Methods
10 |
11 | Table of Contents
12 | -----------------
13 |
14 | - :ref:`Regressors `
15 | - :ref:`Types `
16 | - :ref:`Linear Algebra `
17 | - :ref:`Utilities `
18 |
19 |
20 | .. _kernel_regressors:
21 | Regressors
22 | ----------
23 |
24 | Common
25 | ~~~~~~
26 |
27 | .. autofunction:: linear_operator_learning.kernel.predict
28 |
29 | .. autofunction:: linear_operator_learning.kernel.eig
30 |
31 | .. autofunction:: linear_operator_learning.kernel.evaluate_eigenfunction
32 |
33 | .. _rrr:
34 | Reduced Rank
35 | ~~~~~~~~~~~~
36 | .. autofunction:: linear_operator_learning.kernel.reduced_rank
37 |
38 | .. autofunction:: linear_operator_learning.kernel.nystroem_reduced_rank
39 |
40 | .. autofunction:: linear_operator_learning.kernel.rand_reduced_rank
41 |
42 | .. _pcr:
43 | Principal Component Regression
44 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
45 | .. autofunction:: linear_operator_learning.kernel.pcr
46 |
47 | .. autofunction:: linear_operator_learning.kernel.nystroem_pcr
48 |
49 | .. _kernel_types:
50 | Types
51 | -----
52 |
53 | .. autoclass:: linear_operator_learning.kernel.structs.FitResult
54 | :members:
55 |
56 | .. autoclass:: linear_operator_learning.kernel.structs.EigResult
57 | :members:
58 |
59 | .. _kernel_linalg:
60 | Linear Algebra Utilities
61 | ------------------------
62 |
63 | .. autofunction:: linear_operator_learning.kernel.linalg.weighted_norm
64 |
65 | .. autofunction:: linear_operator_learning.kernel.linalg.stable_topk
66 |
67 | .. autofunction:: linear_operator_learning.kernel.linalg.add_diagonal_
68 |
69 | .. _kernel_utils:
70 | General Utilities
71 | -----------------
72 |
73 | .. autofunction:: linear_operator_learning.kernel.utils.topk
74 |
75 | .. autofunction:: linear_operator_learning.kernel.utils.sanitize_complex_conjugates
76 |
77 |
78 | .. footbibliography::
--------------------------------------------------------------------------------
/docs/reference/nn/functional.rst:
--------------------------------------------------------------------------------
1 | .. _nn_functional:
2 | =====================
3 | :code:`nn.functional`
4 | =====================
5 |
6 | .. rst-class:: lead
7 |
8 | Functional implementations
9 |
10 | .. module:: linear_operator_learning.nn
11 |
12 | Table of Contents
13 | ~~~~~~~~~~~~~~~~~
14 |
15 | - :ref:`Loss Functions `
16 | - :ref:`Regularization Functions `
17 |
18 | .. _nn_func_loss_fns:
19 | Loss Functions
20 | ~~~~~~~~~~~~~~
21 |
22 | .. autofunction:: linear_operator_learning.nn.functional.l2_contrastive_loss
23 |
24 | .. autofunction:: linear_operator_learning.nn.functional.kl_contrastive_loss
25 |
26 | .. autofunction:: linear_operator_learning.nn.functional.vamp_loss
27 |
28 | .. autofunction:: linear_operator_learning.nn.functional.dp_loss
29 |
30 | .. _nn_func_reg_fns:
31 | Regularization Functions
32 | ~~~~~~~~~~~~~~~~~~~~~~~~
33 |
34 | .. autofunction:: linear_operator_learning.nn.functional.orthonormal_fro_reg
35 |
36 | .. autofunction:: linear_operator_learning.nn.functional.orthonormal_logfro_reg
37 |
38 | .. footbibliography::
--------------------------------------------------------------------------------
/docs/reference/nn/index.rst:
--------------------------------------------------------------------------------
1 | .. _nn_index:
2 | ===============
3 | Neural Networks
4 | ===============
5 |
6 | Functions and modules to learn linear operators via neural networks.
7 |
8 | .. toctree::
9 |
10 | nn
11 | functional
12 | stats
13 | linalg
14 |
--------------------------------------------------------------------------------
/docs/reference/nn/linalg.rst:
--------------------------------------------------------------------------------
1 | .. _nn_linalg:
2 | =====================
3 | :code:`nn.linalg`
4 | =====================
5 |
6 | .. rst-class:: lead
7 |
8 | Linear Algebra Utilities
9 |
10 | .. module:: linear_operator_learning.nn
11 |
12 | .. autofunction:: linear_operator_learning.nn.linalg.sqrtmh
13 |
14 | .. footbibliography::
--------------------------------------------------------------------------------
/docs/reference/nn/nn.rst:
--------------------------------------------------------------------------------
1 | .. _nn:
2 | ==========
3 | :code:`nn`
4 | ==========
5 |
6 | .. rst-class:: lead
7 |
8 | Neural Network Modules
9 |
10 | .. module:: linear_operator_learning.nn
11 |
12 | Table of Contents
13 | ~~~~~~~~~~~~~~~~~
14 |
15 | - :ref:`Regressors `
16 | - :ref:`Loss Functions `
17 | - :ref:`Modules `
18 |
19 | .. _nn_regressors:
20 | Regressors (see also :ref:`kernel regressors `)
21 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22 |
23 | .. autofunction:: linear_operator_learning.nn.ridge_least_squares
24 |
25 | .. autofunction:: linear_operator_learning.nn.eig
26 |
27 | .. autofunction:: linear_operator_learning.nn.evaluate_eigenfunction
28 |
29 | .. _nn_loss_fns:
30 | Loss Functions
31 | ~~~~~~~~~~~~~~
32 |
33 | .. autoclass:: linear_operator_learning.nn.L2ContrastiveLoss
34 | :members:
35 | :exclude-members: __init__, __new__
36 |
37 | .. autoclass:: linear_operator_learning.nn.KLContrastiveLoss
38 | :members:
39 | :exclude-members: __init__, __new__
40 |
41 | .. autoclass:: linear_operator_learning.nn.VampLoss
42 | :members:
43 | :exclude-members: __init__, __new__
44 |
45 | .. autoclass:: linear_operator_learning.nn.DPLoss
46 | :members:
47 | :exclude-members: __init__, __new__
48 |
49 | .. _nn_modules:
50 | Modules
51 | ~~~~~~~
52 |
53 | .. autoclass:: linear_operator_learning.nn.MLP
54 | :members:
55 | :exclude-members: __init__, __new__, forward
56 |
57 | .. autoclass:: linear_operator_learning.nn.ResNet
58 | :members:
59 | :exclude-members: __init__, __new__, forward
60 |
61 | .. autoclass:: linear_operator_learning.nn.SimNorm
62 | :members:
63 | :exclude-members: __init__, __new__, forward
64 |
65 | .. autoclass:: linear_operator_learning.nn.EMACovariance
66 | :members:
67 | :exclude-members: __init__, __new__, forward
68 |
69 | .. footbibliography::
70 |
71 |
--------------------------------------------------------------------------------
/docs/reference/nn/stats.rst:
--------------------------------------------------------------------------------
1 | .. _nn_stats:
2 | =====================
3 | :code:`nn.stats`
4 | =====================
5 |
6 | .. rst-class:: lead
7 |
8 | Statistics Utilities
9 |
10 | .. module:: linear_operator_learning.nn
11 |
12 | .. autofunction:: linear_operator_learning.nn.stats.covariance
13 |
14 | .. autofunction:: linear_operator_learning.nn.stats.cov_norm_squared_unbiased
15 |
16 | .. autofunction:: linear_operator_learning.nn.stats.cross_cov_norm_squared_unbiased
17 |
18 | .. autofunction:: linear_operator_learning.nn.stats.whitening
19 |
20 | .. footbibliography::
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | Sphinx==7.3.7
2 | sphinxawesome-theme==5.2.0
3 | sphinxcontrib-applehelp==2.0.0
4 | sphinxcontrib-devhelp==2.0.0
5 | sphinxcontrib-htmlhelp==2.1.0
6 | sphinxcontrib-jsmath==1.0.1
7 | sphinxcontrib-qthelp==2.0.0
8 | sphinxcontrib-serializinghtml==2.0.0
9 | sphinxcontrib-bibtex
10 | sphinx_design
11 | myst-parser
12 | myst-nb
13 | git+https://github.com/CSML-IIT-UCL/linear_operator_learning
--------------------------------------------------------------------------------
/examples/detecting_independence/.gitignore:
--------------------------------------------------------------------------------
1 | runs/
--------------------------------------------------------------------------------
/examples/detecting_independence/detecting_independence.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Detecting Independence\n",
8 | "\n",
9 | "In *Neural Conditional Probability for Uncertainty Quantification* (Kostic et al., 2024), the authors claim that the (deflated) conditional expectation operator can be used to detect the independence of two random variables X and Y by verifying whether it is zero. Here, we show this equivaliance in practice.\n",
10 | "\n",
11 | "## Dataset\n",
12 | "\n",
13 | "We consider the data model\n",
14 | "\n",
15 | "$$Y = tX + (1-t)X',$$\n",
16 | "\n",
17 | "where $X$ and $X'$ are independent standard Gaussians in $\\mathbb{R}$, and $t \\in [0,1]$ is an interpolating factor. This model allows us to explore both extreme cases ($t = 0$ for independence and $t = 1$ where $Y = X$) and the continuum in between, to assess the robustness of NCP in detecting independence."
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 1,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import torch\n",
27 | "from torch.utils.data import DataLoader, TensorDataset, random_split\n",
28 | "\n",
29 | "\n",
30 | "def make_dataset(n_samples: int = 200, t: float = 0.0):\n",
31 | " \"\"\"Draw sample from data model Y = tX + (1-t)X_, where X and X_ are independent gaussians.\n",
32 | "\n",
33 | " If t = 0, then X and Y are independent. Otherwise, if t->1, X and Y become ever more dependent.\n",
34 | "\n",
35 | " Args:\n",
36 | " n_samples (int, optional): Number of samples. Defaults to 200.\n",
37 | " t (float, optional): Interpolation factor. Defaults to 0.0.\n",
38 | " \"\"\"\n",
39 | " X = torch.normal(mean=0, std=1, size=(n_samples, 1))\n",
40 | " X_ = torch.normal(mean=0, std=1, size=(n_samples, 1))\n",
41 | " Y = t * X + (1 - t) * X_\n",
42 | "\n",
43 | " ds = TensorDataset(X, Y)\n",
44 | "\n",
45 | " # Split data into train and val sets\n",
46 | " train_ds, val_ds = random_split(ds, lengths=[0.85, 0.15])\n",
47 | "\n",
48 | " return train_ds, val_ds"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "## Learning the conditional expectation operator\n",
56 | "\n",
57 | "Now, we go through the process of learning the conditional expectation operator $\\mathbb{E}_{Y \\mid X}: L^2_Y \\mapsto L^2_X$\n",
58 | "\n",
59 | "$$[\\mathbb{E}_{Y \\mid X}g](x) = \\mathbb{E}[g(Y) \\mid X = x],$$\n",
60 | "\n",
61 | "where $g \\in L^2_Y$. We begin by noting that, if $\\{u_i\\}_{i=1}^\\infty$ and $\\{v_j\\}_{j=1}^\\infty$ were orthonormal bases of $L^2_X$ and $L^2_Y$ [(Orthonormal basis wikipedia)](https://en.wikipedia.org/wiki/Orthonormal_basis), then we could see the conditional expectation operator as an (infinite) matrix $\\mathbf{E}$, where\n",
62 | "\n",
63 | "$$\\mathbf{E}_{ij} = \\langle u_i, \\mathbb{E}_{Y \\mid X}v_j \\rangle_{L^2_X} = \\mathbb{E}_X[u_i(X)[\\mathbb{E}_{Y \\mid X}v_j](X)] = \\mathbb{E}_{XY}[u_i(X)v_j(Y)].$$\n",
64 | "\n",
65 | "Hence, to learn the operator, we \"only\" need to learn the most important parts of $\\mathbf{E}$. The standard way to deal with such problems is to restrict oneself to finite subspaces of $L^2_X$ and $L^2_Y$ and then learn the (finite) matrix there. This corresponds to finding orthonormal functions $\\{u_i\\}_{i=1}^d$ and $\\{v_j\\}_{j=1}^d$ s.t.\n",
66 | "\n",
67 | "$$\\lVert \\mathbb{E}_{Y \\mid X} - \\mathbb{E}_{Y \\mid X}^d \\rVert$$\n",
68 | "\n",
69 | "is minimized, where $d \\in \\mathbb{N}$ is the dimension and $\\mathbb{E}_{Y \\mid X}^d$ is the truncated operator that acts on $span\\{v_j\\}_{j=1}^d$ and $span\\{u_i\\}_{i=1}^d$. The theoretical solution of this problem is given by the truncated (rank d) Singular Value Decomposition [(Low-rank matrix approximation wikipedia)](https://en.wikipedia.org/wiki/Low-rank_approximation), which also has the nice benefit of ordering the bases by their importance a la PCA, meaning that $u_1$ is more important than $u_2$, and so on and so forth.\n",
70 | "\n",
71 | "## A representation learning problem\n",
72 | "\n",
73 | "A key insight of Kostic et al. (2024) is that this problem corresponds to a representation learning problem, where the goal is to find latent variables $u,v \\in \\mathbb{R}^d$ that are\n",
74 | "\n",
75 | "1. [(Whitened, wikipedia)](https://en.wikipedia.org/wiki/Whitening_transformation): $\\mathbb{E}[u_i(X)u_j(X)] = \\mathbb{E}[v_i(Y)v_j(Y)] = \\delta_{ij}$; and\n",
76 | "2. Minimize the contrastive loss\n",
77 | "\n",
78 | "$$\\frac{1}{N(N-1)}\\sum_{i \\neq j}\\langle u_{i}, Sv_{j} \\rangle^2 - \\frac{2}{N}\\sum_{i=1}\\langle u_{i}, Sv_{i} \\rangle,$$\n",
79 | "\n",
80 | "where $S$ is the matrix of the conditional expectation operator on these subspaces/features, which can be learned end-to-end with backpropagation or estimated with running means."
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "## Representation learning in Torch"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 2,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "from torch.nn import Module\n",
97 | "import torch\n",
98 | "import math\n",
99 | "from torch import Tensor\n",
100 | "\n",
101 | "\n",
102 | "class _Matrix(Module):\n",
103 | " \"\"\"Module representing the matrix form of the truncated conditional expectation operator.\"\"\"\n",
104 | "\n",
105 | " def __init__(\n",
106 | " self,\n",
107 | " dim_u: int,\n",
108 | " dim_v: int,\n",
109 | " ) -> None:\n",
110 | " super().__init__()\n",
111 | " self.weights = torch.nn.Parameter(\n",
112 | " torch.normal(mean=0.0, std=2.0 / math.sqrt(dim_u * dim_v), size=(dim_u, dim_v))\n",
113 | " )\n",
114 | "\n",
115 | " def forward(self, v: Tensor) -> Tensor:\n",
116 | " \"\"\"Forward pass of the truncated conditional expectation operator's matrix (v -> Sv).\"\"\"\n",
117 | " # TODO: Unify Pietro, Giacomo and Dani's ideas on how to normalize\\symmetrize the operator.\n",
118 | " out = v @ self.weights.T\n",
119 | " return out\n",
120 | "\n",
121 | "\n",
122 | "class NCP(Module):\n",
123 | " \"\"\"Neural Conditional Probability in PyTorch.\n",
124 | "\n",
125 | " Args:\n",
126 | " embedding_x (Module): Neural embedding of x.\n",
127 | " embedding_dim_x (int): Latent dimension of x.\n",
128 | " embedding_y (Module): Neural embedding of y.\n",
129 | " embedding_dim_y (int): Latent dimension of y.\n",
130 | " \"\"\"\n",
131 | "\n",
132 | " def __init__(\n",
133 | " self,\n",
134 | " embedding_x: Module,\n",
135 | " embedding_y: Module,\n",
136 | " embedding_dim_x: int,\n",
137 | " embedding_dim_y: int,\n",
138 | " ) -> None:\n",
139 | " super().__init__()\n",
140 | " self.U = embedding_x\n",
141 | " self.V = embedding_y\n",
142 | "\n",
143 | " self.dim_u = embedding_dim_x\n",
144 | " self.dim_v = embedding_dim_y\n",
145 | "\n",
146 | " self.S = _Matrix(self.dim_u, self.dim_v)\n",
147 | "\n",
148 | " def forward(self, x: Tensor, y: Tensor) -> Tensor:\n",
149 | " \"\"\"Forward pass of NCP.\"\"\"\n",
150 | " u = self.U(x)\n",
151 | " v = self.V(y)\n",
152 | " Sv = self.S(v)\n",
153 | "\n",
154 | " return u, Sv"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "metadata": {},
160 | "source": [
161 | "## Training NCP\n",
162 | "\n",
163 | "We now how to train the NCP module above with the contrastive loss from `linear_operator_learning.nn` with orthonormality regularization and standard deep learning techniques."
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 3,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "from torch.optim import Optimizer\n",
173 | "\n",
174 | "def train(\n",
175 | " ncp: NCP,\n",
176 | " train_dataloader: DataLoader,\n",
177 | " device: str,\n",
178 | " loss_fn: callable,\n",
179 | " optimizer: Optimizer,\n",
180 | ") -> Tensor:\n",
181 | " \"\"\"Training logic of NCP.\"\"\"\n",
182 | " ncp.train()\n",
183 | " for batch, (x, y) in enumerate(train_dataloader):\n",
184 | " x, y = x.to(device), y.to(device)\n",
185 | "\n",
186 | " u, Sv = ncp(x, y)\n",
187 | " loss = loss_fn(u, Sv)\n",
188 | " loss.backward()\n",
189 | " optimizer.step()\n",
190 | " optimizer.zero_grad()"
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": 4,
196 | "metadata": {},
197 | "outputs": [
198 | {
199 | "name": "stdout",
200 | "output_type": "stream",
201 | "text": [
202 | "Using cpu device\n",
203 | "run_id = (0.0, 0)\n",
204 | "run_id = (0.1, 0)\n",
205 | "run_id = (0.2, 0)\n",
206 | "run_id = (0.3, 0)\n",
207 | "run_id = (0.4, 0)\n",
208 | "run_id = (0.5, 0)\n",
209 | "run_id = (0.6, 0)\n",
210 | "run_id = (0.7, 0)\n",
211 | "run_id = (0.8, 0)\n",
212 | "run_id = (0.9, 0)\n",
213 | "run_id = (1.0, 0)\n"
214 | ]
215 | }
216 | ],
217 | "source": [
218 | "import torch\n",
219 | "import linear_operator_learning as lol\n",
220 | "\n",
221 | "\n",
222 | "\n",
223 | "SEED = 1\n",
224 | "REPEATS = 1\n",
225 | "BATCH_SIZE = 256\n",
226 | "N_SAMPLES = 5000\n",
227 | "MLP_PARAMS = dict(\n",
228 | " output_shape=2,\n",
229 | " n_hidden=2,\n",
230 | " layer_size=32,\n",
231 | " activation=torch.nn.ELU,\n",
232 | " bias=False,\n",
233 | " iterative_whitening=False,\n",
234 | ")\n",
235 | "EPOCHS = 100\n",
236 | "WHITENING_N_SAMPLES = 2000\n",
237 | "\n",
238 | "torch.manual_seed(SEED)\n",
239 | "\n",
240 | "# device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else \"cpu\"\n",
241 | "device = \"cpu\"\n",
242 | "print(f\"Using {device} device\")\n",
243 | "\n",
244 | "results = dict()\n",
245 | "for t in torch.linspace(start=0, end=1, steps=11):\n",
246 | " for r in range(REPEATS):\n",
247 | " run_id = (round(t.item(), 2), r)\n",
248 | " print(f\"run_id = {run_id}\")\n",
249 | "\n",
250 | " # Load data_________________________________________________________________________________\n",
251 | " train_ds, val_ds = make_dataset(n_samples=N_SAMPLES, t=t.item())\n",
252 | "\n",
253 | " train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)\n",
254 | " val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)\n",
255 | "\n",
256 | " # Build NCP_________________________________________________________________________________\n",
257 | " ncp = NCP(\n",
258 | " embedding_x=lol.nn.MLP(input_shape=1, **MLP_PARAMS),\n",
259 | " embedding_dim_x=MLP_PARAMS[\"output_shape\"],\n",
260 | " embedding_y=lol.nn.MLP(input_shape=1, **MLP_PARAMS),\n",
261 | " embedding_dim_y=MLP_PARAMS[\"output_shape\"],\n",
262 | " ).to(device)\n",
263 | "\n",
264 | " # Train NCP_________________________________________________________________________________\n",
265 | " loss_fn = lol.nn.L2ContrastiveLoss()\n",
266 | " optimizer = torch.optim.Adam(ncp.parameters(), lr=5e-4)\n",
267 | "\n",
268 | " for epoch in range(EPOCHS):\n",
269 | " train(ncp, train_dl, device, loss_fn, optimizer)\n",
270 | "\n",
271 | " # Extract norm______________________________________________________________________________\n",
272 | " x = torch.normal(mean=0, std=1, size=(WHITENING_N_SAMPLES, 1)).to(device)\n",
273 | " x_ = torch.normal(mean=0, std=1, size=(WHITENING_N_SAMPLES, 1)).to(device)\n",
274 | " y = t * x + (1 - t) * x_\n",
275 | " u, Sv = ncp(x, y)\n",
276 | "\n",
277 | " _, _, svals, _, _ = lol.nn.stats.whitening(u, Sv)\n",
278 | " results[run_id] = svals.max().item()"
279 | ]
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "metadata": {},
284 | "source": [
285 | "## Plots"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": 5,
291 | "metadata": {},
292 | "outputs": [
293 | {
294 | "data": {
295 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGwCAYAAABVdURTAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQblJREFUeJzt3Ql8jVf+x/Ff9ghZrAkSYt9CopZYalqt0tYo02lLtailq5qWaosuSlu6Gv2XUlRppy2tVjeqVJmOoULs+y4REoIsgqz3/zrHJHIjidzIvc9dPu/X6zae5z6XnzRyvznPOb/jZjKZTAIAAOAk3I0uAAAAoCIRbgAAgFMh3AAAAKdCuAEAAE6FcAMAAJwK4QYAADgVwg0AAHAqnuJi8vLy5OTJk+Lv7y9ubm5GlwMAAMpAteVLT0+XOnXqiLt76WMzLhduVLAJCwszugwAAFAO8fHxEhoaWuo1Lhdu1IhN/icnICDA6HIAAEAZpKWl6cGJ/Pfx0rhcuMm/FaWCDeEGAADHUpYpJUwoBgAAToVwAwAAnArhBgAAOBXCDQAAcCqEGwAA4FQINwAAwKkQbgAAgFMh3AAAAKfick38AACAdfZ+2hqfIqv2JEnqpWwJrOQld7QMlrZhQTbfy5FwAwAAbsiBpHQZ+8122XEi1ez8rLWHpU1ooLx3f6Q0Db7+tglOcVvqjz/+kD59+ugdPlWq+/7776/7mrVr18pNN90kPj4+0rhxY1mwYIFNagUAwJajIFvizsvbK/bJhKU79Ud1rM7bY7C5b9b6a4JNPnVePa+uc4mRm4yMDImMjJRhw4bJvffee93rjx49Kr1795YnnnhCvvjiC1m9erWMGDFCateuLb169bJJzQAAuNIoSGlU2FK1pl3OKfU69fzz32yX70d2tcktKjeTncRA9ZddunSp9OvXr8RrXnzxRVm2bJns2rWr4NyAAQMkJSVFVqxYUeZdRQMDAyU1NZWNMwEAdjkKUlpYCPD1lCVPdqmwgJOTmycXs3PlYmauZGTlXP2YlSMZmbnmH7PUdf/7mJUjp1Iu63k2ZbX0qS7Stl7VctVpyfu3Q8252bBhg/To0cPsnBqxefbZZ0t8TWZmpn4U/uQAAGBvLBkFeearrTKtf6Rcys6zKIxkFPN8Vk6ezf6OK/cklTvcWMKhwk1iYqIEBwebnVPHKrBcunRJKlWqdM1rpk6dKpMmTbJhlQAAWE6NgJQ0b6WovYnpctcH68TRpF7Ktsmf41DhpjzGjx8vY8aMKThWQSgsLMzQmgAArk2N0sSduyi7T6bJ7pOp+uPGo+fE2QVW8rLJn+NQ4SYkJESSkpLMzqljde+tuFEbRa2qUg8AAIyQnZsnh05fMAsye0+mSXpm6befrMHbw138fDyksren+Hl7iJ+Pp1RWH/93XNnnyq8rF3mu4HzB85769zmYdEEenPtnmf/8ni3N775Yi0OFm86dO8vy5cvNzq1atUqfBwDAaGoey95T6bLnfyFGPfYnpVtlXktEnQC5uUnNgiCiw4oKKN5XgkflImGkkreHeHtWbAeY6pW99QqustxOiwwNlKiwIHH6cHPhwgU5dOiQ2VLvbdu2SbVq1aRevXr6llJCQoJ89tln+nm1BHzGjBnywgsv6OXjv//+u3z99dd6BRUAALZ0PiPLbDRGfTyanCF55VyDrFZBXW8ycWGv94uwyeTc6610VkvTy7LC6937I23WqdjQcLN582bp3r17wXH+3JghQ4bo5nynTp2SuLi4gucbNGigg8zo0aPlgw8+kNDQUJk3bx49bgAAVtsWQL3+ZOpl2Z1wdTRGjcyoc+VVN6iStKoTIK3qBF75WDdAgv19pN9HJTfDM2oU5HrUknS1NL243jz5tb5r4948dtPnxlbocwMArqOkhnhKcQ3xcvNMcjQ5f37M1VGZlIvlW+Xj7ibSqGYVsyDTsk6ABPl5202fm4qi4sS2+BS93Ds/RKo5NiqEVcSIjSXv34QbAIBTKktQUHNSHu3WUJIvZOoQs+9UulzKzi3Xn+fj6S7NQ/ylZf5oTJ0AaR4SoOe6WFq3PY2C2AvCTSkINwDg/NRbW9+Z/y1z3xhL+ft6mt9WqhMojWpWFk8Pd4cYBXFETtuhGACAim6Idz3BAT6FQsyVIBNatZJVQ4b6vdVkYaMnDDsqwg0AwKlcysqVGb9fXYlriQY1Kus5MYVHZWpUoVeaoyHcAACcwp6TabJoU5ws3Zog6RYsqVbU6ql/9o+SKj68LToD/i8CABxWRmaO/LT9pHwVEyfbb+A2VONaVQg2ToT/kwAAh6Im26r5NGqU5sdtJ/XO1jfKVtsCwDYINwAAh6BWDf2wLUG+iomXvafSSrzOy8NN7mgRrHfOVh2DHakhHioG4QYAYNejNJuPn9e3nZbvPCWXs0veo6lhjcoyoGOY3HtTqJ4EXNaGeLbcFgC2QbgBANidcxlZ8t2WE7JoU7zeUbskaiPIuyNCZEDHehLdoJpZSLHHbQFgG4QbAIBdyMszyZ9HzspXm+Ll112JkpVb8ihNs2B/PUrzt7Z1S9zKQFHB5YeRXWmI52IINwAAQ51OvyxLYk/I4k3xcvzsxRKvq+TlIX9tU1sejK5X5k0vFRriuR7CDQDA5tQGlX8cPCOLYuJk9d7TkpNX8k5AEXUDZECHetI3qo74+3rZtE44JsINAMBmTqZckq83x8s3m09IQsqlEq9TPWdUmHmwYz2JqBto0xrh+Ag3AACrysnNk9/3ndaTg9fuPy2lDNJI23pB8mCHevLXyNri581bFMqHrxwAQJmXZasNKVcVmpirti0oaf5L/LmLutGeGqU5nZ5Z4u+rfh81MViN0jQLYeUSbhzhBgBwXapnTHFLqmetPSxtQgPlvf8tqc7KyZOVexJlUUy8rDuUXOrvqZZuq0BzZ0SI+Hp5WPlvAFdCuAEAlOp6zfBU4Ln3o/XSq1WwrNl/RveoKUn1yt5yX7tQ6d8hTBrWrGLFquHKCDcAgFJvRakRm9K6/CoXMnPk2y0JJT7frUkNPUrTo0WwbrwHWBPhBgBQIjXHprjuvmVRy99HHmgfpkdpwqr5VXhtQEkINwCAEqnJw5a6vXktvR1C92Y1xdODURrYHuEGAFAitSrKEv3a1pHp/dtarR6gLIjUAIBSd822RO3ASlarBSgrwg0AoMRuwusPn7XoNWpDSsBo3JYCAFyzQmrp1gSZ+ONuSb/OKqnCIkMD9U7bgNEINwCAAskXMuWlpTvl191JFt++evf+yDLv1A1YE7elAADail2J0uuff1wTbFRfmsf/0lBal7CBpRqxWfJkF92hGLAHjNwAgItTK6Im/bRbviumCZ8KNNMeiJQmwf76dtW2+BRZWWhvKTXHRt2KYsQG9oRwAwAubN3BZHl+yXY5lXrZ7LyHu5uMuq2xjOzeWLz+16tGBZi29arqB2DPCDcA4IIuZeXKW7/slYUbjl/zXONaVfRoTZtQJgfDMRFuAMDFxB4/r/eLOpqcYXZe3Vka3rWBjO3VjF264dAINwDgIjJzcuWD3w7K7H8fljyT+XOhVSvJ+/dHSnTD6kaVB1QYwg0AuIC9p9Jk9OJtsi8x/ZrnHuwYJi/1bilVfHhLgHPgKxkAnFhObp58/McRmf7bAcnONV2za/fbf28j3ZvXMqw+wBoINwDgpNScmjFfb5OtcSnXPNcnso5MvqeVVK3sbUhtgDURbgDAyeTlmeTzP4/L1F/2yuXsPLPngvy85I1+EfLXNnUMqw+wNsINADjZZpcvLNkh6w4lX/Nc92Y19W2oWgG+htQG2ArhBgCcgOoerDoMv/bTtZtdVvb2kFf+2lL6dwijkzBcAuEGAJxgs8sJ3+3U2yIU1bFBNb3EO6yanyG1AUYg3ACAg292qXbxPpuRdc1mly/0aibDujYQd3dGa+BaCDcA4KibXf64W77beu1ml21Cr2x22bgWu3TDNRFuAMDB/OfgGT1puOhml57ubvJ0kc0uAVdEuAEAB3ExK0fe+mWffFbMZpdN9GaXUdI6NNCQ2gB7QrgBAAfZ7PK5r7fJsbMXzc6rxU8jbm4gz/Vks0sgH+EGAOx8s8vpvx2Uj4vZ7DKsWiV57z42uwSKItwAgJ3aczJNb59Q/GaX9eSl3i3Y7BIoBv8qAMCgpntb41Nk1Z4kvfIpsJKX3NEyWNqGBUlunonNLoEbQLgBABs7kJQuY7/ZLjtOpJqdn7X2sDQL9tfzaIobrVGbXb7et5UE+bHZJVAawg0A2DjY3DdrvaQV2SIh3/6ka0MNm10CliHcAIANb0WpEZuSgk1xbmteS966tzWbXQIWINwAgI2oOTZFb0WVZmT3RjK2ZzM2uwQsRAtLALARNXnYEmrpN8EGsBzhBgBsRK2Ksub1AK4g3ACAjajl3ta8HsAVhBsAsBHVx8YSPS28HsAVhBsAsJHwan7i41m2b7uRoYESFRZk9ZoAZ0S4AQAbuJydK49+HiuZOXnXvTbA11PevT+SycRAORFuAMDK8vJMeo8otbN3WUZsljzZRZoG+9ukNsAZ0ecGAKxsyvK9snxnotm5l3u3kHb1q8rKQntLqTk26lYUIzaAg4/czJw5U8LDw8XX11eio6MlJiam1OunT58uzZo1k0qVKklYWJiMHj1aLl++bLN6AcASn/73qMxbd9Ts3NCu4TKiW0NpW6+qvHhnc5nyt9b6ozom2AAOHm4WL14sY8aMkYkTJ8qWLVskMjJSevXqJadPny72+i+//FLGjRunr9+7d6988skn+veYMGGCzWsHgOv5dXeiTP55j9m5O1uFyMu9WxpWE+AK3ExqsxODqJGaDh06yIwZM/RxXl6eHo0ZNWqUDjFFPf300zrUrF69uuDcc889Jxs3bpR169YV+2dkZmbqR760tDT9Z6SmpkpAQIBV/l4AsCXuvDw450+zCcRt6wXJV492El8vD0NrAxyRev8ODAws0/u3YSM3WVlZEhsbKz169LhajLu7Pt6wYUOxr+nSpYt+Tf6tqyNHjsjy5cvl7rvvLvHPmTp1qv5k5D9UsAEAazqWnCEjFm42Czbh1f1k3uD2BBvAmScUJycnS25urgQHmzepUsf79u0r9jUDBw7Ur7v55pv17ro5OTnyxBNPlHpbavz48frWV9GRGwCwhnMZWfLIpzH6Y75qlb1lwdCOUr2Kj6G1Aa7C8AnFlli7dq1MmTJFPvroIz1H57vvvpNly5bJ66+/XuJrfHx89PBV4QcAWKuXzYiFm+TY2YsF51TTvnlD2kt4jcqG1ga4EsNGbmrUqCEeHh6SlGS+S646DgkJKfY1r7zyigwaNEhGjBihj1u3bi0ZGRny2GOPyUsvvaRvawGAEXLzTPLsom2yJS6l4Jxa+PTBgLZyU72qhtYGuBrD0oC3t7e0a9fObHKwmlCsjjt37lzsay5evHhNgFEBSTFwXjQAyJvL9sqK3ea9bCb+taXcGVH8D2sAnLSJn5oLM2TIEGnfvr107NhR97BRIzFDhw7Vzw8ePFjq1q2rJwUrffr0kWnTpknbtm31SqtDhw7p0Rx1Pj/kAICtfbLuqMz/r3kvmxE3N5BHujYwrCbAlRkabvr37y9nzpyRV199VRITEyUqKkpWrFhRMMk4Li7ObKTm5Zdf1g2u1MeEhASpWbOmDjZvvvmmgX8LAK7sl52n5I1l5r1s7m4dIhPubmFYTYCrM7TPjb2vkweA0sQePycD5240W/Ldvn5V+deIaJZ8A67Y5wYAHNnRYnrZNKhRWebSywYwHOEGACyUfCFT97I5fzG74Fx13cumg1St7G1obQAINwBgkUtZqpfNZjleqJeNr5e7fPJIB6lfnV42gD0g3ACABb1snlm0VbbFX+1l4+4m8uGDN0lUWJChtQG4inADAGWg1l68/vMeWbnHvPHoa/e0kjtamm8jA8BYhBsAKGMvmwXrj5mde+wvDWVw53DDagJQPMINAFzHsh2ql81es3O929SWcXc2N6wmACUj3ABAKTYdOyejv95mdq5DeFV5//5IcVcTbgDYHcINAJTg8JkL8uhnmyWrUC+bhjXpZQPYO8INABTjTPqVXjYphXrZ1KjiLQuHdpQgP3rZAPaMcAMARVzMypERCzdJ/LlLBecqeXnI/Ec6SFg1P0NrA3B9hBsAKNLL5h9fbZXtJ1ILzqmpNTMGtpU2ofSyARwB4QYACvWyee3H3fLb3tNm5yf3jZDbW9DLBnAUhBsA+J+5/zkin/953OzcE7c0koc71TesJgCWI9wAgIj8tP2kTFm+z+xcn8g68kKvZobVBKB8CDcAXF7M0XPy3Nfbzc5FN6gm793fhl42gAMi3ABwaYdO/6+XTe7VXjaNa1WROYPai48nvWwAR0S4AeCyTqdf1r1sUi8V7mXjI58+0kEC/bwMrQ1A+RFuALhsL5vhCzbLifNXe9n4eXvoYEMvG8CxEW4AuJyc3DwZ9eVW2Zlg3stm5sCbpHVooKG1AbhxhBsALtfLZuKPu2X1PvNeNm/0ay3dm9cyrC4AFYdwA8ClzP73EfliY5zZuadubSQDo+sZVhOAikW4AeAyftiWIG+vMO9l0zeqjjxPLxvAqRBuALiEP4+clee/2WF2rlPDavLOfW3EzY1eNoAzIdwAcHoHk9LlsSK9bJrUqiIf08sGcEqEGwBO7XSa6mWzSdIu5xScq+XvIwuGdZTASvSyAZwR4QaA08rIzJFhCzdJQop5L5v5j3SQukGVDK0NgPUQbgA4bS+bkV9ukV0JaQXnPNzd5KOHbpKIuvSyAZyZp9EFAEBF9K7ZGp8iq/Yk6a0UAn095dCZC7J2/xmz697sFyG3NqOXDeDsCDcAHNqBpHQZ+8122XHiarfh4oy6rbEM6EgvG8AVEG4AOHSwuW/WerPJwsXp0byWjLmjqc3qAmAs5twAcNhbUWrE5nrBRklKz7RJTQDsA+EGgENSc2yudysqn9ogc1t8itVrAmAfCDcAHJKaPGyJlRZeD8BxEW4AOCS1Ksqa1wNwXIQbAA7J0u7CdCMGXAfhBoBDysrJtej6ni2DrVYLAPvCUnAADrdKavpvB+WTdcfK/JrI0ECJCguyal0A7AfhBoDDyMszyeSf98iC9WUPNgG+nvLu/ZHi5uZm1doA2A/CDQCHkJ2bJy8s2SFLtyaYnW8W4i8ebm6y59TVPaQKj9ioYNM02N+GlQIwGuEGgN27nJ0rI7/YIqv3nTY736VRdZkzuL1U9vbQfWxW5u8tVclLz7FRt6IYsQFcD+EGgF1Lu5wtIxZulpij58zO92oVLB8MaCu+Xh76uG29qvoBAIQbAHYr+UKmDJkfI7tPmt9yur9dqEy9t7V4erDgE8C1CDcA7FJCyiUZNG+jHEnOMDv/aLcGMuHuFtxuAlAiwg0Au3Po9AUZ9MlGOZV62ez8872ayVO3NiLYACgV4QaAXdl5IlWGfBoj5zKyCs6pLPN63wh5uFN9Q2sD4BgINwDsxobDZ+XRzzbLhcycgnOe7m4yrX+U3BNZx9DaADgOwg0Au9nle+SXWyQrJ6/gnK+Xu8x6uJ10b1bL0NoAOBbCDQDDfRt7Ql74dofk5pkKzvn7esqnj3SQ9uHVDK0NgOMh3AAw1Px1R/WWCoXVqOIjnw3rKC3rBBhWFwDHRbgBYNgGmP/87aD83+qDZudDq1aSfw2PlvAalQ2rDYBjI9wAMGQDzEk/7ZaFG46bnW9Sq4p8PjxaQgJ9DasNgOMj3ACw+QaYz3+zXb7fdtLsfGRYkCx4pINUrextWG0AnAPhBoBNN8B86ost8nuRDTC7Nq4ucwa1l8o+fEsCcOP4TgLA0A0w72wVIh88GCU+nlc2wASAG0W4AWDYBpgPtA+VKX9jA0wAFYtwA8CqTpy/KIM/iblmA8zH/tJQxt/VnH2iAFQ4wg0Aqzl0Ol0GfRJzzQaYL9zZTJ68hQ0wAViH4WPBM2fOlPDwcPH19ZXo6GiJiYkp9fqUlBQZOXKk1K5dW3x8fKRp06ayfPlym9ULoGx2nEiR+2dvMAs2Ksu8+bcIeerWxgQbAM45crN48WIZM2aMzJ49Wweb6dOnS69evWT//v1Sq9a1e8lkZWXJHXfcoZ9bsmSJ1K1bV44fPy5BQUGG1A+geOsPJ8ujCzdLRlZuwTkvDzeZ9kCU9GEDTABW5mZSbUINogJNhw4dZMaMGfo4Ly9PwsLCZNSoUTJu3Lhrrlch6N1335V9+/aJl5dXmf6MzMxM/ciXlpam/4zU1FQJCKC1O1DRVu5OlKe/2nrNBpizH24nt7IBJoByUu/fgYGBZXr/Nuy2lBqFiY2NlR49elwtxt1dH2/YsKHY1/z444/SuXNnfVsqODhYIiIiZMqUKZKbe/Wnw6KmTp2qPxn5DxVsAFjHktgT8uQX5jt7B/h66u0UCDYAbMWwcJOcnKxDiQophanjxMTEYl9z5MgRfTtKvU7Ns3nllVfk/ffflzfeeKPEP2f8+PE65eU/4uPjK/zvAuDKBphjv9lutrO32gBz8eOd2dkbgE051GopddtKzbeZM2eOeHh4SLt27SQhIUHfqpo4cWKxr1GTjtUDgBU3wFx1QP7v90Nm58OqXdkAs351NsAE4CLhpkaNGjqgJCUlmZ1XxyEhIcW+Rq2QUnNt1OvytWjRQo/0qNtc3t7sSQPYegPM137aLZ8V2QCzafCVDTCDA9gAE4AL3ZZSQUSNvKxevdpsZEYdq3k1xenatascOnRIX5fvwIEDOvQQbADbb4A5+utt1wSbqLAg+frxzgQbAK7Z50YtA587d64sXLhQ9u7dK08++aRkZGTI0KFD9fODBw/Wc2byqefPnTsnzzzzjA41y5Yt0xOK1QRjALZzKStXHv88Vn4osrP3zY1ryBcjoiXIjx82ALjonJv+/fvLmTNn5NVXX9W3lqKiomTFihUFk4zj4uL0Cqp8aqXTr7/+KqNHj5Y2bdroPjcq6Lz44osG/i0AF9wAc8FmiTlmvgHmXREhMn0AG2ACcPE+N/a+Th6AuTPpVzbA3HPKfAPM/u3DZMq9rcXDna7DAIx//3ao1VIAjN0AU+0TdbTIBpiP/6WhjGMDTAB2hHADwIwazN0anyKr9iRJ6qVsCazkJc1D/GXKsr2SlH6127fy4p3N5clbGxlWKwAUh3ADoMCBpHTdiG/HidRSr9MbYPZrLQOj69msNgAoK8INgIJgc9+s9ZJ2OafU6zzdRaYPaCt/bcMGmADsk6FLwQHYz60oNWJzvWCjhFWrLL1b17ZJXQBQHoQbAHqOzfVuReVTE4q3xadYvSYAKC/CDQA9edgSKy28HgBsiXADQK+Ksub1AOBQE4ovXLhgtteTQnM8wLGo5d7WvB4A7H7k5ujRo9K7d2+pXLmy7hZYtWpV/QgKCtIfATgWfx/Lfs7p2fLKFikA4DQjNw8//LBeXTF//ny9DxSdSQHHnm/z/sr9Zb4+MjRQ7/wNAE4VbrZv3y6xsbHSrFmziq8IgM2s3pskT30RK7ll3GEuwNdT3r0/kh9oADjfbakOHTpIfHx8xVcDwGbW7DstT/5ri2QXSjY1q3jrrRZKGrFZ8mQXaRpc/PMA4NAjN/PmzZMnnnhCEhISJCIiQry8zCcXtmnTpqLqA2AFa/eflsc/j5Ws3KuLAUICfGXx452kXjU/3cdmZaG9pdQcG3UrihEbAE4bbs6cOSOHDx+WoUOHFpxT3/TUPBz1MTc3tyJrBFCB/jhwRh4rEmyCA3xk0WOdpH71yvq4bb2q+gEALhNuhg0bJm3btpWvvvqKCcWAA1l3MFke/WyzZOVcDTa1/H3kq0c7SXiNK8EGAFwy3Bw/flx+/PFHady4ccVXBMAq1h9KlhGfbZLMQsGmRhUf+fLRTtKwZhVDawMAwycU33bbbXrFFADHsOHwWRm2cJNczi4cbLxl0WPR0rgWwQaAcynXyE2fPn1k9OjRsnPnTmnduvU1E4rvueeeiqoPwA3aeOSsDFtgHmyqV/bWIzaNa7HyCYDzcTOpWcAWcncvecDH3icUp6Wl6a7KqampbBMBp7fp2DkZMj9GLmZd/TdZTQebaGkewtc/AMdhyft3uUZuiu4lBcD+xB4/J48UCTZV/bzkixEEGwDOzeI5N9nZ2eLp6Sm7du2yTkUAbtiWuPMyZP4mySgUbIL8vORfI6KlRW2CDQDnZnG4UfNr6tWrZ9e3ngBXphrwDfkkRi5k5phtm/Cv4dHSqk6gobUBgN2ulnrppZdkwoQJcu7cuYqvCEC57TiRIoM+2SjphYKNv6+nfDGik0TUJdgAcA3lmnMzY8YMOXTokNSpU0fq168vlSubN//asmVLRdUHoIx2JaTKw/M2SvrlQsHG58qITetQgg0A11GucNOvX7+KrwTADQWbh+ZtlLQiweaz4R0lMizI0NoAwCGWgjsyloLD2ew5mSYD5/0pKRezC85V9vaQz4ZHS7v67A8FwDlYfSl4vtjYWNm7d6/+datWrfR+UwBsZ++pNHmoSLDx8/aQhcM6EmwAuKxyhZvTp0/LgAEDZO3atRIUdGXIOyUlRbp37y6LFi2SmjVrVnSdAIrYn5iub0WdLxJsFgztKO3DqxlaGwA43GqpUaNGSXp6uuzevVuvmFIP1fdGDRn94x//qPgqAZg5mJQuA+f+KecysgrOVfLykPmPdJCODQg2AFxbuebcqHtev/32m3To0MHsfExMjPTs2VOP4tgr5tzA0R06fUEGzPlTki9kFpzz9XLXwaZLoxqG1gYA9vD+7V7e7ReKbpapqHNszQBYz+EzF+TBuebBxsfTXT4ZQrABgBsKN7fddps888wzcvLkyYJzCQkJeqfw22+/vTy/JYDrOKKCzZw/5Uz61WDj7eku84a0l66NCTYAcEPhRjXxU8ND4eHh0qhRI/1Qv1bnPvzww/L8lgBKcSw5Q4/YnC4SbOYObi/dmjCBHwBueLVUWFiY7kK8evXqgqXgLVq0kB49epTntwNQiuNnrwSbpLRCwcbDXT4e1E5uaUqwAYAKa+Kngo16qGXhRefZzJ8/X+wVE4rhSOLPXZT+H2+Qk6mXC855ebjpYHNb82BDawMAp2riN2nSJJk8ebK0b99eateuLW5ubuWtFUApwUatiioabGY9RLABgAoPN7Nnz5YFCxbIoEGDyvNyANdx4vxFfSsqIeVSwTlPdzeZOfAm6dGSYAMAFT6hOCsrS7p06VKelwK4jpMpl3SwOXH+arDxcHeTGQPbSs9WIYbWBgBOG25GjBghX375ZcVXA7i4U6lXgk38OfNg8+GDbeXOiNqG1gYATn1b6vLlyzJnzhzdpbhNmzbXNPSbNm1aRdUHuIyktMsycO5GOX72YsE5dzeR6f2j5O7WBBsAsGq42bFjh0RFRelfqz2lCmNyMWC502mXdYO+o8kZZsHmn/2jpE9kHUNrAwCXCDdr1qyp+EoAF3U6/bK+FXWkSLCZ9kCU9I2qa2htAOAyc24AVAy1lYK6FXX4zNVgowY/37s/Uvq1JdgAQHkQbgCDqM0vH5r3p97lu3CweefvbeTem0INrQ0AHBnhBjDAuYwseXjeRjmQdDXYKG/f20bubx9mWF0A4LJzbgCUjdrdZGt8iqzakySpl7IlsJKXdGpQTaYs3yv7iwSbqfe2lgc6EGwA4EYRbgArOZCULmO/2S47TqSanZ+19vA11775twh5sGM9G1YHAM6LcANYKdjcN2u9pF3Oue61r/dtJQ9F17dJXQDgCphzA1jhVpQasSlLsKkbVEke7kSwAYCKRLgBKpiaY1P0VlRJ1MaY2+JTrF4TALgSwg1QwdTkYUustPB6AEDpCDdABVOroqx5PQCgdIQboIKp5d7WvB4AUDrCDVDB7mgZbNH1PS28HgBQOsINUMHahgVJk1pVynRtZGigRIUFWb0mAHAlhBuggplMIu5qW+/rCPD1lHfvjxQ3taEUAKDCEG6ACvZFTJzsT0y/7ojNkie7SNNgf5vVBQCuwi7CzcyZMyU8PFx8fX0lOjpaYmJiyvS6RYsW6Z96+/XrZ/UagbI4nXZZ3vlln9m5kd0byZO3NpKB0fX0x6VPdZHvR3Yl2ACAs26/sHjxYhkzZozMnj1bB5vp06dLr169ZP/+/VKrVq0SX3fs2DEZO3asdOvWzab1AqWZ/PMeSc+82pm4Y3g1ee6OZmW6TQUAcJKRm2nTpsmjjz4qQ4cOlZYtW+qQ4+fnJ/Pnzy/xNbm5ufLQQw/JpEmTpGHDhjatFyjJmv2n5ecdpwqOvTzc9IaYBBsAcKFwk5WVJbGxsdKjR4+rBbm76+MNGzaU+LrJkyfrUZ3hw4df98/IzMyUtLQ0swdQ0S5l5cor3+8yO/f4XxpJE249AYBrhZvk5GQ9ChMcbN7nQx0nJiYW+5p169bJJ598InPnzi3TnzF16lQJDAwseISFhVVI7UBh//f7QTlx/lLBcf3qfvL0bY0NrQkAXJXht6UskZ6eLoMGDdLBpkaNGmV6zfjx4yU1NbXgER8fb/U64Vr2JabJ3D+OmJ17o1+E+Hp5GFYTALgyQycUq4Di4eEhSUnmGweq45CQkGuuP3z4sJ5I3KdPn4JzeXl5+qOnp6eehNyoUSOz1/j4+OgHYA15eSaZ8N1OyckzFZzrG1VHujWpaWhdAODKDB258fb2lnbt2snq1avNwoo67ty58zXXN2/eXHbu3Cnbtm0reNxzzz3SvXt3/WtuOcHWvtoUJ1viUswa873cu6WhNQGAqzN8KbhaBj5kyBBp3769dOzYUS8Fz8jI0KunlMGDB0vdunX13BnVByciIsLs9UFBV1rXFz0PWNvp9MvydpGeNuPvbiE1/RkpBACXDjf9+/eXM2fOyKuvvqonEUdFRcmKFSsKJhnHxcXpFVSAvXnj572SdvlqT5v29atK//aMHgKA0dxMJrUTjutQS8HVqik1uTggIMDocuCg/jhwRgbPv9pJ29PdTZb9o5s0C2HpNwAY/f7NkAhgocvZufJykZ42j/2lIcEGAOwE4Qaw0Ie/H5S4cxcLjutV85NRtzUxtCYAwFWEG8ACB5LS5eN/m/e0eb1fhFTypqcNANgLwg1gQU+bl5aa97TpE1lHbmlKTxsAsCeEG6CMvt4cL5uOnS849vf1lFf+2sLQmgAA1yLcAGVwJj1Tpizfa3buxTubSy1/X8NqAgAUj3ADlMGby/aY9bRpWy9IBnasZ2hNAIDiEW6A61h3MFm+33bSrKfN1Htbi7u7m6F1AQCKR7gBrtvTZqfZueHdGkjzEBpAAoC9ItwApZi55pAcO3u1p01o1UryzO30tAEAe0a4AUpw6HS6zP734Wt62vh5G74lGwCgFIQboISeNhO+2yXZuVd72vRuXVu6N6tlaF0AgOsj3ADFWBJ7QmKOnSs49vfxlFf7tDS0JgBA2RBugCLOXsiUKb+Y97R54c5mEhxATxsAcASEG6CIN5fvlZSL2QXHUWFBMjC6vqE1AQDKjnADFLL+ULJ8tyWh4NjD3U2m/K21/ggAcAyEG6BQT5uXvt9ldm74zQ2kZR162gCAIyHcAP8za+1hOZqcUXBcN6iSPNuDnjYA4GgIN4DuaXNBh5vCJvdtRU8bAHBAhBu4PJPJJC8t3SlZuXkF5+6KCJHbWwQbWhcAoHwIN3B5325JkI1Hr/a0qeLjKRP7tDK0JgBA+RFu4NLOZWTJm8v2mJ0b27OphATS0wYAHBXhBi5tyvK9cr5QT5s2oYEyqHO4oTUBAG4M4QYua8Phs3qbhXyqlQ09bQDA8RFu4JIyc1RPm51m54Z2bSARdQMNqwkAUDEIN3BJs9cekSNnrva0qRPoK2PuaGpoTQCAikG4gcs5cuaCzFxzyOzcpL4RUtmHnjYA4AwIN3C5njYvf7/LrKdNr1bBckdLetoAgLMg3MClLN2aIOsPny04ruztIa/dQ08bAHAmhBu4jPMZWfLGsr1m557r2UxqB1YyrCYAQMUj3MBlvPXLPt20L1/ruoEypAs9bQDA2RBu4BI2HjkrizfHFxzT0wYAnBfhBi7R02bCUvOeNmrEpnUoPW0AwBkRbuD05vz7iBwu1NMmJMBXz7UBADgnwg2c2rHkDPmwSE8btTpK7fwNAHBOhBs4f0+bnKs9bXq0CNZ9bQAAzotwA6f1w7aTsu5QcsGxn7eHTOrbStzcmEQMAM6McAOnlHJR9bTZY3ZO7R1VN4ieNgDg7Ag3cEpvr9gnyReu9rRpWTtAHqGnDQC4BMINnM6mY+fkq5irPW3UXaip97YWTw++3AHAFfDdHk5FTR5+qWhPm87hEhkWZFhNAADbItzAqcz9zxE5kHSh4Dg4wEee69nU0JoAALZFuIHTOH42Q/5v9UGzc6/1aSX+vl6G1QQAsD3CDZyqp01moZ42tzevJXdGhBhaFwDA9gg3cAo/7Tgl/zl4tadNJS962gCAqyLcwOGlXsyWyT+Z97QZfUcTCa3qZ1hNAADjEG7g8N7+VfW0ySw4blE7QIZ2bWBoTQAA4xBu4NBij5+XLzfGFRyru1BT/hYhXvS0AQCXxTsAHFZ2bp5M+M68p83D0fWlbb2qhtUEADAe4QYOa95/jsr+pPSC45r+PvL8nc0MrQkAYDxPowsALFnuvTU+RVbtSZKT5y/JzztPmj0/sU9LCaCnDQC4PMINHMKBpHQZ+8122XEitdjnO4RXld6ta9u8LgCA/eG2FBwi2Nw3a32JwUbZeypdDp6+uu0CAMB1EW5g97ei1IhN2uWcUq+7kJkjz3+zXV8PAHBthBvYNTXHprQRm8K2n0iVbfEpVq8JAGDfCDewa2rysCVWWng9AMD5EG5g11IvZVv1egCA8yHcwK5Zuu1lYCWWggOAq7OLcDNz5kwJDw8XX19fiY6OlpiYmBKvnTt3rnTr1k2qVq2qHz169Cj1ejiuPSfTZNnOUxa9pmfLYKvVAwBwDIaHm8WLF8uYMWNk4sSJsmXLFomMjJRevXrJ6dOni71+7dq18uCDD8qaNWtkw4YNEhYWJj179pSEhASb1w7r+fPIWen/8QZJuVj220yRoYESFRZk1boAAPbPzWTw2lk1UtOhQweZMWOGPs7Ly9OBZdSoUTJu3Ljrvj43N1eP4KjXDx48+LrXp6WlSWBgoKSmpkpAQECF/B1QsVbsSpR/LNoqWTl5ZX5NgK+nLHmyizQN9rdqbQAAY1jy/m3oyE1WVpbExsbqW0sFBbm762M1KlMWFy9elOzsbKlWrVqxz2dmZupPSOEH7NdXMXHy1BexZsHG38dT3vl7G2kTGljiiA3BBgBgF9svJCcn65GX4GDzeRLqeN++fWX6PV588UWpU6eOWUAqbOrUqTJp0qQKqRfWowYQZ/x+SN5fdcDsfI0qPrJwWAdpVSdQ7m8fqvvYqOXealWUmjys5tioW1FubpZOPQYAOCuH3lvqrbfekkWLFul5OGoycnHGjx+v5/TkUyM36rYX7Edenkle+2m3fLbhuNn5+tX95LNhHaV+9cr6WAWYtvWq6gcAAHYZbmrUqCEeHh6SlGTeeE0dh4SElPra9957T4eb3377Tdq0aVPidT4+PvoB+5SZkytjvt4uy3aYr4pqVSdAFgztKDX9+X8HALCMoXNuvL29pV27drJ69eqCc2pCsTru3Llzia9755135PXXX5cVK1ZI+/btbVQtKpraD2rYgk3XBJsujarLosc6EWwAAI55W0rdMhoyZIgOKR07dpTp06dLRkaGDB06VD+vVkDVrVtXz51R3n77bXn11Vflyy+/1L1xEhMT9fkqVaroBxxD8oVMGfrpJtmZYL5v1N2tQ+Sf/aPEx9PDsNoAAI7N8HDTv39/OXPmjA4sKqhERUXpEZn8ScZxcXF6BVW+WbNm6VVW9913n9nvo/rkvPbaazavH5aLP3dRBn2yUY6dvWh2/uFO9WTSPRHi4c7kYACAA/e5sTX63BjfdXjIpzFyJj3T7PzoHk3lH7c3ZtUTAOCG378NH7mB69h45KyMWLhZ0jNzCs6pLPN63wh5uFN9Q2sDADgPwg1s4tfdiTLqK/Ouw94e7vLBgCi5q3VtQ2sDADgXwg2sblFMnExYulPyCt0AreLjKXMGt5MujWoYWRoAwAkRbmA1ajrXzDWH5L2V13YdXjC0g0TULX47BQAAbgThBlbrOjzpp92y8DpdhwEAqGiEG1il6/BzX2+Xn+k6DAAwAOEGFd51+InPY2XdoWSz850bVtdzbPx9vQyrDQDgGgg3sEnX4WkPRImvF12HAQDWR7hBhXUdHjw/Ro4mZ5idp+swAMDWCDe4YXtPpelgU7Tr8LM9msgztzeh6zAAwKYIN7jxrsOfbZb0y3QdBgDYB8INKrzr8PQBUXI3XYcBAAYh3KBc6DoMALBXhBtYhK7DAAB7R7iBRV2HJ/+8RxasP2Z2vl41P/l8OF2HAQD2gXCDMlHzap77Zrv8tP2k2fmWtQNkwbAOUsvf17DaAAAojHCDcncd7tSwmswZ3F4C6DoMALAjhBuU6qzqOrxgk+w4Yd51+K6IEPlnf7oOAwDsD+EGFncdfii6nkzuS9dhAIB9ItygxK7DQ+bHyOkiXYdVx2HVeZiuwwAAe0W4wTVijp6T4Qs3XdN1WI3WDKLrMADAzhFuYGbl7kR5mq7DAAAHRrhx4WZ8W+NTZNWeJEm9lC2BlbxE3WiatfawFGo6fKXr8KB20qUxXYcBAI6BcOOCDiSly9hvtl+zAqqoGlW8ZcHQjnQdBgA4FMKNCwab+2atl7RC82mKExLgK4se6yThNeg6DABwLO5GFwDb3opSIzbXCzZKtcreUr+6n03qAgCgIhFuXIiaY3O9W1H59pxKk23xKVavCQCAika4cSFq8rAlVlp4PQAA9oBw4yL2J6bLj9sSLHqNWkUFAICjYUKxkzuVekmmrTwg3245IXmF13iXgVoeDgCAoyHcOKnUi9ny0b8PyYL/HpPMQg35LNGzZXCF1wUAgLURbpzM5exc+XzDcZmx5lCxt5XUXpdlGcGJDA2UqLAg6xQJAIAVEW6cRG6eSb7fmiDTVh2QhJRL1zzv5eEmD0XXl7tbh8iIhZtLXQ4e4Osp794fyeaYAACHRLhxgt41aw+ckbd/2Sf7EtOLveaeyDryXM+mUr/6lYZ8S57sUmKHYjVio4JN02B/q9cOAIA1EG4c2Pb4FHnrl32y4cjZYp/v2ri6jLuzhbQONd8+QQWXH0Z21X1sVhbaW0rNsVG3ohixAQA4MsKNAzqWnCHvrtwvy3acKvb5lrUDZNxdzaVbkxolBhV1vm29qvoBAIAzIdw4kOQLmfLh6oPyxcY4ySlmVnDdoEoytldT6RtZV9zVzGEAAFwQ4cYBZGTmyLz/HJU5fxyWjKzca54P8vOSp7s3lkGd64uPp4chNQIAYC8IN3YsOzdPFm2Klw9+O6hHbYry9XKXYV0byBO3NpIAXxruAQCgEG7sdAXUL7sS5d1f98vR5Ixrnld3nB5oHybP9mgqIYG+htQIAIC9ItzYmY1HzsrUX/aVuCN3jxbB8uKdzaQJS7UBACgW4caONrZ8e8U++X3f6WKfv6lekIy/u4V0CK9m89oAAHAkhBuDnUy5pLsKq40tTcVsi9CwZmV58c7mugcN/WcAALg+wo2dbmxZ099HRvdoKg+0DxVPD3dDagQAwBERbgzY2PKzDcdk5prDxW5sWcXHU564paEMu7mB+HnzvwcAAEvx7mnDjS2Xqo0tV+6Xk6mXi93Y8uFO9XW/mupVfAypEQAAZ0C4qaCl21vjU2RVoX2a7mgZLG3DgvTza/ef0ZOFS9rYsm9UHXnujmZSr7qfjSsHAMD5EG5u0IGk9GJ32J619rA0rlVF/LzdZceJtGJfe3PjGnoPqIi65htbAgCA8iPc3GCwuW/Wekm7nFPs84dOXyh1Y8u/NK1p5QoBAHA9hJsbuBWlRmxKCjbFCa1aScb2bCb3RNZhY0sAAKyEcFNOao5N0VtRpRnWNVxevKs5G1sCAGBlNFApJzV52BI+Xh4EGwAAbIBwU07F9aipyOsBAED5EG7KSS33tub1AACgfAg35aT62FhC7Q0FAACsj3BTTqpBX5vQsvWniQwNlKj/NfQDAADWRbgpJ7VD93v3R0qAb+kLztTz794fyY7eAADYCOHmBjQN9pclT3YpcQRHjdio59V1AADANuhzc4NUcPlhZFfZFp8iKwvtLaXm2KhbUYzYAADggiM3M2fOlPDwcPH19ZXo6GiJiYkp9fpvvvlGmjdvrq9v3bq1LF++XIykAkzbelXlxTuby5S/tdYf1THBBgAAFww3ixcvljFjxsjEiRNly5YtEhkZKb169ZLTp08Xe/369evlwQcflOHDh8vWrVulX79++rFr1y6b1w4AAOyPm0ltkmQgNVLToUMHmTFjhj7Oy8uTsLAwGTVqlIwbN+6a6/v37y8ZGRny888/F5zr1KmTREVFyezZs6+5PjMzUz/ypaWl6d8/NTVVAgICrPb3AgAAFUe9fwcGBpbp/dvQkZusrCyJjY2VHj16XC3I3V0fb9iwodjXqPOFr1fUSE9J10+dOlV/MvIfKtgAAADnZWi4SU5OltzcXAkONm9wp44TExOLfY06b8n148eP1ykv/xEfH1+BfwMAAGBvnH61lI+Pj34AAADXYOjITY0aNcTDw0OSksx32FbHISEhxb5GnbfkegAA4FoMDTfe3t7Srl07Wb16dcE5NaFYHXfu3LnY16jzha9XVq1aVeL1AADAtRh+W0otAx8yZIi0b99eOnbsKNOnT9eroYYOHaqfHzx4sNStW1dPDFaeeeYZueWWW+T999+X3r17y6JFi2Tz5s0yZ84cg/8mAADAHhgebtTS7jNnzsirr76qJwWrJd0rVqwomDQcFxenV1Dl69Kli3z55Zfy8ssvy4QJE6RJkyby/fffS0RERJn+vPyV72pJGQAAcAz579tl6WBjeJ8bWztx4gTLwQEAcFBq1XNoaGip17hcuFFzek6ePCn+/v4Vvj1CfoNA9Yl3lAaB1Gwb1Gwb1Gwbjlizo9ZNzVepuJKeni516tQxu6Njl7elbE19Qq6X+G6U+p/pKF+E+ajZNqjZNqjZNhyxZketm5qvUM14HWJvKQAAgIpEuAEAAE6FcFOBVCdktbu5I3VEpmbboGbboGbbcMSaHbVuai4fl5tQDAAAnBsjNwAAwKkQbgAAgFMh3AAAAKdCuAEAAE6FcGOhmTNnSnh4uPj6+kp0dLTExMSUev0333wjzZs319e3bt1ali9fLvZc8+7du+Xvf/+7vl51cFYbmRrBkprnzp0r3bp1k6pVq+pHjx49rvv/xeiav/vuO71ZbFBQkFSuXFnvqfb555+LvX8951Mb1qqvj379+ok917xgwQJdZ+GHep29f55TUlJk5MiRUrt2bb3ipGnTpjb/3mFJzbfeeus1n2f1UJsb22vNivr+1qxZM6lUqZLuqDt69Gi5fPmy2JoldWdnZ8vkyZOlUaNG+vrIyEi9H6Ot/PHHH9KnTx/dJVj9P1Z7O17P2rVr5aabbtJfy40bN9b/Lq1OrZZC2SxatMjk7e1tmj9/vmn37t2mRx991BQUFGRKSkoq9vr//ve/Jg8PD9M777xj2rNnj+nll182eXl5mXbu3Gm3NcfExJjGjh1r+uqrr0whISGmf/7znzartbw1Dxw40DRz5kzT1q1bTXv37jU98sgjpsDAQNOJEyfstuY1a9aYvvvuO/11cejQIdP06dP118qKFSvstuZ8R48eNdWtW9fUrVs3U9++fU22ZGnNn376qSkgIMB06tSpgkdiYqJd15yZmWlq37696e677zatW7dOf77Xrl1r2rZtm93WfPbsWbPP8a5du/TXs/r822vNX3zxhcnHx0d/VJ/jX3/91VS7dm3T6NGjbVZzeep+4YUXTHXq1DEtW7bMdPjwYdNHH31k8vX1NW3ZssUm9S5fvtz00ksv6e9fKkIsXbq01OuPHDli8vPzM40ZM0Z/v/vwww9t8r2OcGOBjh07mkaOHFlwnJubq7/Ipk6dWuz1DzzwgKl3795m56Kjo02PP/64yV5rLqx+/fqGhJsbqVnJyckx+fv7mxYuXGhylJqVtm3b6gBszzWrz22XLl1M8+bNMw0ZMsTm4cbSmtWbqwq6RrK05lmzZpkaNmxoysrKMhnlRr+e1fcN9W/wwoULJnutWV172223mZ1Tb8Bdu3Y12ZKldasANmPGDLNz9957r+mhhx4y2ZqUIdyoMNaqVSuzc/379zf16tXLqrVxW6qMsrKyJDY2Vt/yKLxPlTresGFDsa9R5wtfr/Tq1avE6+2hZqNVRM0XL17UQ7fVqlUTR6hZfY9YvXq17N+/X/7yl7+IPdeshsNr1aolw4cPF1srb80XLlyQ+vXr69sOffv21bde7bnmH3/8UTp37qxvSwUHB0tERIRMmTJFcnNz7bbmoj755BMZMGCAvuVqrzV36dJFvyb/FtCRI0f0rb+7777bJjWXt+7MzMxrbq2q22rr1q0Te7TBoPdBwk0ZJScn628u6ptNYeo4MTGx2Neo85Zcbw81G60ian7xxRf1/eCi/6DsrebU1FSpUqWKeHt767kJH374odxxxx12W7P65qnetNQcJyOUp2Y1n2L+/Pnyww8/yL/+9S/Jy8vTb2onTpyw25rVm+ySJUv069Sb7SuvvCLvv/++vPHGG3Zbc2EqLOzatUtGjBghtlKemgcOHKjD+s033yxeXl56DouaOzRhwgS7rlsFg2nTpsnBgwf11/OqVav0HL5Tp06JPUos4X1Q7Rx+6dIlq/25hBs4lbfeektPdl26dKkhE0ct4e/vL9u2bZNNmzbJm2++KWPGjNET7+xRenq6DBo0SAebGjVqiKNQIyCDBw/WE7ZvueUW/SZQs2ZN+fjjj8VeqTcsNTo2Z84cadeunfTv319eeuklmT17tjgCFYDV4omOHTuKPVP/1tSI2EcffSRbtmzRXxvLli2T119/XezZBx98IE2aNNELVdQPRk8//bQMHTpUj/jgKs9Cv0Yp1Dd0Dw8PSUpKMjuvjkNCQop9jTpvyfX2ULPRbqTm9957T4eb3377Tdq0aSP2XrP6ZqRWDijqzXfv3r0ydepU/dOjvdV8+PBhOXbsmF4lUfhNWPH09NS31NRPvvZUc3HUT+ht27aVQ4cOiS2Up2a1QkrVqV6Xr0WLFvonYHUbQ72h2VvN+TIyMvQPF2pExJbKU7MaEVOBPX+ESQUyVf9jjz2mw6QtwkJ56lbhXK1QUqu6zp49q0epx40bJw0bNhR7FFLC+2BAQIC+nWYtRL0yUt9Q1E9Ram5E4W/u6lj9dFgcdb7w9YoaQizpenuo2Wjlrfmdd97RP3GpJZFqibUtVdTnWb1G3U+3x5rVT4k7d+7UI035j3vuuUe6d++uf63ms9hbzcVRtwDU30MFCFsoT81du3bV4Ss/PCoHDhzQNVs72JS35sKtL9TX8MMPPyy2VJ6a1dy8ogEmP1DaasvFG/lcq5HpunXrSk5Ojnz77bd6Ppk96mzU+6BVpys7GbVkTy0dXLBggV7S9thjj+kle/lLSwcNGmQaN26c2VJwT09P03vvvaeXKE+cONGQpeCW1KyWoaol1eqhZuWrZeHq1wcPHrTbmt966y29lHLJkiVmy1HT09PttuYpU6aYVq5cqZdyquvV14j6Wpk7d67d1lyUEaulLK150qRJeomv+jzHxsaaBgwYoJfNqiW39lpzXFycXmn09NNPm/bv32/6+eefTbVq1TK98cYbdltzvptvvlmvhDGCpTWr78fq86zaXqjlyurfY6NGjfQqV3uu+88//zR9++23+mv6jz/+0Cu+GjRoYDp//rxN6k1PTy94j1ARYtq0afrXx48f18+rWlXNRZeCP//88/p9ULXtYCm4HVJr9OvVq6ffTNUSPvWFlu+WW27R3/AL+/rrr01NmzbV16vlcKo3gT3XrPo9qC/Yog91nb3WrJasF1ez+uZlrzWrPhGNGzfWb7RVq1Y1de7cWX+Ts/evZ6PDjaU1P/vsswXXBgcH694xtuoHUt6alfXr1+u2EepNTy0Lf/PNN/UyfHuued++ffrfnQoJRrGk5uzsbNNrr72mA436dxgWFmZ66qmnbBYSylu36nnUokUL/bVRvXp1HSQSEhJsVuuaNWuK/X6bX6P6WPT9Qr0mKipK//3U17Mt+h+5qf9Yd2wIAADAdphzAwAAnArhBgAAOBXCDQAAcCqEGwAA4FQINwAAwKkQbgAAgFMh3AAAAKdCuAEAAE6FcAMAAJwK4QaA01A7qj/77LNGlwHAYIQbAADgVNhbCoBTeOSRR2ThwoVm544ePSrh4eGG1QTAGIQbAE4hNTVV7rrrLomIiJDJkyfrczVr1hQPDw+jSwNgY562/gMBwBoCAwPF29tb/Pz8JCQkxOhyABiIOTcAAMCpEG4AAIBTIdwAcBrqtlRubq7RZQAwGOEGgNNQK6M2btwox44dk+TkZMnLyzO6JAAGINwAcBpjx47Vq6NatmypV0rFxcUZXRIAA7AUHAAAOBVGbgAAgFMh3AAAAKdCuAEAAE6FcAMAAJwK4QYAADgVwg0AAHAqhBsAAOBUCDcAAMCpEG4AAIBTIdwAAACnQrgBAADiTP4fowMfOlTCG08AAAAASUVORK5CYII=",
296 | "text/plain": [
297 | ""
298 | ]
299 | },
300 | "metadata": {},
301 | "output_type": "display_data"
302 | }
303 | ],
304 | "source": [
305 | "import pandas as pd\n",
306 | "import seaborn as sns\n",
307 | "\n",
308 | "results_df = pd.DataFrame(\n",
309 | " data=[(t, r, norm) for ((t, r), norm) in results.items()],\n",
310 | " columns=[\"t\", \"r\", \"norm\"],\n",
311 | ")\n",
312 | "sns.pointplot(results_df, x=\"t\", y=\"norm\");"
313 | ]
314 | }
315 | ],
316 | "metadata": {
317 | "kernelspec": {
318 | "display_name": ".venv",
319 | "language": "python",
320 | "name": "python3"
321 | },
322 | "language_info": {
323 | "codemirror_mode": {
324 | "name": "ipython",
325 | "version": 3
326 | },
327 | "file_extension": ".py",
328 | "mimetype": "text/x-python",
329 | "name": "python",
330 | "nbconvert_exporter": "python",
331 | "pygments_lexer": "ipython3",
332 | "version": "3.11.11"
333 | }
334 | },
335 | "nbformat": 4,
336 | "nbformat_minor": 2
337 | }
338 |
--------------------------------------------------------------------------------
/examples/lorenz63/.gitignore:
--------------------------------------------------------------------------------
1 | *.json
2 | *.png
--------------------------------------------------------------------------------
/examples/lorenz63/main.py:
--------------------------------------------------------------------------------
1 | """Lorenz 63 example."""
2 |
3 | import functools
4 | import json
5 | import sys
6 | from collections import defaultdict
7 | from pathlib import Path
8 | from time import perf_counter
9 |
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import scipy.integrate
13 | from loguru import logger
14 | from scipy.spatial.distance import pdist
15 | from sklearn.gaussian_process.kernels import RBF
16 |
17 | import linear_operator_learning as lol
18 |
19 | # Configure logger
20 | logger.remove()
21 | logger.add(
22 | sys.stdout, # Log to standard output
23 | format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
24 | colorize=True, # Enable colorization
25 | )
26 |
27 | this_folder = Path(__file__).parent.resolve()
28 |
29 |
30 | # Adapted from https://realpython.com/python-timer/#creating-a-python-timer-decorator
31 | def timer(func):
32 | """A decorator that times the execution of a function.
33 |
34 | Args:
35 | func: Callable to be timed.
36 |
37 | Returns:
38 | Tuple containing the output of `func` and the time it took to execute.
39 | """
40 |
41 | @functools.wraps(func)
42 | def wrapper_timer(*args, **kwargs):
43 | tic = perf_counter()
44 | value = func(*args, **kwargs)
45 | toc = perf_counter()
46 | elapsed_time = toc - tic
47 | return value, elapsed_time
48 |
49 | return wrapper_timer
50 |
51 |
52 | def L63(x: np.ndarray, t: int = 1, sigma=10, mu=28, beta=8 / 3, dt=0.01):
53 | """Simulates the Lorenz '63 system using given initial conditions and parameters.
54 |
55 | Args:
56 | x (np.ndarray): Initial state vector of the system.
57 | t (int, optional): Total number of time steps for the simulation. Default is 1.
58 | sigma (float, optional): Prandtl number. Default is 10.
59 | mu (float, optional): Rayleigh number minus 1. Default is 28.
60 | beta (float, optional): Geometric factor. Default is 8/3.
61 | dt (float, optional): Time step size. Default is 0.01.
62 |
63 | Returns:
64 | np.ndarray: Array containing the state of the system at each time step.
65 | """
66 | M_lin = np.array([[-sigma, sigma, 0], [mu, 0, 0], [0, 0, -beta]])
67 |
68 | def D(_, x):
69 | dx = M_lin @ x
70 | dx[1] -= x[2] * x[0]
71 | dx[2] += x[0] * x[1]
72 | return dx
73 |
74 | sim_time = dt * (t + 1)
75 | t_eval = np.linspace(0, sim_time, t + 1, endpoint=True)
76 | t_span = (0, t_eval[-1])
77 | sol = scipy.integrate.solve_ivp(D, t_span, x, t_eval=t_eval, method="RK45")
78 | return sol.y.T
79 |
80 |
81 | def make_dataset():
82 | """Generates a dataset of the Lorenz '63 system, with the given parameters.
83 |
84 | Returns:
85 | dict: A dictionary containing the train and test datasets.
86 | """
87 | train_samples = 10000
88 | test_samples = 100
89 | t = train_samples + 1000 + test_samples
90 | x = np.ones(3)
91 | raw_data = L63(x, t=t)
92 | # raw_data = LogisticMap(N=20).sample(X0 = np.ones(1), T=configs.train_samples + 1000 + configs.test_samples)
93 | mean = np.mean(raw_data, axis=0)
94 | norm = np.max(np.abs(raw_data), axis=0)
95 | # Data rescaling
96 | data = raw_data - mean
97 | data /= norm
98 |
99 | train_data = data[: train_samples + 1]
100 | test_data = data[-test_samples - 1 :]
101 | return {"train": train_data, "test": test_data}
102 |
103 |
104 | def center_selection(num_pts: int, num_centers: int, rng_seed: int | None = None):
105 | """Randomly selects a specified number of center indices from a given number of points.
106 |
107 | Args:
108 | num_pts (int): Total number of available points to select from.
109 | num_centers (int): Number of center indices to select.
110 | rng_seed (int | None, optional): Seed for the random number generator. Defaults to None.
111 |
112 | Returns:
113 | ndarray: Array of selected center indices.
114 | """
115 | rng = np.random.default_rng(rng_seed)
116 | rand_indices = rng.choice(num_pts, num_centers)
117 | return rand_indices
118 |
119 |
120 | def main():
121 | """Main entry point for the Lorenz 63 benchmarks.
122 |
123 | This script trains multiple reduced rank regression models with different
124 | training set sizes and plots the training times and root mean squared errors
125 | (rMSEs).
126 | """
127 | data = make_dataset()
128 | train_data = data["train"]
129 | test_data = data["test"]
130 | # Length scale of the kernel: median of the pairwise distances of the dataset
131 | data_pdist = pdist(train_data)
132 | kernel = RBF(length_scale=np.quantile(data_pdist, 0.5))
133 |
134 | rank = 25
135 | num_centers = 250
136 | tikhonov_reg = 1e-6
137 |
138 | train_stops = np.logspace(3, 4, 10).astype(int)
139 | timings = defaultdict(list)
140 | rMSEs = defaultdict(list)
141 |
142 | for stop in train_stops:
143 | logger.info(f"######## {stop} Training points ########")
144 | assert stop < len(train_data)
145 | X = train_data[:stop]
146 | Y = train_data[1 : stop + 1]
147 | kernel_X = kernel(X)
148 | kernel_Y = kernel(Y)
149 | kernel_YX = kernel(Y, X)
150 | X_test = test_data[:-1]
151 | kernel_Xtest_X = kernel(X_test, X)
152 | Y_test = test_data[1:]
153 |
154 | # Vanilla Reduced Rank Regression
155 | fit_results, fit_time = timer(lol.kernel.reduced_rank)(
156 | kernel_X, kernel_Y, tikhonov_reg, rank
157 | )
158 | Y_pred = lol.kernel.predict(1, fit_results, kernel_YX, kernel_Xtest_X, Y)
159 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2))
160 | timings["Vanilla RRR"].append(fit_time)
161 | logger.info(f"Vanilla RRR {fit_time:.1e}s")
162 | rMSEs["Vanilla RRR"].append(rMSE.item())
163 | fit_results, fit_time = timer(lol.kernel.pcr)(kernel_X, tikhonov_reg, rank)
164 | Y_pred = lol.kernel.predict(1, fit_results, kernel_YX, kernel_Xtest_X, Y)
165 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2))
166 | timings["Vanilla PCR"].append(fit_time)
167 | logger.info(f"Vanilla PCR {fit_time:.1e}s")
168 | rMSEs["Vanilla PCR"].append(rMSE.item())
169 | # Nystroem
170 | center_idxs = center_selection(len(X), num_centers, rng_seed=42)
171 | fit_results, fit_time = timer(lol.kernel.nystroem_reduced_rank)(
172 | kernel(X[center_idxs]),
173 | kernel(Y[center_idxs]),
174 | kernel_X[:, center_idxs],
175 | kernel_Y[:, center_idxs],
176 | tikhonov_reg,
177 | rank,
178 | )
179 | Y_pred = lol.kernel.predict(
180 | 1,
181 | fit_results,
182 | kernel(Y[center_idxs], X[center_idxs]),
183 | kernel_Xtest_X[:, center_idxs],
184 | Y[center_idxs],
185 | )
186 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2))
187 | timings["Nystroem RRR"].append(fit_time)
188 | logger.info(f"Nystroem RRR {fit_time:.1e}s")
189 | rMSEs["Nystroem RRR"].append(rMSE.item())
190 |
191 | center_idxs = center_selection(len(X), num_centers, rng_seed=42)
192 | fit_results, fit_time = timer(lol.kernel.nystroem_pcr)(
193 | kernel(X[center_idxs]),
194 | kernel(Y[center_idxs]),
195 | kernel_X[:, center_idxs],
196 | kernel_Y[:, center_idxs],
197 | tikhonov_reg,
198 | rank,
199 | )
200 | Y_pred = lol.kernel.predict(
201 | 1,
202 | fit_results,
203 | kernel(Y[center_idxs], X[center_idxs]),
204 | kernel_Xtest_X[:, center_idxs],
205 | Y[center_idxs],
206 | )
207 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2))
208 | timings["Nystroem PCR"].append(fit_time)
209 | logger.info(f"Nystroem PCR {fit_time:.1e}s")
210 | rMSEs["Nystroem PCR"].append(rMSE.item())
211 | # Randomized
212 | fit_results, fit_time = timer(lol.kernel.rand_reduced_rank)(
213 | kernel_X, kernel_Y, tikhonov_reg, rank
214 | )
215 | Y_pred = lol.kernel.predict(1, fit_results, kernel_YX, kernel_Xtest_X, Y)
216 | rMSE = np.sqrt(np.mean((Y_pred - Y_test) ** 2))
217 | timings["Randomized RRR"].append(fit_time)
218 | logger.info(f"Randomized RRR {fit_time:.1e}s")
219 | rMSEs["Randomized RRR"].append(rMSE.item())
220 |
221 | # Save results to json
222 | with open(this_folder / "timings.json", "w") as f:
223 | json.dump(timings, f)
224 | with open(this_folder / "rMSEs.json", "w") as f:
225 | json.dump(rMSEs, f)
226 | # Plot results
227 | fig, axes = plt.subplots(ncols=2, figsize=(7, 3), dpi=144)
228 |
229 | for name in [
230 | "Vanilla RRR",
231 | "Vanilla PCR",
232 | "Nystroem RRR",
233 | "Nystroem PCR",
234 | "Randomized RRR",
235 | ]:
236 | axes[0].plot(train_stops, rMSEs[name], ".-", label=name)
237 | axes[1].plot(train_stops, timings[name], ".-", label=name)
238 |
239 | axes[0].set_title("rMSE")
240 | axes[1].set_title("Training time (s)")
241 | axes[1].legend(frameon=False, loc="upper left")
242 | axes[1].set_yscale("log")
243 | for ax in axes:
244 | ax.set_xscale("log")
245 | ax.set_xlabel("Training points")
246 | plt.tight_layout()
247 | # Save plot
248 | plt.savefig(this_folder / "reduced_rank_benchmarks.png")
249 |
250 |
251 | if __name__ == "__main__":
252 | main()
253 |
--------------------------------------------------------------------------------
/linear_operator_learning/__init__.py:
--------------------------------------------------------------------------------
1 | """Main entry point for linear operator learning."""
2 |
3 | import sys
4 |
5 |
6 | def __getattr__(attr):
7 | """Lazy-import submodules."""
8 | submodules = {
9 | "kernel": "linear_operator_learning.kernel",
10 | "nn": "linear_operator_learning.nn",
11 | }
12 |
13 | if attr in submodules:
14 | module = __import__(submodules[attr], fromlist=[""])
15 | setattr(sys.modules[__name__], attr, module) # Attach to the module namespace
16 | return module
17 |
18 | raise AttributeError(f"Unknown submodule {attr}")
19 |
--------------------------------------------------------------------------------
/linear_operator_learning/kernel/__init__.py:
--------------------------------------------------------------------------------
1 | """Kernel methods entry point."""
2 |
3 | from linear_operator_learning.kernel.regressors import * # noqa
4 |
--------------------------------------------------------------------------------
/linear_operator_learning/kernel/linalg.py:
--------------------------------------------------------------------------------
1 | """Linear algebra utilities for the `kernel` algorithms."""
2 |
3 | from warnings import warn
4 |
5 | import numpy as np
6 | from numpy import ndarray
7 |
8 | from linear_operator_learning.kernel.utils import topk
9 |
10 |
11 | def add_diagonal_(M: ndarray, alpha: float):
12 | """Add alpha to the diagonal of M inplace.
13 |
14 | Args:
15 | M (ndarray): The matrix to modify inplace.
16 | alpha (float): The value to add to the diagonal of M.
17 | """
18 | np.fill_diagonal(M, M.diagonal() + alpha)
19 |
20 |
21 | def stable_topk(
22 | vec: ndarray,
23 | k_max: int,
24 | rcond: float | None = None,
25 | ignore_warnings: bool = True,
26 | ):
27 | """Takes up to k_max indices of the top k_max values of vec. If the values are below rcond, they are discarded.
28 |
29 | Args:
30 | vec (ndarray): Vector to extract the top k indices from.
31 | k_max (int): Number of indices to extract.
32 | rcond (float, optional): Value below which the values are discarded. Defaults to None, in which case it is set according to the machine precision of vec's dtype.
33 | ignore_warnings (bool): If False, raise a warning when some elements are discarted for being below the requested numerical precision.
34 |
35 | """
36 | if rcond is None:
37 | rcond = 10.0 * vec.shape[0] * np.finfo(vec.dtype).eps
38 |
39 | top_vec, top_idxs = topk(vec, k_max)
40 |
41 | if all(top_vec > rcond):
42 | return top_vec, top_idxs
43 | else:
44 | valid = top_vec > rcond
45 | # In the case of multiple occurrences of the maximum vec, the indices corresponding to the first occurrence are returned.
46 | first_invalid = np.argmax(np.logical_not(valid))
47 | _first_discarded_val = np.max(np.abs(vec[first_invalid:]))
48 |
49 | if not ignore_warnings:
50 | warn(
51 | f"Warning: Discarted {k_max - vec.shape[0]} dimensions of the {k_max} requested due to numerical instability. Consider decreasing the k. The largest discarded value is: {_first_discarded_val:.3e}."
52 | )
53 | return top_vec[valid], top_idxs[valid]
54 |
55 |
56 | def weighted_norm(A: ndarray, M: ndarray | None = None):
57 | r"""Weighted norm of the columns of A.
58 |
59 | Args:
60 | A (ndarray): 1D or 2D array. If 2D, the columns are treated as vectors.
61 | M (ndarray or LinearOperator, optional): Weigthing matrix. the norm of the vector :math:`a` is given by :math:`\langle a, Ma \rangle` . Defaults to None, corresponding to the Identity matrix. Warning: no checks are
62 | performed on M being a PSD operator.
63 |
64 | Returns:
65 | (ndarray or float): If ``A.ndim == 2`` returns 1D array of floats corresponding to the norms of
66 | the columns of A. Else return a float.
67 | """
68 | assert A.ndim <= 2, "'A' must be a vector or a 2D array"
69 | if M is None:
70 | norm = np.linalg.norm(A, axis=0)
71 | else:
72 | _A = np.dot(M, A)
73 | _A_T = np.dot(M.T, A)
74 | norm = np.real(np.sum(0.5 * (np.conj(A) * _A + np.conj(A) * _A_T), axis=0))
75 | rcond = 10.0 * A.shape[0] * np.finfo(A.dtype).eps
76 | norm = np.where(norm < rcond, 0.0, norm)
77 | return np.sqrt(norm)
78 |
--------------------------------------------------------------------------------
/linear_operator_learning/kernel/regressors.py:
--------------------------------------------------------------------------------
1 | """Kernel-based regressors for linear operators."""
2 |
3 | from math import sqrt
4 | from typing import Literal
5 | from warnings import warn
6 |
7 | import numpy as np
8 | import scipy.linalg
9 | from numpy import ndarray
10 | from scipy.sparse.linalg import eigs, eigsh
11 |
12 | from linear_operator_learning.kernel.linalg import (
13 | add_diagonal_,
14 | stable_topk,
15 | weighted_norm,
16 | )
17 | from linear_operator_learning.kernel.structs import EigResult, FitResult
18 | from linear_operator_learning.kernel.utils import sanitize_complex_conjugates
19 |
20 | __all__ = [
21 | "predict",
22 | "eig",
23 | "evaluate_eigenfunction",
24 | "pcr",
25 | "nystroem_pcr",
26 | "reduced_rank",
27 | "nystroem_reduced_rank",
28 | "rand_reduced_rank",
29 | ]
30 |
31 |
32 | def eig(
33 | fit_result: FitResult,
34 | K_X: ndarray, # Kernel matrix of the input data
35 | K_YX: ndarray, # Kernel matrix between the output data and the input data
36 | ) -> EigResult:
37 | """Computes the eigendecomposition of a regressor.
38 |
39 | Args:
40 | fit_result (FitResult): Fit result as defined in ``linear_operator_learning.kernel.structs``.
41 | K_X (ndarray): Kernel matrix of the input data.
42 | K_YX (ndarray): Kernel matrix between the output data and the input data.
43 |
44 |
45 | Shape:
46 | ``K_X``: :math:`(N, N)`, where :math:`N` is the sample size.
47 |
48 | ``K_YX``: :math:`(N, N)`, where :math:`N` is the sample size.
49 |
50 | Output: ``U, V`` of shape :math:`(N, R)`, ``svals`` of shape :math:`R`
51 | where :math:`N` is the sample size and :math:`R` is the rank of the regressor.
52 | """
53 | # SUV.TZ -> V.T K_YX U (right ev = SUvr, left ev = ZVvl)
54 | U = fit_result["U"]
55 | V = fit_result["V"]
56 | r_dim = (K_X.shape[0]) ** (-1)
57 |
58 | W_YX = np.linalg.multi_dot([V.T, r_dim * K_YX, U])
59 | W_X = np.linalg.multi_dot([U.T, r_dim * K_X, U])
60 |
61 | values, vl, vr = scipy.linalg.eig(W_YX, left=True, right=True) # Left -> V, Right -> U
62 | values = sanitize_complex_conjugates(values)
63 | r_perm = np.argsort(values)
64 | vr = vr[:, r_perm]
65 | l_perm = np.argsort(values.conj())
66 | vl = vl[:, l_perm]
67 | values = values[r_perm]
68 |
69 | rcond = 1000.0 * np.finfo(U.dtype).eps
70 | # Normalization in RKHS
71 | norm_r = weighted_norm(vr, W_X)
72 | norm_r = np.where(norm_r < rcond, np.inf, norm_r)
73 | vr = vr / norm_r
74 |
75 | # Bi-orthogonality of left eigenfunctions
76 | norm_l = np.diag(np.linalg.multi_dot([vl.T, W_YX, vr]))
77 | norm_l = np.where(np.abs(norm_l) < rcond, np.inf, norm_l)
78 | vl = vl / norm_l
79 | result: EigResult = {"values": values, "left": V @ vl, "right": U @ vr}
80 | return result
81 |
82 |
83 | def evaluate_eigenfunction(
84 | eig_result: EigResult,
85 | which: Literal["left", "right"],
86 | K_Xin_X_or_Y: ndarray,
87 | ):
88 | """Evaluates left or right eigenfunctions of a regressor.
89 |
90 | Args:
91 | eig_result: EigResult object containing eigendecomposition results
92 | which: String indicating "left" or "right" eigenfunctions
93 | K_Xin_X_or_Y: Kernel matrix between initial conditions and input data (for right
94 | eigenfunctions) or output data (for left eigenfunctions)
95 |
96 |
97 | Shape:
98 | ``eig_result``: ``U, V`` of shape :math:`(N, R)`, ``svals`` of shape :math:`R`
99 | where :math:`N` is the sample size and :math:`R` is the rank of the regressor.
100 |
101 | ``K_Xin_X_or_Y``: :math:`(N_0, N)`, where :math:`N_0` is the number of inputs to
102 | predict and :math:`N` is the sample size.
103 |
104 | Output: :math:`(N_0, R)`
105 | """
106 | vr_or_vl = eig_result[which]
107 | rsqrt_dim = (K_Xin_X_or_Y.shape[1]) ** (-0.5)
108 | return np.linalg.multi_dot([rsqrt_dim * K_Xin_X_or_Y, vr_or_vl])
109 |
110 |
111 | def predict(
112 | num_steps: int,
113 | fit_result: FitResult,
114 | kernel_YX: ndarray,
115 | kernel_Xin_X: ndarray,
116 | obs_train_Y: ndarray,
117 | ) -> ndarray:
118 | """Predicts future states given initial values using a fitted regressor.
119 |
120 | Args:
121 | num_steps (int): Number of steps to predict forward (returns the last prediction)
122 | fit_result (FitResult): FitResult object containing fitted U and V matrices
123 | kernel_YX (ndarray): Kernel matrix between output data and input data (or inducing points for Nystroem)
124 | kernel_Xin_X (ndarray): Kernel matrix between initial conditions and input data (or inducing points for Nystroem)
125 | obs_train_Y (ndarray): Observable evaluated on output training data (or inducing points for Nystroem)
126 |
127 | Shape:
128 | ``kernel_YX``: :math:`(N, N)`, where :math:`N` is the number of training data, or inducing points for Nystroem.
129 |
130 | ``kernel_Xin_X``: :math:`(N_0, N)`, where :math:`N_0` is the number of inputs to predict.
131 |
132 | ``obs_train_Y``: :math:`(N, *)`, where :math:`*` is the shape of the observable.
133 |
134 | Output: :math:`(N, *)`.
135 | """
136 | # G = S UV.T Z
137 | # G^n = (SU)(V.T K_YX U)^(n-1)(V.T Z)
138 | U = fit_result["U"]
139 | V = fit_result["V"]
140 | npts = U.shape[0]
141 | K_dot_U = kernel_Xin_X @ U / sqrt(npts)
142 | V_dot_obs = V.T @ obs_train_Y / sqrt(npts)
143 | V_K_YX_U = np.linalg.multi_dot([V.T, kernel_YX, U]) / npts
144 | M = np.linalg.matrix_power(V_K_YX_U, num_steps - 1)
145 | return np.linalg.multi_dot([K_dot_U, M, V_dot_obs])
146 |
147 |
148 | def pcr(
149 | kernel_X: ndarray,
150 | tikhonov_reg: float = 0.0,
151 | rank: int | None = None,
152 | svd_solver: Literal["arnoldi", "full"] = "arnoldi",
153 | ) -> FitResult:
154 | """Fits the Principal Components estimator.
155 |
156 | Args:
157 | kernel_X (ndarray): Kernel matrix of the input data.
158 | tikhonov_reg (float, optional): Tikhonov (ridge) regularization parameter. Defaults to 0.0.
159 | rank (int | None, optional): Rank of the estimator. Defaults to None.
160 | svd_solver (Literal[ "arnoldi", "full" ], optional): Solver for the generalized eigenvalue problem. Defaults to "arnoldi".
161 |
162 | Shape:
163 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data.
164 | """
165 | npts = kernel_X.shape[0]
166 | add_diagonal_(kernel_X, npts * tikhonov_reg)
167 | if svd_solver == "arnoldi":
168 | _num_arnoldi_eigs = min(rank + 5, kernel_X.shape[0])
169 | values, vectors = eigsh(kernel_X, k=_num_arnoldi_eigs)
170 | elif svd_solver == "full":
171 | values, vectors = scipy.linalg.eigh(kernel_X)
172 | else:
173 | raise ValueError(f"Unknown svd_solver {svd_solver}")
174 | add_diagonal_(kernel_X, -npts * tikhonov_reg)
175 |
176 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False)
177 | vectors = vectors[:, stable_values_idxs]
178 | Q = sqrt(npts) * vectors / np.sqrt(values)
179 | kernel_X_eigvalsh = np.sqrt(np.abs(values)) / npts
180 | result: FitResult = {"U": Q, "V": Q, "svals": kernel_X_eigvalsh}
181 | return result
182 |
183 |
184 | def nystroem_pcr(
185 | kernel_X: ndarray, # Kernel matrix of the input inducing points
186 | kernel_Y: ndarray, # Kernel matrix of the output inducing points
187 | kernel_Xnys: ndarray, # Kernel matrix between the input data and the input inducing points
188 | kernel_Ynys: ndarray, # Kernel matrix between the output data and the output inducing points
189 | tikhonov_reg: float = 0.0, # Tikhonov (ridge) regularization parameter (can be 0)
190 | rank: int | None = None, # Rank of the estimator
191 | svd_solver: Literal["arnoldi", "full"] = "arnoldi",
192 | ) -> FitResult:
193 | """Fits the Principal Components estimator using the Nyström method from :footcite:t:`Meanti2023`.
194 |
195 | Args:
196 | kernel_X (ndarray): Kernel matrix of the input inducing points.
197 | kernel_Y (ndarray): Kernel matrix of the output inducing points.
198 | kernel_Xnys (ndarray): Kernel matrix between the input data and the input inducing points.
199 | kernel_Ynys (ndarray): Kernel matrix between the output data and the output inducing points.
200 | tikhonov_reg (float, optional): Tikhonov (ridge) regularization parameter. Defaults to 0.0.
201 | rank (int | None, optional): Rank of the estimator. Defaults to None.
202 | svd_solver (Literal[ "arnoldi", "full" ], optional): Solver for the generalized eigenvalue problem. Defaults to "arnoldi".
203 |
204 | Shape:
205 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data.
206 |
207 | ``kernel_Y``: :math:`(N, N)`.
208 |
209 | ``kernel_Xnys``: :math:`(N, M)`, where :math:`M` is the number of Nystroem centers (inducing points).
210 |
211 | ``kernel_Ynys``: :math:`(N, M)`.
212 | """
213 | ncenters = kernel_X.shape[0]
214 | npts = kernel_Xnys.shape[0]
215 | eps = 1000 * np.finfo(kernel_X.dtype).eps
216 | reg = max(eps, tikhonov_reg)
217 | kernel_Xnys_sq = kernel_Xnys.T @ kernel_Xnys
218 | add_diagonal_(kernel_X, reg * ncenters)
219 | if svd_solver == "full":
220 | values, vectors = scipy.linalg.eigh(
221 | kernel_Xnys_sq, kernel_X
222 | ) # normalization leads to needing to invert evals
223 | elif svd_solver == "arnoldi":
224 | _oversampling = max(10, 4 * int(np.sqrt(rank)))
225 | _num_arnoldi_eigs = min(rank + _oversampling, ncenters)
226 | values, vectors = eigsh(
227 | kernel_Xnys_sq,
228 | M=kernel_X,
229 | k=_num_arnoldi_eigs,
230 | which="LM",
231 | )
232 | else:
233 | raise ValueError(f"Unknown svd_solver {svd_solver}")
234 | add_diagonal_(kernel_X, -reg * ncenters)
235 |
236 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False)
237 | vectors = vectors[:, stable_values_idxs]
238 |
239 | U = sqrt(ncenters) * vectors / np.sqrt(values)
240 | V = np.linalg.multi_dot([kernel_Ynys.T, kernel_Xnys, vectors])
241 | V = scipy.linalg.lstsq(kernel_Y, V)[0]
242 | V = sqrt(ncenters) * V / np.sqrt(values)
243 |
244 | kernel_X_eigvalsh = np.sqrt(np.abs(values)) / npts
245 | result: FitResult = {"U": U, "V": V, "svals": kernel_X_eigvalsh}
246 | return result
247 |
248 |
249 | def reduced_rank(
250 | kernel_X: ndarray, # Kernel matrix of the input data
251 | kernel_Y: ndarray, # Kernel matrix of the output data
252 | tikhonov_reg: float, # Tikhonov (ridge) regularization parameter, can be 0
253 | rank: int, # Rank of the estimator
254 | svd_solver: Literal["arnoldi", "full"] = "arnoldi",
255 | ) -> FitResult:
256 | """Fits the Reduced Rank estimator from :footcite:t:`Kostic2022`.
257 |
258 | Args:
259 | kernel_X (ndarray): Kernel matrix of the input data.
260 | kernel_Y (ndarray): Kernel matrix of the output data.
261 | tikhonov_reg (float): Tikhonov (ridge) regularization parameter.
262 | rank (int): Rank of the estimator.
263 | svd_solver (Literal[ "arnoldi", "full" ], optional): Solver for the generalized eigenvalue problem. Defaults to "arnoldi".
264 |
265 | Shape:
266 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data.
267 |
268 | ``kernel_Y``: :math:`(N, N)`.
269 | """
270 | # Number of data points
271 | npts = kernel_X.shape[0]
272 | eps = 1000.0 * np.finfo(kernel_X.dtype).eps
273 | penalty = max(eps, tikhonov_reg) * npts
274 |
275 | A = (kernel_Y / sqrt(npts)) @ (kernel_X / sqrt(npts))
276 | add_diagonal_(kernel_X, penalty)
277 | # Find U via Generalized eigenvalue problem equivalent to the SVD. If K is ill-conditioned might be slow.
278 | # Prefer svd_solver == 'randomized' in such a case.
279 | if svd_solver == "arnoldi":
280 | # Adding a small buffer to the Arnoldi-computed eigenvalues.
281 | num_arnoldi_eigs = min(rank + 5, npts)
282 | values, vectors = eigs(A, k=num_arnoldi_eigs, M=kernel_X)
283 | elif svd_solver == "full": # 'full'
284 | values, vectors = scipy.linalg.eig(A, kernel_X, overwrite_a=True, overwrite_b=True)
285 | else:
286 | raise ValueError(f"Unknown svd_solver: {svd_solver}")
287 | # Remove the penalty from kernel_X (inplace)
288 | add_diagonal_(kernel_X, -penalty)
289 |
290 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False)
291 | vectors = vectors[:, stable_values_idxs]
292 | # Compare the filtered eigenvalues with the regularization strength, and warn if there are any eigenvalues that are smaller than the regularization strength.
293 | if not np.all(np.abs(values) >= tikhonov_reg):
294 | warn(
295 | f"Warning: {(np.abs(values) < tikhonov_reg).sum()} out of the {len(values)} squared singular values are smaller than the regularization strength {tikhonov_reg:.2e}. Consider redudcing the regularization strength to avoid overfitting."
296 | )
297 |
298 | # Eigenvector normalization
299 | kernel_X_vecs = np.dot(kernel_X / sqrt(npts), vectors)
300 | vecs_norm = np.sqrt(
301 | np.sum(
302 | kernel_X_vecs**2 + tikhonov_reg * kernel_X_vecs * vectors * sqrt(npts),
303 | axis=0,
304 | )
305 | )
306 |
307 | norm_rcond = 1000.0 * np.finfo(values.dtype).eps
308 | values, stable_values_idxs = stable_topk(vecs_norm, rank, rcond=norm_rcond)
309 | U = vectors[:, stable_values_idxs] / vecs_norm[stable_values_idxs]
310 |
311 | # Ordering the results
312 | V = kernel_X @ U
313 | svals = np.sqrt(np.abs(values))
314 | result: FitResult = {"U": U.real, "V": V.real, "svals": svals}
315 | return result
316 |
317 |
318 | def nystroem_reduced_rank(
319 | kernel_X: ndarray, # Kernel matrix of the input inducing points
320 | kernel_Y: ndarray, # Kernel matrix of the output inducing points
321 | kernel_Xnys: ndarray, # Kernel matrix between the input data and the input inducing points
322 | kernel_Ynys: ndarray, # Kernel matrix between the output data and the output inducing points
323 | tikhonov_reg: float, # Tikhonov (ridge) regularization parameter
324 | rank: int, # Rank of the estimator
325 | svd_solver: Literal["arnoldi", "full"] = "arnoldi",
326 | ) -> FitResult:
327 | """Fits the Nyström Reduced Rank estimator from :footcite:t:`Meanti2023`.
328 |
329 | Args:
330 | kernel_X (ndarray): Kernel matrix of the input inducing points.
331 | kernel_Y (ndarray): Kernel matrix of the output inducing points.
332 | kernel_Xnys (ndarray): Kernel matrix between the input data and the input inducing points.
333 | kernel_Ynys (ndarray): Kernel matrix between the output data and the output inducing points.
334 | tikhonov_reg (float): Tikhonov (ridge) regularization parameter.
335 | rank (int): Rank of the estimator.
336 | svd_solver (Literal[ "arnoldi", "full" ], optional): Solver for the generalized eigenvalue problem. Defaults to "arnoldi".
337 |
338 | Shape:
339 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data.
340 |
341 | ``kernel_Y``: :math:`(N, N)`.
342 |
343 | ``kernel_Xnys``: :math:`(N, M)`, where :math:`M` is the number of Nystroem centers (inducing points).
344 |
345 | ``kernel_Ynys``: :math:`(N, M)`.
346 | """
347 | num_points = kernel_Xnys.shape[0]
348 | num_centers = kernel_X.shape[0]
349 |
350 | eps = 1000 * np.finfo(kernel_X.dtype).eps * num_centers
351 | reg = max(eps, tikhonov_reg)
352 |
353 | # LHS of the generalized eigenvalue problem
354 | sqrt_Mn = sqrt(num_centers * num_points)
355 | kernel_YX_nys = (kernel_Ynys.T / sqrt_Mn) @ (kernel_Xnys / sqrt_Mn)
356 |
357 | _tmp_YX = scipy.linalg.lstsq(kernel_Y * (num_centers**-1), kernel_YX_nys)[0]
358 | kernel_XYX = kernel_YX_nys.T @ _tmp_YX
359 |
360 | # RHS of the generalized eigenvalue problem
361 | kernel_Xnys_sq = (kernel_Xnys.T / sqrt_Mn) @ (kernel_Xnys / sqrt_Mn) + reg * kernel_X * (
362 | num_centers**-1
363 | )
364 |
365 | add_diagonal_(kernel_Xnys_sq, eps)
366 | A = scipy.linalg.lstsq(kernel_Xnys_sq, kernel_XYX)[0]
367 | if svd_solver == "full":
368 | values, vectors = scipy.linalg.eigh(
369 | kernel_XYX, kernel_Xnys_sq
370 | ) # normalization leads to needing to invert evals
371 | elif svd_solver == "arnoldi":
372 | _oversampling = max(10, 4 * int(np.sqrt(rank)))
373 | _num_arnoldi_eigs = min(rank + _oversampling, num_centers)
374 | values, vectors = eigs(kernel_XYX, k=_num_arnoldi_eigs, M=kernel_Xnys_sq)
375 | else:
376 | raise ValueError(f"Unknown svd_solver {svd_solver}")
377 | add_diagonal_(kernel_Xnys_sq, -eps)
378 |
379 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False)
380 | vectors = vectors[:, stable_values_idxs]
381 | # Compare the filtered eigenvalues with the regularization strength, and warn if there are any eigenvalues that are smaller than the regularization strength.
382 | if not np.all(np.abs(values) >= tikhonov_reg):
383 | warn(
384 | f"Warning: {(np.abs(values) < tikhonov_reg).sum()} out of the {len(values)} squared singular values are smaller than the regularization strength {tikhonov_reg:.2e}. Consider redudcing the regularization strength to avoid overfitting."
385 | )
386 | # Eigenvector normalization
387 | vecs_norm = np.sqrt(np.abs(np.sum(vectors.conj() * (kernel_XYX @ vectors), axis=0)))
388 | norm_rcond = 1000.0 * np.finfo(values.dtype).eps
389 | values, stable_values_idxs = stable_topk(vecs_norm, rank, rcond=norm_rcond)
390 | vectors = vectors[:, stable_values_idxs] / vecs_norm[stable_values_idxs]
391 | U = A @ vectors
392 | V = _tmp_YX @ vectors
393 | svals = np.sqrt(np.abs(values))
394 | result: FitResult = {"U": U.real, "V": V.real, "svals": svals}
395 | return result
396 |
397 |
398 | def rand_reduced_rank(
399 | kernel_X: ndarray,
400 | kernel_Y: ndarray,
401 | tikhonov_reg: float,
402 | rank: int,
403 | n_oversamples: int = 5,
404 | optimal_sketching: bool = False,
405 | iterated_power: int = 1,
406 | rng_seed: int | None = None,
407 | precomputed_cholesky=None,
408 | ) -> FitResult:
409 | """Fits the Randomized Reduced Rank Estimator from :footcite:t:`Turri2023`.
410 |
411 | Args:
412 | kernel_X (ndarray): Kernel matrix of the input data
413 | kernel_Y (ndarray): Kernel matrix of the output data
414 | tikhonov_reg (float): Tikhonov (ridge) regularization parameter
415 | rank (int): Rank of the estimator
416 | n_oversamples (int, optional): Number of Oversamples. Defaults to 5.
417 | optimal_sketching (bool, optional): Whether to use optimal sketching (slower but more accurate) or not.. Defaults to False.
418 | iterated_power (int, optional): Number of iterations of the power method. Defaults to 1.
419 | rng_seed (int | None, optional): Random Number Generators seed. Defaults to None.
420 | precomputed_cholesky (optional): Precomputed Cholesky decomposition. Should be the output of cho_factor evaluated on the regularized kernel matrix.. Defaults to None.
421 |
422 | Shape:
423 | ``kernel_X``: :math:`(N, N)`, where :math:`N` is the number of training data.
424 |
425 | ``kernel_Y``: :math:`(N, N)`.
426 | """
427 | rng = np.random.default_rng(rng_seed)
428 | npts = kernel_X.shape[0]
429 |
430 | penalty = npts * tikhonov_reg
431 | add_diagonal_(kernel_X, penalty)
432 | if precomputed_cholesky is None:
433 | cholesky_decomposition = scipy.linalg.cho_factor(kernel_X)
434 | else:
435 | cholesky_decomposition = precomputed_cholesky
436 | add_diagonal_(kernel_X, -penalty)
437 |
438 | sketch_dimension = rank + n_oversamples
439 |
440 | if optimal_sketching:
441 | cov = kernel_Y / npts
442 | sketch = rng.multivariate_normal(
443 | np.zeros(npts, dtype=kernel_Y.dtype), cov, size=sketch_dimension
444 | ).T
445 | else:
446 | sketch = rng.standard_normal(size=(npts, sketch_dimension))
447 |
448 | for _ in range(iterated_power):
449 | # Powered randomized rangefinder
450 | sketch = (kernel_Y / npts) @ (
451 | sketch - penalty * scipy.linalg.cho_solve(cholesky_decomposition, sketch)
452 | )
453 | sketch, _ = scipy.linalg.qr(sketch, mode="economic") # QR re-orthogonalization
454 |
455 | kernel_X_sketch = scipy.linalg.cho_solve(cholesky_decomposition, sketch)
456 | _M = sketch - penalty * kernel_X_sketch
457 |
458 | F_0 = sketch.T @ sketch - penalty * (sketch.T @ kernel_X_sketch) # Symmetric
459 | F_0 = 0.5 * (F_0 + F_0.T)
460 | F_1 = _M.T @ ((kernel_Y @ _M) / npts)
461 |
462 | values, vectors = scipy.linalg.eig(scipy.linalg.lstsq(F_0, F_1)[0])
463 | values, stable_values_idxs = stable_topk(values, rank, ignore_warnings=False)
464 | vectors = vectors[:, stable_values_idxs]
465 |
466 | # Remove elements in the kernel of F_0
467 | relative_norm_sq = np.abs(
468 | np.sum(vectors.conj() * (F_0 @ vectors), axis=0) / np.linalg.norm(vectors, axis=0) ** 2
469 | )
470 | norm_rcond = 1000.0 * np.finfo(values.dtype).eps
471 | values, stable_values_idxs = stable_topk(relative_norm_sq, rank, rcond=norm_rcond)
472 | vectors = vectors[:, stable_values_idxs]
473 |
474 | vecs_norms = (np.sum(vectors.conj() * (F_0 @ vectors), axis=0).real) ** 0.5
475 | vectors = vectors / vecs_norms
476 |
477 | U = sqrt(npts) * kernel_X_sketch @ vectors
478 | V = sqrt(npts) * _M @ vectors
479 | svals = np.sqrt(values)
480 | result: FitResult = {"U": U, "V": V, "svals": svals}
481 | return result
482 |
--------------------------------------------------------------------------------
/linear_operator_learning/kernel/structs.py:
--------------------------------------------------------------------------------
1 | """Structs used by the `kernel` algorithms."""
2 |
3 | from typing import TypedDict
4 |
5 | import numpy as np
6 | from numpy import ndarray
7 |
8 |
9 | class FitResult(TypedDict):
10 | """Return type for kernel regressors."""
11 |
12 | U: ndarray
13 | V: ndarray
14 | svals: ndarray | None
15 |
16 |
17 | class EigResult(TypedDict):
18 | """Return type for eigenvalue decompositions of kernel regressors."""
19 |
20 | values: ndarray
21 | left: ndarray | None
22 | right: ndarray
23 |
--------------------------------------------------------------------------------
/linear_operator_learning/kernel/utils.py:
--------------------------------------------------------------------------------
1 | """Generic Utilities."""
2 |
3 | from math import sqrt
4 |
5 | import numpy as np
6 | from numpy import ndarray
7 | from scipy.spatial.distance import pdist
8 |
9 |
10 | def topk(vec: ndarray, k: int):
11 | """Get the top k values from a Numpy array.
12 |
13 | Args:
14 | vec (ndarray): A 1D numpy array
15 | k (int): Number of elements to keep
16 |
17 | Returns:
18 | values, indices: top k values and their indices
19 | """
20 | assert np.ndim(vec) == 1, "'vec' must be a 1D array"
21 | assert k > 0, "k should be greater than 0"
22 | sort_perm = np.flip(np.argsort(vec)) # descending order
23 | indices = sort_perm[:k]
24 | values = vec[indices]
25 | return values, indices
26 |
27 |
28 | def sanitize_complex_conjugates(vec: ndarray, tol: float = 10.0):
29 | """This function acts on 1D complex vectors. If the real parts of two elements are close, sets them equal. Furthermore, sets to 0 the imaginary parts smaller than `tol` times the machine precision.
30 |
31 | Args:
32 | vec (ndarray): A 1D vector to sanitize.
33 | tol (float, optional): Tolerance for comparisons. Defaults to 10.0.
34 |
35 | """
36 | assert issubclass(vec.dtype.type, np.complexfloating), "The input element should be complex"
37 | assert vec.ndim == 1
38 | rcond = tol * np.finfo(vec.dtype).eps
39 | pdist_real_part = pdist(vec.real[:, None])
40 | # Set the same element whenever pdist is smaller than eps*tol
41 | condensed_idxs = np.argwhere(pdist_real_part < rcond)[:, 0]
42 | fuzzy_real = vec.real.copy()
43 | if condensed_idxs.shape[0] >= 1:
44 | for idx in condensed_idxs:
45 | i, j = _row_col_from_condensed_index(vec.real.shape[0], idx)
46 | avg = 0.5 * (fuzzy_real[i] + fuzzy_real[j])
47 | fuzzy_real[i] = avg
48 | fuzzy_real[j] = avg
49 | fuzzy_imag = vec.imag.copy()
50 | fuzzy_imag[np.abs(fuzzy_imag) < rcond] = 0.0
51 | return fuzzy_real + 1j * fuzzy_imag
52 |
53 |
54 | def _row_col_from_condensed_index(d, index):
55 | # Credits to: https://stackoverflow.com/a/14839010
56 | b = 1 - (2 * d)
57 | i = (-b - sqrt(b**2 - 8 * index)) // 2
58 | j = index + i * (b + i + 2) // 2 + 1
59 | return (int(i), int(j))
60 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/__init__.py:
--------------------------------------------------------------------------------
1 | """Neural network methods entry point."""
2 |
3 | import linear_operator_learning.nn.functional as functional
4 | import linear_operator_learning.nn.stats as stats
5 | from linear_operator_learning.nn.modules import * # noqa: F403
6 | from linear_operator_learning.nn.regressors import eig, evaluate_eigenfunction, ridge_least_squares
7 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/functional.py:
--------------------------------------------------------------------------------
1 | """Functional interface."""
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from linear_operator_learning.nn.linalg import sqrtmh
7 | from linear_operator_learning.nn.stats import cov_norm_squared_unbiased, covariance
8 |
9 | # Losses_____________________________________________________________________________________________
10 |
11 |
12 | def vamp_loss(
13 | x: Tensor, y: Tensor, schatten_norm: int = 2, center_covariances: bool = True
14 | ) -> Tensor:
15 | """See :class:`linear_operator_learning.nn.VampLoss` for details."""
16 | cov_x, cov_y, cov_xy = (
17 | covariance(x, center=center_covariances),
18 | covariance(y, center=center_covariances),
19 | covariance(x, y, center=center_covariances),
20 | )
21 | if schatten_norm == 2:
22 | # Using least squares in place of pinv for numerical stability
23 | M_x = torch.linalg.lstsq(cov_x, cov_xy).solution
24 | M_y = torch.linalg.lstsq(cov_y, cov_xy.T).solution
25 | return -torch.trace(M_x @ M_y)
26 | elif schatten_norm == 1:
27 | sqrt_cov_x = sqrtmh(cov_x)
28 | sqrt_cov_y = sqrtmh(cov_y)
29 | M = torch.linalg.multi_dot(
30 | [
31 | torch.linalg.pinv(sqrt_cov_x, hermitian=True),
32 | cov_xy,
33 | torch.linalg.pinv(sqrt_cov_y, hermitian=True),
34 | ]
35 | )
36 | return -torch.linalg.matrix_norm(M, "nuc")
37 | else:
38 | raise NotImplementedError(f"Schatten norm {schatten_norm} not implemented")
39 |
40 |
41 | def dp_loss(
42 | x: Tensor,
43 | y: Tensor,
44 | relaxed: bool = True,
45 | center_covariances: bool = True,
46 | ) -> Tensor:
47 | """See :class:`linear_operator_learning.nn.DPLoss` for details."""
48 | cov_x, cov_y, cov_xy = (
49 | covariance(x, center=center_covariances),
50 | covariance(y, center=center_covariances),
51 | covariance(x, y, center=center_covariances),
52 | )
53 | if relaxed:
54 | S = (torch.linalg.matrix_norm(cov_xy, ord="fro") ** 2) / (
55 | torch.linalg.matrix_norm(cov_x, ord=2) * torch.linalg.matrix_norm(cov_y, ord=2)
56 | )
57 | else:
58 | M_x = torch.linalg.lstsq(cov_x, cov_xy).solution
59 | M_y = torch.linalg.lstsq(cov_y, cov_xy.T).solution
60 | S = torch.trace(M_x @ M_y)
61 | return -S
62 |
63 |
64 | def l2_contrastive_loss(x: Tensor, y: Tensor) -> Tensor:
65 | """See :class:`linear_operator_learning.nn.L2ContrastiveLoss` for details."""
66 | assert x.shape == y.shape
67 | assert x.ndim == 2
68 |
69 | npts, dim = x.shape
70 | diag = 2 * torch.mean(x * y) * dim
71 | square_term = torch.matmul(x, y.T) ** 2
72 | off_diag = (
73 | torch.mean(torch.triu(square_term, diagonal=1) + torch.tril(square_term, diagonal=-1))
74 | * npts
75 | / (npts - 1)
76 | )
77 | return off_diag - diag
78 |
79 |
80 | def kl_contrastive_loss(X: Tensor, Y: Tensor) -> Tensor:
81 | """See :class:`linear_operator_learning.nn.KLContrastiveLoss` for details."""
82 | assert X.shape == Y.shape
83 | assert X.ndim == 2
84 |
85 | npts, dim = X.shape
86 | log_term = torch.mean(torch.log(X * Y)) * dim
87 | linear_term = torch.matmul(X, Y.T)
88 | off_diag = (
89 | torch.mean(torch.triu(linear_term, diagonal=1) + torch.tril(linear_term, diagonal=-1))
90 | * npts
91 | / (npts - 1)
92 | )
93 | return off_diag - log_term
94 |
95 |
96 | # Regularizers______________________________________________________________________________________
97 |
98 |
99 | def orthonormal_fro_reg(x: Tensor) -> Tensor:
100 | r"""Orthonormality regularization with Frobenious norm of covariance of `x`.
101 |
102 | Given a batch of realizations of `x`, the orthonormality regularization term penalizes:
103 |
104 | 1. Orthogonality: Linear dependencies among dimensions,
105 | 2. Normality: Deviations of each dimension’s variance from 1,
106 | 3. Centering: Deviations of each dimension’s mean from 0.
107 |
108 | .. math::
109 |
110 | \frac{1}{D} \| \mathbf{C}_{X} - I \|_F^2 + 2 \| \mathbb{E}_{X} x \|^2 = \frac{1}{D} (\text{tr}(\mathbf{C}^2_{X}) - 2 \text{tr}(\mathbf{C}_{X}) + D + 2 \| \mathbb{E}_{X} x \|^2)
111 |
112 | Args:
113 | x (Tensor): Input features.
114 |
115 | Shape:
116 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
117 | """
118 | x_mean = x.mean(dim=0, keepdim=True)
119 | x_centered = x - x_mean
120 | # As ||Cx||_F^2 = E_(x,x')~p(x) [((x - E_p(x) x)^T (x' - E_p(x) x'))^2] = tr(Cx^2), involves the product of
121 | # covariances, unbiased estimation of this term requires the use of U-statistics
122 | Cx_fro_2 = cov_norm_squared_unbiased(x_centered)
123 | # tr(Cx) = E_p(x) [(x - E_p(x))^T (x - E_p(x))] ≈ 1/N Σ_n (x_n - E_p(x))^T (x_n - E_p(x))
124 | tr_Cx = torch.einsum("ij,ij->", x_centered, x_centered) / x.shape[0]
125 | centering_loss = (x_mean**2).sum() # ||E_p(x) x||^2
126 | D = x.shape[-1] # ||I||_F^2 = D
127 | reg = Cx_fro_2 - 2 * tr_Cx + D + 2 * centering_loss
128 | return reg / D
129 |
130 |
131 | def orthonormal_logfro_reg(x: Tensor) -> Tensor:
132 | r"""Orthonormality regularization with log-Frobenious norm of covariance of x by :footcite:t:`Kostic2023DPNets`.
133 |
134 | .. math::
135 |
136 | \frac{1}{D}\text{Tr}(C_X^{2} - C_X -\ln(C_X)).
137 |
138 | Args:
139 | x (Tensor): Input features.
140 |
141 | Shape:
142 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
143 | """
144 | cov = covariance(x) # shape: (D, D)
145 | eps = torch.finfo(cov.dtype).eps * cov.shape[0]
146 | vals_x = torch.linalg.eigvalsh(cov)
147 | vals_x = torch.where(vals_x > eps, vals_x, eps)
148 | orth_loss = torch.mean(-torch.log(vals_x) + vals_x * (vals_x - 1.0))
149 | # TODO: Centering like this?
150 | centering_loss = (x.mean(0, keepdim=True) ** 2).sum() # ||E_p(x) x||^2
151 | reg = orth_loss + 2 * centering_loss
152 | return reg
153 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/linalg.py:
--------------------------------------------------------------------------------
1 | """Linear Algebra."""
2 |
3 | from typing import NamedTuple
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 |
9 | def sqrtmh(A: Tensor) -> Tensor:
10 | """Compute the square root of a Symmetric or Hermitian positive definite matrix or batch of matrices.
11 |
12 | Used code from `this issue `_.
13 |
14 | Args:
15 | A (Tensor): Symmetric or Hermitian positive definite matrix or batch of matrices.
16 |
17 | Shape:
18 | ``A``: :math:`(N, N)`
19 |
20 | Output: :math:`(N, N)`
21 | """
22 | L, Q = torch.linalg.eigh(A)
23 | zero = torch.zeros((), device=L.device, dtype=L.dtype)
24 | threshold = L.max(-1).values * L.size(-1) * torch.finfo(L.dtype).eps
25 | L = L.where(L > threshold.unsqueeze(-1), zero) # zero out small components
26 | return (Q * L.sqrt().unsqueeze(-2)) @ Q.mH
27 |
28 |
29 | ####################################################################################################
30 | # TODO: THIS IS JUST COPY AND PASTE FROM OLD NCP
31 | # Should topk and filter_reduced_rank_svals be in utils? They look like linalg to me, specially the
32 | # filter
33 | ####################################################################################################
34 |
35 |
36 | # Sorting and parsing
37 | class TopKReturnType(NamedTuple): # noqa: D101
38 | values: Tensor
39 | indices: Tensor
40 |
41 |
42 | def topk(vec: Tensor, k: int): # noqa: D103
43 | assert vec.ndim == 1, "'vec' must be a 1D array"
44 | assert k > 0, "k should be greater than 0"
45 | sort_perm = torch.flip(torch.argsort(vec), dims=[0]) # descending order
46 | indices = sort_perm[:k]
47 | values = vec[indices]
48 | return TopKReturnType(values, indices)
49 |
50 |
51 | def filter_reduced_rank_svals(values, vectors): # noqa: D103
52 | eps = 2 * torch.finfo(torch.get_default_dtype()).eps
53 | # Filtering procedure.
54 | # Create a mask which is True when the real part of the eigenvalue is negative or the imaginary part is nonzero
55 | is_invalid = torch.logical_or(
56 | torch.real(values) <= eps,
57 | torch.imag(values) != 0
58 | if torch.is_complex(values)
59 | else torch.zeros(len(values), device=values.device),
60 | )
61 | # Check if any is invalid take the first occurrence of a True value in the mask and filter everything after that
62 | if torch.any(is_invalid):
63 | values = values[~is_invalid].real
64 | vectors = vectors[:, ~is_invalid]
65 |
66 | sort_perm = topk(values, len(values)).indices
67 | values = values[sort_perm]
68 | vectors = vectors[:, sort_perm]
69 |
70 | # Assert that the eigenvectors do not have any imaginary part
71 | assert torch.all(
72 | torch.imag(vectors) == 0 if torch.is_complex(values) else torch.ones(len(values))
73 | ), "The eigenvectors should be real. Decrease the rank or increase the regularization strength."
74 |
75 | # Take the real part of the eigenvectors
76 | vectors = torch.real(vectors)
77 | values = torch.real(values)
78 | return values, vectors
79 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | """Modules entry point."""
2 |
3 | from .ema_covariance import EMACovariance
4 | from .loss import * # noqa: F403
5 | from .mlp import MLP
6 | from .resnet import * # noqa: F403
7 | from .simnorm import SimNorm
8 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/modules/ema_covariance.py:
--------------------------------------------------------------------------------
1 | """Exponential moving average of the covariance matrices."""
2 |
3 | import torch
4 | import torch.distributed
5 | from torch import Tensor
6 |
7 |
8 | class EMACovariance(torch.nn.Module):
9 | r"""Exponential moving average of the covariance matrices.
10 |
11 | Gives an online estimate of the covariances and means :math:`C` adding the batch covariance :math:`\hat{C}` via the following update forumla
12 |
13 | .. math::
14 |
15 | C \leftarrow (1 - m)C + m \hat{C}
16 |
17 | Args:
18 | feature_dim: The number of features in the input and output tensors.
19 | momentum: The momentum for the exponential moving average.
20 | center: Whether to center the data before computing the covariance matrices.
21 | """
22 |
23 | def __init__(self, feature_dim: int, momentum: float = 0.01, center: bool = True):
24 | super().__init__()
25 | self.is_centered = center
26 | self.momentum = momentum
27 | self.register_buffer("mean_X", torch.zeros(feature_dim))
28 | self.register_buffer("cov_X", torch.eye(feature_dim))
29 | self.register_buffer("mean_Y", torch.zeros(feature_dim))
30 | self.register_buffer("cov_Y", torch.eye(feature_dim))
31 | self.register_buffer("cov_XY", torch.eye(feature_dim))
32 | self.register_buffer("is_initialized", torch.tensor(False, dtype=torch.bool))
33 |
34 | @torch.no_grad()
35 | def forward(self, X: Tensor, Y: Tensor):
36 | """Update the exponential moving average of the covariance matrices.
37 |
38 | Args:
39 | X: Input tensor.
40 | Y: Output tensor.
41 |
42 | Shape:
43 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
44 |
45 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
46 | """
47 | if not self.training:
48 | return
49 | assert X.ndim == 2
50 | assert X.shape == Y.shape
51 | assert X.shape[1] == self.mean_X.shape[0]
52 | if not self.is_initialized.item():
53 | self._first_forward(X, Y)
54 | else:
55 | mean_X = X.mean(dim=0, keepdim=True)
56 | mean_Y = Y.mean(dim=0, keepdim=True)
57 | # Update means
58 | self._inplace_EMA(mean_X[0], self.mean_X)
59 | self._inplace_EMA(mean_Y[0], self.mean_Y)
60 |
61 | if self.is_centered:
62 | X = X - self.mean_X
63 | Y = Y - self.mean_Y
64 |
65 | cov_X = torch.mm(X.T, X) / X.shape[0]
66 | cov_Y = torch.mm(Y.T, Y) / Y.shape[0]
67 | cov_XY = torch.mm(X.T, Y) / X.shape[0]
68 | # Update covariances
69 | self._inplace_EMA(cov_X, self.cov_X)
70 | self._inplace_EMA(cov_Y, self.cov_Y)
71 | self._inplace_EMA(cov_XY, self.cov_XY)
72 |
73 | def _first_forward(self, X: torch.Tensor, Y: torch.Tensor):
74 | mean_X = X.mean(dim=0, keepdim=True)
75 | self._inplace_set(mean_X[0], self.mean_X)
76 | mean_Y = Y.mean(dim=0, keepdim=True)
77 | self._inplace_set(mean_Y[0], self.mean_Y)
78 | if self.is_centered:
79 | X = X - self.mean_X
80 | Y = Y - self.mean_Y
81 |
82 | cov_X = torch.mm(X.T, X) / X.shape[0]
83 | cov_Y = torch.mm(Y.T, Y) / Y.shape[0]
84 | cov_XY = torch.mm(X.T, Y) / X.shape[0]
85 | self._inplace_set(cov_X, self.cov_X)
86 | self._inplace_set(cov_Y, self.cov_Y)
87 | self._inplace_set(cov_XY, self.cov_XY)
88 | self.is_initialized = torch.tensor(True, dtype=torch.bool)
89 |
90 | def _inplace_set(self, update, current):
91 | if torch.distributed.is_initialized():
92 | torch.distributed.all_reduce(update, op=torch.distributed.ReduceOp.SUM)
93 | update /= torch.distributed.get_world_size()
94 | current.copy_(update)
95 |
96 | def _inplace_EMA(self, update, current):
97 | alpha = 1 - self.momentum
98 | if torch.distributed.is_initialized():
99 | torch.distributed.all_reduce(update, op=torch.distributed.ReduceOp.SUM)
100 | update /= torch.distributed.get_world_size()
101 |
102 | current.mul_(alpha).add_(update, alpha=self.momentum)
103 |
104 |
105 | def test_EMACovariance(): # noqa: D103
106 | torch.manual_seed(0)
107 |
108 | dims = 5
109 | dummy_X = torch.randn(10, dims)
110 | dummy_Y = torch.randn(10, dims)
111 | cov_module = EMACovariance(feature_dim=dims)
112 |
113 | # Check that when model is not set to training covariance is not updated
114 | cov_module.eval()
115 | cov_module(dummy_X, dummy_Y)
116 | assert torch.allclose(cov_module.cov_X, torch.eye(dims))
117 | assert torch.allclose(cov_module.cov_Y, torch.eye(dims))
118 | assert torch.allclose(cov_module.cov_XY, torch.eye(dims))
119 |
120 | assert torch.allclose(cov_module.mean_X, torch.zeros(dims))
121 | assert torch.allclose(cov_module.mean_Y, torch.zeros(dims))
122 |
123 | # Check that the first_forward is correctly called
124 | cov_module.train()
125 | assert not cov_module.is_initialized.item()
126 | cov_module(dummy_X, dummy_Y)
127 | assert cov_module.is_initialized.item()
128 | assert torch.allclose(cov_module.mean_X, dummy_X.mean(dim=0))
129 | assert torch.allclose(cov_module.mean_Y, dummy_Y.mean(dim=0))
130 | if cov_module.is_centered:
131 | assert torch.allclose(cov_module.cov_X, torch.cov(dummy_X.T, correction=0))
132 | assert torch.allclose(cov_module.cov_Y, torch.cov(dummy_Y.T, correction=0))
133 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/modules/loss.py:
--------------------------------------------------------------------------------
1 | """Loss functions for representation learning."""
2 |
3 | from typing import Literal
4 |
5 | from torch import Tensor
6 | from torch.nn import Module
7 |
8 | from linear_operator_learning.nn import functional as F
9 |
10 | __all__ = ["VampLoss", "L2ContrastiveLoss", "KLContrastiveLoss", "DPLoss"]
11 |
12 | # Losses_____________________________________________________________________________________________
13 |
14 |
15 | class _RegularizedLoss(Module):
16 | """Base class for regularized losses.
17 |
18 | Args:
19 | gamma (float, optional): Regularization strength.
20 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `.
21 | """
22 |
23 | def __init__(
24 | self, gamma: float, regularizer: Literal["orthn_fro", "orthn_logfro"]
25 | ) -> None: # TODO: Automatically determine 'gamma' from dim_x and dim_y
26 | super().__init__()
27 | self.gamma = gamma
28 |
29 | if regularizer == "orthn_fro":
30 | self.regularizer = F.orthonormal_fro_reg
31 | elif regularizer == "orthn_logfro":
32 | self.regularizer = F.orthonormal_logfro_reg
33 | else:
34 | raise NotImplementedError(f"Regularizer {regularizer} not supported!")
35 |
36 |
37 | class VampLoss(_RegularizedLoss):
38 | r"""Variational Approach for learning Markov Processes (VAMP) score by :footcite:t:`Wu2019`.
39 |
40 | .. math::
41 |
42 | \mathcal{L}(x, y) = -\sum_{i} \sigma_{i}(A)^{p} \qquad \text{where}~A = \big(x^{\top}x\big)^{\dagger/2}x^{\top}y\big(y^{\top}y\big)^{\dagger/2}.
43 |
44 | Args:
45 | schatten_norm (int, optional): Computes the VAMP-p score with ``p = schatten_norm``. Defaults to 2.
46 | center_covariances (bool, optional): Use centered covariances to compute the VAMP score. Defaults to True.
47 | gamma (float, optional): Regularization strength. Defaults to 1e-3.
48 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `.
49 | """
50 |
51 | def __init__(
52 | self,
53 | schatten_norm: int = 2,
54 | center_covariances: bool = True,
55 | gamma: float = 1e-3,
56 | regularizer: Literal["orthn_fro", "orthn_logfro"] = "orthn_fro",
57 | ) -> None:
58 | super().__init__(gamma, regularizer)
59 | self.schatten_norm = schatten_norm
60 | self.center_covariances = center_covariances
61 |
62 | def forward(self, x: Tensor, y: Tensor) -> Tensor:
63 | """Forward pass of VAMP loss.
64 |
65 | Args:
66 | x (Tensor): Features for x.
67 | y (Tensor): Features for y.
68 |
69 | Raises:
70 | NotImplementedError: If ``schatten_norm`` is not 1 or 2.
71 |
72 | Shape:
73 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
74 |
75 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
76 | """
77 | return F.vamp_loss(
78 | x,
79 | y,
80 | self.schatten_norm,
81 | self.center_covariances,
82 | ) + self.gamma * (self.regularizer(x) + self.regularizer(y))
83 |
84 |
85 | class L2ContrastiveLoss(_RegularizedLoss):
86 | r"""NCP/Contrastive/Mutual Information Loss based on the :math:`L^{2}` error by :footcite:t:`Kostic2024NCP`.
87 |
88 | .. math::
89 |
90 | \mathcal{L}(x, y) = \frac{1}{N(N-1)}\sum_{i \neq j}\langle x_{i}, y_{j} \rangle^2 - \frac{2}{N}\sum_{i=1}\langle x_{i}, y_{i} \rangle.
91 |
92 | Args:
93 | gamma (float, optional): Regularization strength. Defaults to 1e-3.
94 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `.
95 | """
96 |
97 | def __init__(
98 | self,
99 | gamma: float = 1e-3,
100 | regularizer: Literal["orthn_fro", "orthn_logfro"] = "orthn_fro",
101 | ) -> None:
102 | super().__init__(gamma, regularizer)
103 |
104 | def forward(self, x: Tensor, y: Tensor) -> Tensor: # noqa: D102
105 | """Forward pass of the L2 contrastive loss.
106 |
107 | Args:
108 | x (Tensor): Input features.
109 | y (Tensor): Output features.
110 |
111 |
112 | Shape:
113 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
114 |
115 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
116 | """
117 | return F.l2_contrastive_loss(x, y) + self.gamma * (
118 | self.regularizer(x) + self.regularizer(y)
119 | )
120 |
121 |
122 | class DPLoss(_RegularizedLoss):
123 | r"""Deep Projection Loss by :footcite:t:`Kostic2023DPNets`.
124 |
125 | .. math::
126 |
127 | \mathcal{L}(x, y) = -\frac{\|x^{\top}y\|^{2}_{{\rm F}}}{\|x^{\top}x\|^{2}\|y^{\top}y\|^{2}}.
128 |
129 | Args:
130 | relaxed (bool, optional): Whether to use the relaxed (more numerically stable) or the full deep-projection loss. Defaults to True.
131 | center_covariances (bool, optional): Use centered covariances to compute the Deep Projection loss. Defaults to True.
132 | gamma (float, optional): Regularization strength. Defaults to 1e-3.
133 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `.
134 | """
135 |
136 | def __init__(
137 | self,
138 | relaxed: bool = True,
139 | center_covariances: bool = True,
140 | gamma: float = 1e-3,
141 | regularizer: Literal["orthn_fro", "orthn_logfro"] = "orthn_fro",
142 | ) -> None:
143 | super().__init__(gamma, regularizer)
144 | self.relaxed = relaxed
145 | self.center_covariances = center_covariances
146 |
147 | def forward(self, x: Tensor, y: Tensor) -> Tensor:
148 | """Forward pass of DPLoss.
149 |
150 | Args:
151 | x (Tensor): Features for x.
152 | y (Tensor): Features for y.
153 |
154 | Shape:
155 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
156 |
157 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
158 | """
159 | return F.dp_loss(
160 | x,
161 | y,
162 | self.relaxed,
163 | self.center_covariances,
164 | ) + self.gamma * (self.regularizer(x) + self.regularizer(y))
165 |
166 |
167 | class KLContrastiveLoss(_RegularizedLoss):
168 | r"""NCP/Contrastive/Mutual Information Loss based on the KL divergence.
169 |
170 | .. math::
171 |
172 | \mathcal{L}(x, y) = \frac{1}{N(N-1)}\sum_{i \neq j}\langle x_{i}, y_{j} \rangle - \frac{2}{N}\sum_{i=1}\log\big(\langle x_{i}, y_{i} \rangle\big).
173 |
174 | Args:
175 | gamma (float, optional): Regularization strength. Defaults to 1e-3.
176 | regularizer (literal, optional): Regularizer. Either :func:`orthn_fro ` or :func:`orthn_logfro `. Defaults to :func:`orthn_fro `.
177 |
178 |
179 | """
180 |
181 | def __init__(
182 | self,
183 | gamma: float = 1e-3,
184 | regularizer: Literal["orthn_fro", "orthn_logfro"] = "orthn_fro",
185 | ) -> None:
186 | super().__init__(gamma, regularizer)
187 |
188 | def forward(self, x: Tensor, y: Tensor) -> Tensor: # noqa: D102
189 | """Forward pass of the KL contrastive loss.
190 |
191 | Args:
192 | x (Tensor): Input features.
193 | y (Tensor): Output features.
194 |
195 |
196 | Shape:
197 | ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
198 |
199 | ``y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
200 | """
201 | return F.kl_contrastive_loss(x, y) + self.gamma * (
202 | self.regularizer(x) + self.regularizer(y)
203 | )
204 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/modules/mlp.py:
--------------------------------------------------------------------------------
1 | # TODO: Refactor the models, add docstrings, etc...
2 | """PyTorch Models."""
3 |
4 | import torch
5 | from torch.nn import Conv2d, Dropout, Linear, MaxPool2d, Module, ReLU, Sequential
6 |
7 |
8 | class _MLPBlock(Module):
9 | def __init__(self, input_size, output_size, dropout=0.0, activation=ReLU, bias=True):
10 | super(_MLPBlock, self).__init__()
11 | self.linear = Linear(input_size, output_size, bias=bias)
12 | self.dropout = Dropout(dropout)
13 | self.activation = activation()
14 |
15 | def forward(self, x):
16 | out = self.linear(x)
17 | out = self.dropout(out)
18 | out = self.activation(out)
19 | return out
20 |
21 |
22 | class MLP(Module):
23 | """Multi Layer Perceptron.
24 |
25 | Args:
26 | input_shape (int): Input shape of the MLP.
27 | n_hidden (int): Number of hidden layers.
28 | layer_size (int or list of ints): Number of neurons in each layer. If an int is
29 | provided, it is used as the number of neurons for all hidden layers. Otherwise,
30 | the list of int is used to define the number of neurons for each layer.
31 | output_shape (int): Output shape of the MLP.
32 | dropout (float): Dropout probability between layers. Defaults to 0.0.
33 | activation (torch.nn.Module): Activation function. Defaults to ReLU.
34 | iterative_whitening (bool): Whether to add an IterNorm layer at the end of the
35 | network. Defaults to False.
36 | bias (bool): Whether to include bias in the layers. Defaults to False.
37 | """
38 |
39 | def __init__(
40 | self,
41 | input_shape,
42 | n_hidden,
43 | layer_size,
44 | output_shape,
45 | dropout=0.0,
46 | activation=ReLU,
47 | iterative_whitening=False,
48 | bias=False,
49 | ):
50 | super(MLP, self).__init__()
51 | if isinstance(layer_size, int):
52 | layer_size = [layer_size] * n_hidden
53 | if n_hidden == 0:
54 | layers = [Linear(input_shape, output_shape, bias=False)]
55 | else:
56 | layers = []
57 | for layer in range(n_hidden):
58 | if layer == 0:
59 | layers.append(
60 | _MLPBlock(input_shape, layer_size[layer], dropout, activation, bias=bias)
61 | )
62 | else:
63 | layers.append(
64 | _MLPBlock(
65 | layer_size[layer - 1], layer_size[layer], dropout, activation, bias=bias
66 | )
67 | )
68 |
69 | layers.append(Linear(layer_size[-1], output_shape, bias=False))
70 | if iterative_whitening:
71 | # layers.append(IterNorm(output_shape))
72 | raise NotImplementedError("IterNorm isn't implemented")
73 | self.model = Sequential(*layers)
74 |
75 | def forward(self, x): # noqa: D102
76 | return self.model(x)
77 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/modules/resnet.py:
--------------------------------------------------------------------------------
1 | """Resnet Module."""
2 |
3 | from typing import Any, Callable, List, Optional, Type, Union
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch import Tensor
8 |
9 | __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
10 |
11 |
12 | def conv3x3(
13 | in_planes: int,
14 | out_planes: int,
15 | stride: int = 1,
16 | groups: int = 1,
17 | dilation: int = 1,
18 | padding_mode: str = "zeros",
19 | ) -> nn.Conv2d:
20 | """3x3 convolution with padding."""
21 | return nn.Conv2d(
22 | in_planes,
23 | out_planes,
24 | kernel_size=3,
25 | stride=stride,
26 | padding=dilation,
27 | groups=groups,
28 | bias=False,
29 | dilation=dilation,
30 | padding_mode=padding_mode,
31 | )
32 |
33 |
34 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
35 | """1x1 convolution."""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | expansion: int = 1
41 |
42 | def __init__(
43 | self,
44 | inplanes: int,
45 | planes: int,
46 | stride: int = 1,
47 | downsample: Optional[nn.Module] = None,
48 | groups: int = 1,
49 | base_width: int = 64,
50 | dilation: int = 1,
51 | padding_mode: str = "zeros",
52 | norm_layer: Optional[Callable[..., nn.Module]] = None,
53 | ) -> None:
54 | super().__init__()
55 | if norm_layer is None:
56 | norm_layer = nn.BatchNorm2d
57 | if groups != 1 or base_width != 64:
58 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
59 | if dilation > 1:
60 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
61 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
62 | self.conv1 = conv3x3(inplanes, planes, stride, padding_mode=padding_mode)
63 | self.bn1 = norm_layer(planes)
64 | self.relu = nn.ReLU(inplace=True)
65 | self.conv2 = conv3x3(planes, planes, padding_mode=padding_mode)
66 | self.bn2 = norm_layer(planes)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x: Tensor) -> Tensor:
71 | identity = x
72 |
73 | out = self.conv1(x)
74 | out = self.bn1(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv2(out)
78 | out = self.bn2(out)
79 |
80 | if self.downsample is not None:
81 | identity = self.downsample(x)
82 |
83 | out += identity
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 |
89 | class Bottleneck(nn.Module):
90 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
91 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
92 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
93 | # This variant is also known as ResNet V1.5 and improves accuracy according to
94 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
95 |
96 | expansion: int = 4
97 |
98 | def __init__(
99 | self,
100 | inplanes: int,
101 | planes: int,
102 | stride: int = 1,
103 | downsample: Optional[nn.Module] = None,
104 | groups: int = 1,
105 | base_width: int = 64,
106 | dilation: int = 1,
107 | padding_mode: str = "zeros",
108 | norm_layer: Optional[Callable[..., nn.Module]] = None,
109 | ) -> None:
110 | super().__init__()
111 | if norm_layer is None:
112 | norm_layer = nn.BatchNorm2d
113 | width = int(planes * (base_width / 64.0)) * groups
114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
115 | self.conv1 = conv1x1(inplanes, width)
116 | self.bn1 = norm_layer(width)
117 | self.conv2 = conv3x3(width, width, stride, groups, dilation, padding_mode=padding_mode)
118 | self.bn2 = norm_layer(width)
119 | self.conv3 = conv1x1(width, planes * self.expansion)
120 | self.bn3 = norm_layer(planes * self.expansion)
121 | self.relu = nn.ReLU(inplace=True)
122 | self.downsample = downsample
123 | self.stride = stride
124 |
125 | def forward(self, x: Tensor) -> Tensor:
126 | identity = x
127 |
128 | out = self.conv1(x)
129 | out = self.bn1(out)
130 | out = self.relu(out)
131 |
132 | out = self.conv2(out)
133 | out = self.bn2(out)
134 | out = self.relu(out)
135 |
136 | out = self.conv3(out)
137 | out = self.bn3(out)
138 |
139 | if self.downsample is not None:
140 | identity = self.downsample(x)
141 |
142 | out += identity
143 | out = self.relu(out)
144 |
145 | return out
146 |
147 |
148 | class ResNet(nn.Module):
149 | """ResNet model from :footcite:t:`he2016deep`.
150 |
151 | Args:
152 | block (Type[Union[BasicBlock, Bottleneck]]): Block type.
153 | layers (List[int]): Number of layers.
154 | channels_in (int): Number of input channels.
155 | num_features (int): Number of features.
156 | zero_init_residual (bool): Zero initialization of residual.
157 | groups (int): Number of groups.
158 | width_per_group (int): Width per group.
159 | replace_stride_with_dilation (Optional[List[bool]]): Replace stride with dilation.
160 | padding_mode (str): Padding mode for the convolutional layers.
161 | norm_layer (Optional[Callable[..., nn.Module]]): Normalization layer.
162 | """
163 |
164 | def __init__(
165 | self,
166 | block: Type[Union[BasicBlock, Bottleneck]],
167 | layers: List[int],
168 | channels_in: int = 3,
169 | num_features: int = 1024,
170 | zero_init_residual: bool = False,
171 | groups: int = 1,
172 | width_per_group: int = 64,
173 | replace_stride_with_dilation: Optional[List[bool]] = None,
174 | padding_mode: str = "zeros",
175 | norm_layer: Optional[Callable[..., nn.Module]] = None,
176 | ) -> None:
177 | super().__init__()
178 | if norm_layer is None:
179 | norm_layer = nn.BatchNorm2d
180 | self._norm_layer = norm_layer
181 |
182 | self.inplanes = 64
183 | self.dilation = 1
184 | if replace_stride_with_dilation is None:
185 | # each element in the tuple indicates if we should replace
186 | # the 2x2 stride with a dilated convolution instead
187 | replace_stride_with_dilation = [False, False, False]
188 | if len(replace_stride_with_dilation) != 3:
189 | raise ValueError(
190 | "replace_stride_with_dilation should be None "
191 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
192 | )
193 | self.groups = groups
194 | self.base_width = width_per_group
195 | self.conv1 = nn.Conv2d(
196 | channels_in,
197 | self.inplanes,
198 | kernel_size=7,
199 | stride=2,
200 | padding=3,
201 | bias=False,
202 | padding_mode=padding_mode,
203 | )
204 | self.bn1 = norm_layer(self.inplanes)
205 | self.relu = nn.ReLU(inplace=True)
206 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
207 | self.layer1 = self._make_layer(block, 64, layers[0], padding_mode=padding_mode)
208 | self.layer2 = self._make_layer(
209 | block,
210 | 128,
211 | layers[1],
212 | stride=2,
213 | dilate=replace_stride_with_dilation[0],
214 | padding_mode=padding_mode,
215 | )
216 | self.layer3 = self._make_layer(
217 | block,
218 | 256,
219 | layers[2],
220 | stride=2,
221 | dilate=replace_stride_with_dilation[1],
222 | padding_mode=padding_mode,
223 | )
224 | self.layer4 = self._make_layer(
225 | block,
226 | 512,
227 | layers[3],
228 | stride=2,
229 | dilate=replace_stride_with_dilation[2],
230 | padding_mode=padding_mode,
231 | )
232 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
233 | self.fc = nn.Linear(512 * block.expansion, num_features, bias=False)
234 |
235 | for m in self.modules():
236 | if isinstance(m, nn.Conv2d):
237 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
238 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
239 | nn.init.constant_(m.weight, 1)
240 | nn.init.constant_(m.bias, 0)
241 |
242 | # Zero-initialize the last BN in each residual branch,
243 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
244 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
245 | if zero_init_residual:
246 | for m in self.modules():
247 | if isinstance(m, Bottleneck) and m.bn3.weight is not None:
248 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
249 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
250 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
251 |
252 | def _make_layer(
253 | self,
254 | block: Type[Union[BasicBlock, Bottleneck]],
255 | planes: int,
256 | blocks: int,
257 | stride: int = 1,
258 | dilate: bool = False,
259 | padding_mode: str = "zeros",
260 | ) -> nn.Sequential:
261 | norm_layer = self._norm_layer
262 | downsample = None
263 | previous_dilation = self.dilation
264 | if dilate:
265 | self.dilation *= stride
266 | stride = 1
267 | if stride != 1 or self.inplanes != planes * block.expansion:
268 | downsample = nn.Sequential(
269 | conv1x1(self.inplanes, planes * block.expansion, stride),
270 | norm_layer(planes * block.expansion),
271 | )
272 |
273 | layers = []
274 | layers.append(
275 | block(
276 | self.inplanes,
277 | planes,
278 | stride,
279 | downsample,
280 | self.groups,
281 | self.base_width,
282 | previous_dilation,
283 | padding_mode,
284 | norm_layer,
285 | )
286 | )
287 | self.inplanes = planes * block.expansion
288 | for _ in range(1, blocks):
289 | layers.append(
290 | block(
291 | self.inplanes,
292 | planes,
293 | groups=self.groups,
294 | base_width=self.base_width,
295 | dilation=self.dilation,
296 | norm_layer=norm_layer,
297 | padding_mode=padding_mode,
298 | )
299 | )
300 |
301 | return nn.Sequential(*layers)
302 |
303 | def _forward_impl(self, x: Tensor) -> Tensor:
304 | # See note [TorchScript super()]
305 | x = self.conv1(x)
306 | x = self.bn1(x)
307 | x = self.relu(x)
308 | x = self.maxpool(x)
309 |
310 | x = self.layer1(x)
311 | x = self.layer2(x)
312 | x = self.layer3(x)
313 | x = self.layer4(x)
314 |
315 | x = self.avgpool(x)
316 | x = torch.flatten(x, 1)
317 | x = self.fc(x)
318 | return x
319 |
320 | def forward(self, x: Tensor) -> Tensor:
321 | """Forward pass of the ResNet model."""
322 | return self._forward_impl(x)
323 |
324 |
325 | def _resnet(
326 | block: Type[Union[BasicBlock, Bottleneck]],
327 | layers: List[int],
328 | **kwargs: Any,
329 | ) -> ResNet:
330 | model = ResNet(block, layers, **kwargs)
331 |
332 | return model
333 |
334 |
335 | def resnet18(*args, **kwargs: Any) -> ResNet:
336 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__."""
337 | return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
338 |
339 |
340 | def resnet34(*args, **kwargs: Any) -> ResNet:
341 | """ResNet-34 from `Deep Residual Learning for Image Recognition `__."""
342 | return _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
343 |
344 |
345 | def resnet50(*args, **kwargs: Any) -> ResNet:
346 | """ResNet-50 from `Deep Residual Learning for Image Recognition `__."""
347 | return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
348 |
349 |
350 | def resnet101(*args, **kwargs: Any) -> ResNet:
351 | """ResNet-101 from `Deep Residual Learning for Image Recognition `__."""
352 | return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
353 |
354 |
355 | def resnet152(*args, **kwargs: Any) -> ResNet:
356 | """ResNet-101 from `Deep Residual Learning for Image Recognition `__."""
357 | return _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
358 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/modules/simnorm.py:
--------------------------------------------------------------------------------
1 | """Simplicial normalizarion."""
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 |
7 |
8 | class SimNorm(nn.Module):
9 | """Simplicial normalization from :footcite:t:`lavoie2022simplicial`.
10 |
11 | Simplicial normalization splits the input into chunks of dimension :code:`dim`, applies a softmax transformation to each of the chunks separately, and concatenates them back together.
12 |
13 | Args:
14 | dim (int): Dimension of the simplicial groups.
15 | """
16 |
17 | def __init__(self, dim: int = 8):
18 | super().__init__()
19 | self.dim = dim
20 |
21 | def forward(self, x: Tensor) -> Tensor:
22 | """Forward pass of the simplicial normalization module."""
23 | shp = x.shape
24 | x = x.view(*shp[:-1], -1, self.dim)
25 | x = F.softmax(x, dim=-1)
26 | return x.view(*shp)
27 |
28 | def __repr__(self):
29 | """String representation of the simplicial norm module."""
30 | return f"SimNorm(dim={self.dim})"
31 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/regressors.py:
--------------------------------------------------------------------------------
1 | """NN regressors."""
2 |
3 | from typing import Literal
4 |
5 | import numpy as np
6 | import scipy.linalg
7 | import torch
8 | from torch import Tensor
9 |
10 | from linear_operator_learning.nn.structs import EigResult, FitResult
11 |
12 |
13 | def ridge_least_squares(
14 | cov_X: Tensor,
15 | tikhonov_reg: float = 0.0,
16 | ) -> FitResult:
17 | """Fit the ridge least squares estimator for the transfer operator.
18 |
19 | Args:
20 | cov_X (Tensor): covariance matrix of the input data.
21 | tikhonov_reg (float, optional): Ridge regularization. Defaults to 0.0.
22 |
23 | """
24 | dim = cov_X.shape[0]
25 | reg_input_covariance = cov_X + tikhonov_reg * torch.eye(
26 | dim, dtype=cov_X.dtype, device=cov_X.device
27 | )
28 | values, vectors = torch.linalg.eigh(reg_input_covariance)
29 | # Divide columns of vectors by square root of eigenvalues
30 | rsqrt_evals = 1.0 / torch.sqrt(values + 1e-10)
31 | Q = vectors @ torch.diag(rsqrt_evals)
32 | result: FitResult = FitResult({"U": Q, "V": Q, "svals": values})
33 | return result
34 |
35 |
36 | def eig(
37 | fit_result: FitResult,
38 | cov_XY: Tensor,
39 | ) -> EigResult:
40 | """Computes the eigendecomposition of a regressor.
41 |
42 | Args:
43 | fit_result (FitResult): Fit result as defined in ``linear_operator_learning.nn.structs``.
44 | cov_XY (Tensor): Cross covariance matrix between the input and output data.
45 |
46 |
47 | Shape:
48 | ``cov_XY``: :math:`(D, D)`, where :math:`D` is the number of features.
49 |
50 | Output: ``U, V`` of shape :math:`(D, R)`, ``svals`` of shape :math:`R`
51 | where :math:`D` is the number of features and :math:`R` is the rank of the regressor.
52 | """
53 | dtype_and_device = {
54 | "dtype": cov_XY.dtype,
55 | "device": cov_XY.device,
56 | }
57 | U = fit_result["U"]
58 | # Using the trick described in https://arxiv.org/abs/1905.11490
59 | M = torch.linalg.multi_dot([U.T, cov_XY, U])
60 | # Convertion to numpy
61 | M = M.numpy(force=True)
62 | values, lv, rv = scipy.linalg.eig(M, left=True, right=True)
63 | r_perm = torch.tensor(np.argsort(values), device=cov_XY.device)
64 | l_perm = torch.tensor(np.argsort(values.conj()), device=cov_XY.device)
65 | values = values[r_perm]
66 | # Back to torch, casting to appropriate dtype and device
67 | values = torch.complex(
68 | torch.tensor(values.real, **dtype_and_device), torch.tensor(values.imag, **dtype_and_device)
69 | )
70 | lv = torch.complex(
71 | torch.tensor(lv.real, **dtype_and_device), torch.tensor(lv.imag, **dtype_and_device)
72 | )
73 | rv = torch.complex(
74 | torch.tensor(rv.real, **dtype_and_device), torch.tensor(rv.imag, **dtype_and_device)
75 | )
76 | # Normalization in RKHS norm
77 | rv = U.to(rv.dtype) @ rv
78 | rv = rv[:, r_perm]
79 | rv = rv / torch.linalg.norm(rv, axis=0)
80 | # # Biorthogonalization
81 | lv = torch.linalg.multi_dot([cov_XY.T.to(lv.dtype), U.to(lv.dtype), lv])
82 | lv = lv[:, l_perm]
83 | l_norm = torch.sum(lv * rv, axis=0)
84 | lv = lv / l_norm
85 | result: EigResult = EigResult({"values": values, "left": lv, "right": rv})
86 | return result
87 |
88 |
89 | def evaluate_eigenfunction(
90 | eig_result: EigResult,
91 | which: Literal["left", "right"],
92 | X: Tensor,
93 | ):
94 | """Evaluates left or right eigenfunctions of a regressor.
95 |
96 | Args:
97 | eig_result: EigResult object containing eigendecomposition results
98 | which: String indicating "left" or "right" eigenfunctions.
99 | X: Feature map of the input data
100 |
101 |
102 | Shape:
103 | ``eig_results``: ``U, V`` of shape :math:`(D, R)`, ``svals`` of shape :math:`R`
104 | where :math:`D` is the number of features and :math:`R` is the rank of the regressor.
105 |
106 | ``X``: :math:`(N_0, D)`, where :math:`N_0` is the number of inputs to predict and :math:`D` is the number of features.
107 |
108 | Output: :math:`(N_0, R)`
109 | """
110 | vr_or_vl = eig_result[which]
111 | return X.to(vr_or_vl.dtype) @ vr_or_vl
112 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/stats.py:
--------------------------------------------------------------------------------
1 | """Statistics utilities for multi-variate random variables."""
2 |
3 | from math import sqrt
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | from linear_operator_learning.nn.linalg import filter_reduced_rank_svals, sqrtmh
9 |
10 |
11 | def covariance(
12 | X: Tensor,
13 | Y: Tensor | None = None,
14 | center: bool = True,
15 | norm: float | None = None,
16 | ) -> Tensor:
17 | """Computes the covariance of X or cross-covariance between X and Y if Y is given.
18 |
19 | Args:
20 | X (Tensor): Input features.
21 | Y (Tensor | None, optional): Output features. Defaults to None.
22 | center (bool, optional): Whether to compute centered covariances. Defaults to True.
23 | norm (float | None, optional): Normalization factor. Defaults to None.
24 |
25 | Shape:
26 | ``X``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
27 |
28 | ``Y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
29 |
30 | Output: :math:`(D, D)`, where :math:`D` is the number of features.
31 | """
32 | assert X.ndim == 2
33 | if norm is None:
34 | norm = sqrt(X.shape[0])
35 | else:
36 | assert norm > 0
37 | norm = sqrt(norm)
38 | if Y is None:
39 | X = X / norm
40 | if center:
41 | X = X - X.mean(dim=0, keepdim=True)
42 | return torch.mm(X.T, X)
43 | else:
44 | assert Y.ndim == 2
45 | X = X / norm
46 | Y = Y / norm
47 | if center:
48 | X = X - X.mean(dim=0, keepdim=True)
49 | Y = Y - Y.mean(dim=0, keepdim=True)
50 | return torch.mm(X.T, Y)
51 |
52 |
53 | def cross_cov_norm_squared_unbiased(x: Tensor, y: Tensor, permutation=None):
54 | r"""Compute the unbiased estimation of :math:`\|\mathbf{C}_{xy}\|_F^2` from a batch of samples, using U-statistics.
55 |
56 | Given the Covariance matrix :math:`\mathbf{C}_{xy} = \mathbb{E}_p(x,y) [x^{\top} y]`, this function computes an unbiased estimation
57 | of the Frobenius norm of the covariance matrix from two independent sampling sets (an effective samples size of :math:`N^2`).
58 |
59 | .. math::
60 |
61 | \begin{align}
62 | \|\mathbf{C}_{xy}\|_F^2 &= \text{tr}(\mathbf{C}_{xy}^{\top} \mathbf{C}_{xy})
63 | = \sum_i \sum_j (\mathbb{E}_{x,y \sim p(x,y)} [x_i y_j]) (\mathbb{E}_{x',y' \sim p(x,y)} [x_j y_i']) \\
64 | &= \mathbb{E}_{(x,y),(x',y') \sim p(x,y)} [(x^{\top} y') (x'^{T} y)] \\
65 | &\approx \frac{1}{N^2} \sum_n \sum_m [(x_{n}^{\top} y^{\prime}_m) (x^{\prime \top}_m y_n)]
66 | \end{align}
67 |
68 | .. note::
69 | The random variable is assumed to be centered.
70 |
71 | Args:
72 | x (Tensor): Centered realizations of a random variable `x` of shape (N, D_x).
73 | y (Tensor): Centered realizations of a random variable `y` of shape (N, D_y).
74 | permutation (Tensor, optional): List of integer indices of shape (n_samples,) used to permute the samples.
75 |
76 | Returns:
77 | Tensor: Unbiased estimation of :math:`\|\mathbf{C}_{xy}\|_F^2` using U-statistics.
78 | """
79 | n_samples = x.shape[0]
80 |
81 | # Permute the rows independently to simulate independent sampling
82 | perm = permutation if permutation is not None else torch.randperm(n_samples)
83 | assert perm.shape == (n_samples,), f"Invalid permutation {perm.shape}!=({n_samples},)"
84 | xp = x[perm] # Independent sampling of x'
85 | yp = y[perm] # Independent sampling of y'
86 |
87 | # Compute 1/N^2 Σ_n Σ_m [(x_n.T y'_m) (x'_m.T y_n)]
88 | val = torch.einsum("nj,mj,mk,nk->", x, yp, xp, y)
89 | cov_fro_norm = val / (n_samples**2)
90 | return cov_fro_norm
91 |
92 |
93 | def cov_norm_squared_unbiased(x: Tensor, permutation=None):
94 | r"""Compute the unbiased estimation of :math:`\|\mathbf{C}_x\|_F^2` from a batch of samples.
95 |
96 | Given the Covariance matrix :math:`\mathbf{C}_x = \mathbb{E}_p(x) [x^{\top} x]`, this function computes an unbiased estimation
97 | of the Frobenius norm of the covariance matrix from a single sampling set.
98 |
99 | .. math::
100 |
101 | \begin{align}
102 | \|\mathbf{C}_x\|_F^2 &= \text{tr}(\mathbf{C}_x^{\top} \mathbf{C}_x) = \sum_i \sum_j (\mathbb{E}_{x} [x_i x_j]) (\mathbb{E}_{x'} [x'_j x'_i]) \\
103 | &= \mathbb{E}_{x,x' \sim p(x)} [(x^{\top} x')^2] \\
104 | &\approx \frac{1}{N^2} \sum_n \sum_m [(x_n^{\top} x'_m)^2]
105 | \end{align}
106 |
107 |
108 | .. note::
109 |
110 | The random variable is assumed to be centered.
111 |
112 | Args:
113 | x (Tensor): (n_samples, r) Centered realizations of a random variable x = [x_1, ..., x_r].
114 | permutation (Tensor, optional): List of integer indices of shape (n_samples,) used to permute the samples.
115 |
116 | Returns:
117 | Tensor: Unbiased estimation of :math:`\|\mathbf{C}_x\|_F^2` using U-statistics.
118 | """
119 | return cross_cov_norm_squared_unbiased(x=x, y=x, permutation=permutation)
120 |
121 |
122 | def whitening(u: Tensor, v: Tensor) -> tuple:
123 | """Computes whitening matrices for ``u`` and ``v``.
124 |
125 | Args:
126 | u (Tensor): Input features.
127 | v (Tensor): Output features.
128 |
129 |
130 | Shape:
131 | ``u``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
132 |
133 | ``v``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
134 |
135 | ``sqrt_cov_u_inv``: :math:`(D, D)`
136 |
137 | ``sqrt_cov_v_inv``: :math:`(D, D)`
138 |
139 | ``sing_val``: :math:`(D,)`
140 |
141 | ``sing_vec_l``: :math:`(D, D)`
142 |
143 | ``sing_vec_r``: :math:`(D, D)`
144 | """
145 | cov_u = covariance(u)
146 | cov_v = covariance(v)
147 | cov_uv = covariance(u, v)
148 |
149 | sqrt_cov_u_inv = torch.linalg.pinv(sqrtmh(cov_u))
150 | sqrt_cov_v_inv = torch.linalg.pinv(sqrtmh(cov_v))
151 |
152 | M = sqrt_cov_u_inv @ cov_uv @ sqrt_cov_v_inv
153 | e_val, sing_vec_l = torch.linalg.eigh(M @ M.T)
154 | e_val, sing_vec_l = filter_reduced_rank_svals(e_val, sing_vec_l)
155 | sing_val = torch.sqrt(e_val)
156 | sing_vec_r = (M.T @ sing_vec_l) / sing_val
157 |
158 | return sqrt_cov_u_inv, sqrt_cov_v_inv, sing_val, sing_vec_l, sing_vec_r
159 |
--------------------------------------------------------------------------------
/linear_operator_learning/nn/structs.py:
--------------------------------------------------------------------------------
1 | """Structs used by the `nn` algorithms."""
2 |
3 | from typing import TypedDict
4 |
5 | from torch import Tensor
6 |
7 |
8 | class FitResult(TypedDict):
9 | """Return type for nn regressors."""
10 |
11 | U: Tensor
12 | V: Tensor
13 | svals: Tensor | None
14 |
15 |
16 | class EigResult(TypedDict):
17 | """Return type for eigenvalue decompositions of nn regressors."""
18 |
19 | values: Tensor
20 | left: Tensor | None
21 | right: Tensor
22 |
--------------------------------------------------------------------------------
/linear_operator_learning/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/linear_operator_learning/py.typed
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/linear_operator_learning/9be4c0ba4ea0f2cc1edfe206e0e682cde9054991/logo.png
--------------------------------------------------------------------------------
/logo.svg:
--------------------------------------------------------------------------------
1 |
26 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "linear-operator-learning"
3 | version = "0.2.5"
4 | description = "A package to learn linear operators"
5 | readme = "README.md"
6 | authors = [
7 | { name = "Alek Frohlich", email = "alek.frohlich@gmail.com" },
8 | { name = "Pietro Novelli", email = "pietronvll@gmail.com"}
9 | ]
10 | requires-python = ">=3.10"
11 | dependencies = [
12 | "numpy>=1.26",
13 | "scipy>=1.12",
14 | "torch>=2.0",
15 | ]
16 |
17 | [build-system]
18 | requires = ["hatchling"]
19 | build-backend = "hatchling.build"
20 |
21 | [dependency-groups]
22 | dev = [
23 | "bumpver>=2024.1130",
24 | "pre-commit>=4.1.0",
25 | "pytest>=8.3.4",
26 | "ruff>=0.9.4",
27 | "ty>=0.0.1a7",
28 | ]
29 | docs = [
30 | "myst-parser>=4.0.0",
31 | "setuptools>=75.8.0",
32 | "sphinx-design>=0.6.1",
33 | "sphinx>=8.1.3",
34 | "sphinx-autobuild>=2024.10.3",
35 | "sphinxawesome-theme>=5.3.2",
36 | "sphinxcontrib-applehelp==2.0.0",
37 | "sphinxcontrib-bibtex>=2.6.3",
38 | "sphinxcontrib-jsmath==1.0.1",
39 | "myst-nb>=1.2.0",
40 | ]
41 | examples = [
42 | "ipykernel>=6.29.5",
43 | "lightning>=2.5.0.post0",
44 | "loguru>=0.7.3",
45 | "matplotlib>=3.10.0",
46 | "scikit-learn>=1.6.1",
47 | "seaborn>=0.13.2",
48 | ]
49 |
50 | [tool.ruff]
51 | line-length = 100
52 | exclude = ["docs"]
53 |
54 | [tool.ruff.lint]
55 | select = [
56 | "D", # pydocstyle rules, limiting to those that adhere to the google convention
57 | "E4", # Errors
58 | "E7",
59 | "E9",
60 | "F", # Pyflakes
61 | "I"
62 | ]
63 |
64 | ignore = [
65 | "F401", # Don't remove unused imports
66 | "D107", # Document __init__ arguments inside class docstring
67 | ]
68 |
69 |
70 |
71 | [tool.ruff.lint.pydocstyle]
72 | convention = "google"
73 |
74 | [bumpver]
75 | current_version = "0.2.5"
76 | version_pattern = "MAJOR.MINOR.PATCH"
77 |
78 | [bumpver.file_patterns]
79 | "pyproject.toml" = [
80 | 'version = "{version}"'
81 | ]
82 |
--------------------------------------------------------------------------------