├── requirements.txt ├── doc ├── logo │ └── qndiag_logo.ico ├── doc-requirements.txt ├── api.rst ├── whats_new.rst ├── _templates │ └── layout.html ├── index.rst ├── Makefile └── conf.py ├── examples ├── README.txt ├── plot_toy_example.py ├── plot_tutorial.py └── article_figure.py ├── MANIFEST.in ├── .github └── workflows │ ├── tests_octave.yml │ ├── tests.yml │ └── deploy_ghpages.yml ├── qndiag ├── tests │ ├── test_pham.py │ └── test_qndiag.py ├── __init__.py ├── pham.py └── qndiag.py ├── matlab_octave ├── toy_example.m └── qndiag.m ├── LICENSE ├── Makefile ├── .gitignore ├── setup.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | -------------------------------------------------------------------------------- /doc/logo/qndiag_logo.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pierreablin/qndiag/HEAD/doc/logo/qndiag_logo.ico -------------------------------------------------------------------------------- /doc/doc-requirements.txt: -------------------------------------------------------------------------------- 1 | numpydoc 2 | pillow 3 | matplotlib 4 | sphinx-bootstrap-theme 5 | sphinx-gallery -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _general_examples: 2 | 3 | Examples Gallery 4 | ================ 5 | 6 | .. contents:: Contents 7 | :local: 8 | :depth: 3 9 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | recursive-include examples *.py 3 | 4 | include Makefile 5 | recursive-include qndiag *.py 6 | 7 | recursive-exclude matlab_octave * 8 | -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | .. _api_documentation: 2 | 3 | ================= 4 | API Documentation 5 | ================= 6 | 7 | .. currentmodule:: qndiag 8 | 9 | QNDIAG 10 | ====== 11 | 12 | Functions 13 | 14 | .. autosummary:: 15 | :toctree: generated/ 16 | 17 | qndiag 18 | transform_set 19 | loss 20 | gradient 21 | ajd_pham 22 | -------------------------------------------------------------------------------- /.github/workflows/tests_octave.yml: -------------------------------------------------------------------------------- 1 | name: tests_octave 2 | 3 | on: 4 | push: 5 | branches: 6 | - '*' 7 | pull_request: 8 | branches: 9 | - '*' 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Install octave 19 | run: | 20 | sudo apt-get update 21 | sudo apt-get install octave 22 | - name: Test 23 | run: | 24 | cd matlab_octave 25 | octave toy_example.m 26 | -------------------------------------------------------------------------------- /doc/whats_new.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. _whats_new: 4 | 5 | What's new? 6 | =========== 7 | 8 | .. currentmodule:: autoreject 9 | 10 | .. _current: 11 | 12 | Current 13 | ------- 14 | 15 | Changelog 16 | ~~~~~~~~~ 17 | 18 | - Add ``ortho`` parameter to :func:`qndiag.qndiag` to impose orthogonality to estimated matrix, by `Hugo Richard`_ in `#11 `_ 19 | 20 | Bug 21 | ~~~ 22 | 23 | API 24 | ~~~ 25 | 26 | - The package now requires scipy. 27 | 28 | Authors 29 | ~~~~~~~~ 30 | 31 | .. _Hugo Richard: https://hugorichard.github.io/ 32 | -------------------------------------------------------------------------------- /examples/plot_toy_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Joint diagonalization on toy data 3 | ================================= 4 | 5 | """ 6 | 7 | # Authors: Pierre Ablin 8 | # 9 | # License: MIT 10 | 11 | import numpy as np 12 | from qndiag import qndiag 13 | 14 | n, p = 10, 3 15 | diagonals = np.random.uniform(size=(n, p)) 16 | A = np.random.randn(p, p) # mixing matrix 17 | C = np.array([A.dot(d[:, None] * A.T) for d in diagonals]) # dataset 18 | B, _ = qndiag(C) 19 | 20 | with np.printoptions(precision=3, suppress=True): 21 | print(B.dot(A)) # Should be a permutation + scale matrix 22 | -------------------------------------------------------------------------------- /qndiag/tests/test_pham.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_array_equal 3 | 4 | from qndiag import ajd_pham 5 | 6 | 7 | def test_ajd(): 8 | """Test approximate joint diagonalization.""" 9 | n, p = 10, 3 10 | rng = np.random.RandomState(42) 11 | diagonals = rng.uniform(size=(n, p)) 12 | A = rng.randn(p, p) # mixing matrix 13 | C = np.array([A.dot(d[:, None] * A.T) for d in diagonals]) # dataset 14 | B, _ = ajd_pham(C) 15 | BA = np.abs(B.dot(A)) # BA Should be a permutation + scale matrix 16 | BA /= np.max(BA, axis=1, keepdims=True) 17 | BA[np.abs(BA) < 1e-8] = 0. 18 | assert_array_equal(BA[np.lexsort(BA)], np.eye(p)) 19 | -------------------------------------------------------------------------------- /matlab_octave/toy_example.m: -------------------------------------------------------------------------------- 1 | % Authors: Pierre Ablin 2 | % Alexandre Gramfort 3 | % 4 | % License: MIT 5 | 6 | clc; clear 7 | 8 | rand('seed', 42); 9 | randn('seed', 42); 10 | 11 | n = 10; 12 | p = 3; 13 | 14 | diagonals = rand(n, p); 15 | A = randn(p, p); % mixing matrix 16 | 17 | C = zeros(n, p, p); 18 | for k=1:n 19 | C(k, :, :) = A * diag(diagonals(k, :)) * A'; 20 | end 21 | 22 | [D, B] = qndiag(C, 'max_iter', 100); 23 | 24 | B * A % Should be a permutation + scale matrix 25 | 26 | weights = rand(n, 1); 27 | 28 | [D, B] = qndiag(C, 'max_iter', 100, 'weights', weights); 29 | 30 | B * A % Should be a permutation + scale matrix 31 | -------------------------------------------------------------------------------- /qndiag/__init__.py: -------------------------------------------------------------------------------- 1 | # Authors: Pierre Ablin 2 | # 3 | # License: MIT 4 | """Joint diagonalization in Python""" 5 | 6 | # PEP0440 compatible formatted version, see: 7 | # https://www.python.org/dev/peps/pep-0440/ 8 | # 9 | # Generic release markers: 10 | # X.Y 11 | # X.Y.Z # For bugfix releases 12 | # 13 | # Admissible pre-release markers: 14 | # X.YaN # Alpha release 15 | # X.YbN # Beta release 16 | # X.YrcN # Release Candidate 17 | # X.Y # Final release 18 | # 19 | # Dev branch marker is: 'X.Y.devN' where N is an integer. 20 | # 21 | 22 | __version__ = '0.2.dev' 23 | 24 | from .qndiag import qndiag, transform_set, loss, gradient # noqa 25 | from .pham import ajd_pham # noqa 26 | -------------------------------------------------------------------------------- /doc/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {# Import the theme's layout. #} 2 | {% extends "!layout.html" %} 3 | 4 | {# remove site and page menus #} 5 | {%- block sidebartoc %} 6 | {% endblock %} 7 | {%- block sidebarrel %} 8 | {% endblock %} 9 | 10 | {%- block content %} 11 | {{ navBar() }} 12 |
13 | {% block body %}{% endblock %} 14 |
15 | 16 | 17 | Fork me on GitHub 21 | 22 | {%- endblock %} 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # simple makefile to simplify repetetive build env management tasks under posix 2 | 3 | PYTHON ?= python 4 | PYTESTS ?= pytest 5 | 6 | CTAGS ?= ctags 7 | 8 | all: clean inplace test 9 | 10 | clean-pyc: 11 | find . -name "*.pyc" | xargs rm -f 12 | find . -name "__pycache__" | xargs rm -rf 13 | 14 | clean-build: 15 | rm -rf build 16 | 17 | clean-ctags: 18 | rm -f tags 19 | 20 | clean: clean-build clean-pyc clean-ctags 21 | 22 | in: inplace # just a shortcut 23 | inplace: 24 | $(PYTHON) setup.py build_ext -i 25 | 26 | test-code: 27 | rm -rf coverage .coverage 28 | $(PYTESTS) qndiag --cov=qndiag 29 | 30 | test-doc: 31 | $(PYTESTS) $(shell find doc -name '*.rst' | sort) 32 | 33 | test: test-code test-manifest 34 | 35 | trailing-spaces: 36 | find . -name "*.py" | xargs perl -pi -e 's/[ \t]*$$//' 37 | 38 | ctags: 39 | # make tags for symbol based navigation in emacs and vim 40 | # Install with: sudo apt-get install exuberant-ctags 41 | $(CTAGS) -R * 42 | 43 | .PHONY : doc-plot 44 | doc-plot: 45 | make -C doc html 46 | 47 | .PHONY : doc 48 | doc: 49 | make -C doc html-noplot 50 | 51 | test-manifest: 52 | check-manifest --ignore doc,qndiag/*/tests,matlab_octave; 53 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: tests 5 | 6 | on: 7 | push: 8 | branches: 9 | - '*' 10 | pull_request: 11 | branches: 12 | - '*' 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: [3.7, 3.9] 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | python -m pip install flake8 pytest 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | - name: Lint with flake8 34 | run: | 35 | flake8 qndiag 36 | flake8 examples 37 | - name: Test with pytest 38 | run: | 39 | pip install -e . 40 | pytest 41 | -------------------------------------------------------------------------------- /qndiag/tests/test_qndiag.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_allclose 3 | 4 | import pytest 5 | 6 | from qndiag import qndiag 7 | 8 | 9 | @pytest.mark.parametrize('weights', [None, True]) 10 | @pytest.mark.parametrize('ortho', [False, True]) 11 | def test_qndiag(weights, ortho): 12 | n, p = 10, 3 13 | rng = np.random.RandomState(42) 14 | diagonals = rng.uniform(size=(n, p)) 15 | A = rng.randn(p, p) # mixing matrix 16 | if ortho: 17 | Ua, _, Va = np.linalg.svd(A, full_matrices=False) 18 | A = Ua.dot(Va) 19 | C = np.array([A.dot(d[:, None] * A.T) for d in diagonals]) # dataset 20 | if weights: 21 | weights = rng.rand(n) 22 | B, _ = qndiag(C, weights=weights, 23 | B0=np.eye(p), ortho=ortho) # use the algorithm 24 | BA = np.abs(B.dot(A)) # BA Should be a permutation + scale matrix 25 | if not ortho: 26 | BA /= np.max(BA, axis=1, keepdims=True) 27 | BA[np.abs(BA) < 1e-6] = 0. 28 | assert_allclose(BA[np.lexsort(BA)], np.eye(p)) 29 | 30 | 31 | def test_errors(): 32 | n, p = 10, 2 33 | rng = np.random.RandomState(42) 34 | with pytest.raises(ValueError, match='3 dimensions'): 35 | x = rng.randn(n, p) 36 | qndiag(x) 37 | with pytest.raises(ValueError, match='last two dimensions'): 38 | x = rng.randn(n, p, p + 1) 39 | qndiag(x) 40 | with pytest.raises(ValueError, match='only symmetric'): 41 | x = rng.randn(n, p, p) 42 | qndiag(x) 43 | with pytest.raises(ValueError, match='positive'): 44 | x = rng.randn(n, p, p) 45 | x += x.swapaxes(1, 2) 46 | x[0] = np.array([[0, 1], [1, 0]]) 47 | qndiag(x) 48 | -------------------------------------------------------------------------------- /examples/plot_tutorial.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple tutorial on joint diagonalization 3 | ========================================== 4 | 5 | We generate some independent signals with different powers. 6 | The signals are then mixed, and their covariances are computed. 7 | Joint diagonalization recovers the mixing matrix. 8 | """ 9 | 10 | # Authors: Pierre Ablin 11 | # 12 | # License: MIT 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | 17 | from qndiag import qndiag 18 | 19 | 20 | rng = np.random.RandomState(0) 21 | 22 | ############################################################################### 23 | # We take 10 different bins, and 5 sources. We generate random powers for 24 | # each source and bin 25 | n_bins = 10 26 | n_sources = 5 27 | powers = rng.rand(n_bins, n_sources) 28 | 29 | ############################################################################### 30 | # Next, we generate a random minxing matrix A, and for each bin, we generate 31 | # sources s with the powers above, and observe the signals x = A.dot(s). 32 | # We then store the covariances of the signals 33 | 34 | n_samples = 100 35 | A = rng.randn(n_sources, n_sources) 36 | covariances = [] 37 | for power in powers: 38 | s = power[:, None] * rng.randn(n_sources, n_samples) 39 | x = np.dot(A, s) 40 | covariances.append(np.dot(x, x.T) / n_samples) 41 | 42 | covariances = np.array(covariances) 43 | ############################################################################### 44 | # We now use qndiag on 'covariances' to recover the unmixing matrix, i.e the 45 | # inverse of A 46 | 47 | B, _ = qndiag(covariances) 48 | 49 | unmixing_mixing = np.dot(B, A) 50 | plt.matshow(unmixing_mixing) # Should be ~ a permutation + scale matrix 51 | plt.show() 52 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | QNDIAG 2 | ====== 3 | 4 | This is a library to run the QNDIAG algorithm [1]. 5 | This algorithm exploits a state-of-the-art quasi-Newton strategy for 6 | approximate joint diagonalization of a list of matrices. 7 | 8 | Installation 9 | ------------ 10 | 11 | To install qndiag:: 12 | 13 | $ pip install qndiag 14 | 15 | If you do not have admin privileges on the computer, use the ``--user`` flag 16 | with `pip`. To upgrade, use the ``--upgrade`` flag provided by `pip`. 17 | 18 | To check if everything worked fine, you can do:: 19 | 20 | $ python -c 'import qndiag' 21 | 22 | and it should not give any error message. 23 | 24 | Quickstart 25 | ---------- 26 | 27 | The easiest way to get started is to copy the following lines of code 28 | in your script: 29 | 30 | .. code:: python 31 | 32 | >>> import numpy as np 33 | >>> from qndiag import qndiag 34 | >>> n, p = 10, 3 35 | >>> diagonals = np.random.uniform(size=(n, p)) 36 | >>> A = np.random.randn(p, p) # mixing matrix 37 | >>> C = np.array([A.dot(d[:, None] * A.T) for d in diagonals]) # dataset 38 | >>> B, _ = qndiag(C) 39 | >>> print(B.dot(A)) # Should be a permutation + scale matrix # doctest:+ELLIPSIS 40 | 41 | Bug reports 42 | ----------- 43 | 44 | Use the `github issue tracker `_ to report bugs. 45 | 46 | Cite 47 | ---- 48 | 49 | [1] P. Ablin, J.F. Cardoso and A. Gramfort. Beyond Pham's algorithm 50 | for joint diagonalization. Proc. ESANN 2019. 51 | https://www.elen.ucl.ac.be/Proceedings/esann/esannpdf/es2019-119.pdf 52 | https://hal.archives-ouvertes.fr/hal-01936887v1 53 | https://arxiv.org/abs/1811.11433 54 | 55 | API 56 | --- 57 | 58 | .. toctree:: 59 | :maxdepth: 1 60 | 61 | api.rst 62 | whats_new.rst 63 | -------------------------------------------------------------------------------- /.github/workflows/deploy_ghpages.yml: -------------------------------------------------------------------------------- 1 | name: Deploy GitHub pages 2 | 3 | on: [push, pull_request] 4 | 5 | 6 | jobs: 7 | build_docs: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Generate HTML docs 12 | uses: ammaraskar/sphinx-action@master 13 | with: 14 | docs-folder: "doc/" 15 | pre-build-command: | 16 | apt-get update 17 | pip install -e . 18 | pip install -r doc/doc-requirements.txt 19 | - name: Upload generated HTML as artifact 20 | uses: actions/upload-artifact@v2 21 | with: 22 | name: DocHTML 23 | path: doc/_build/html/ 24 | 25 | deploy_docs: 26 | if: github.ref == 'refs/heads/master' 27 | needs: 28 | build_docs 29 | runs-on: ubuntu-latest 30 | steps: 31 | - uses: actions/checkout@v2 32 | - name: Download artifacts 33 | uses: actions/download-artifact@v2 34 | with: 35 | name: DocHTML 36 | path: doc/_build/html/ 37 | - name: Commit to documentation branch 38 | run: | 39 | git clone --no-checkout --depth 1 https://github.com/${{ github.repository_owner }}/qndiag.git --branch gh-pages --single-branch gh-pages 40 | cp -r doc/_build/html/* gh-pages/ 41 | cd gh-pages 42 | touch .nojekyll 43 | git config --local user.email "qndiag@github.com" 44 | git config --local user.name "qndiag GitHub Action" 45 | git add . 46 | git commit -m "Update documentation" -a || true 47 | - name: Push changes 48 | uses: ad-m/github-push-action@v0.6.0 49 | with: 50 | branch: gh-pages 51 | directory: gh-pages 52 | github_token: ${{ secrets.GITHUB_TOKEN }} 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | _build 3 | auto_examples 4 | gen_modules 5 | generated 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import os 4 | from setuptools import setup, find_packages 5 | 6 | descr = """Joint diagonalization in Python""" 7 | 8 | version = None 9 | with open(os.path.join('qndiag', '__init__.py'), 'r') as fid: 10 | for line in (line.strip() for line in fid): 11 | if line.startswith('__version__'): 12 | version = line.split('=')[1].strip().strip('\'') 13 | break 14 | if version is None: 15 | raise RuntimeError('Could not determine version') 16 | 17 | 18 | DISTNAME = 'qndiag' 19 | DESCRIPTION = descr 20 | MAINTAINER = 'Pierre Ablin' 21 | MAINTAINER_EMAIL = 'pierreablin@gmail.com' 22 | LICENSE = 'MIT' 23 | DOWNLOAD_URL = 'https://github.com/pierreablin/qndiag.git' 24 | VERSION = version 25 | URL = 'https://github.com/pierreablin/qndiag' 26 | 27 | if __name__ == "__main__": 28 | setup(name=DISTNAME, 29 | maintainer=MAINTAINER, 30 | maintainer_email=MAINTAINER_EMAIL, 31 | description=DESCRIPTION, 32 | license=LICENSE, 33 | version=VERSION, 34 | url=URL, 35 | download_url=DOWNLOAD_URL, 36 | long_description=open('README.md').read(), 37 | install_requires=[ 38 | 'numpy >=1.16.0', 39 | 'scipy >=1.2.0', 40 | ], 41 | classifiers=[ 42 | 'Intended Audience :: Science/Research', 43 | 'Intended Audience :: Developers', 44 | 'License :: OSI Approved', 45 | 'Programming Language :: Python', 46 | 'Topic :: Software Development', 47 | 'Topic :: Scientific/Engineering', 48 | 'Operating System :: Microsoft :: Windows', 49 | 'Operating System :: POSIX', 50 | 'Operating System :: Unix', 51 | 'Operating System :: MacOS', 52 | ], 53 | platforms='any', 54 | packages=find_packages(), 55 | ) 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quasi-Newton algorithm for joint-diagonalization 2 | 3 | 4 | ![Build](https://github.com/pierreablin/qndiag/workflows/tests/badge.svg) 5 | ![Codecov](https://codecov.io/gh/pierreablin/qndiag/branch/master/graph/badge.svg) 6 | 7 | ## Doc and website 8 | 9 | See here for the documentation and examples: https://pierreablin.github.io/qndiag/ 10 | 11 | ## Summary 12 | 13 | This Python package contains code for fast joint-diagonalization of a set of 14 | positive definite symmetric matrices. The main function is `qndiag`, 15 | which takes as input a set of matrices of size `(p, p)`, stored as a `(n, p, p)` 16 | array, `C`. It outputs a `(p, p)` array, `B`, such that the matrices 17 | `B @ C[i] @ B.T` (python), i.e. `B * C(i,:,:) * B'` (matlab/octave) 18 | are as diagonal as possible. 19 | 20 | ## Installation of Python package 21 | 22 | To install the package, simply do: 23 | 24 | `$ pip install qndiag` 25 | 26 | You can also simply clone it, and then do: 27 | 28 | `$ pip install -e .` 29 | 30 | To check that everything worked, the command 31 | 32 | `$ python -c 'import qndiag'` 33 | 34 | should not return any error. 35 | 36 | ## Use with Python 37 | 38 | Here is a toy example (also available at `examples/toy_example.py`) 39 | 40 | ```python 41 | import numpy as np 42 | from qndiag import qndiag 43 | 44 | n, p = 10, 3 45 | diagonals = np.random.uniform(size=(n, p)) 46 | A = np.random.randn(p, p) # mixing matrix 47 | C = np.array([A.dot(d[:, None] * A.T) for d in diagonals]) # dataset 48 | 49 | 50 | B, _ = qndiag(C) # use the algorithm 51 | 52 | print(B.dot(A)) # Should be a permutation + scale matrix 53 | ``` 54 | 55 | ## Use with Matlab or Octave 56 | 57 | See `qndiag.m` and `toy_example.m` in the folder `matlab_octave`. 58 | 59 | ## Cite 60 | 61 | If you use this code please cite: 62 | 63 | P. Ablin, J.F. Cardoso and A. Gramfort. Beyond Pham’s algorithm 64 | for joint diagonalization. Proc. ESANN 2019. 65 | https://www.elen.ucl.ac.be/Proceedings/esann/esannpdf/es2019-119.pdf 66 | https://hal.archives-ouvertes.fr/hal-01936887v1 67 | https://arxiv.org/abs/1811.11433 68 | -------------------------------------------------------------------------------- /examples/article_figure.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replicate figure from paper 3 | =========================== 4 | 5 | """ 6 | 7 | # Authors: Pierre Ablin 8 | # 9 | # License: MIT 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from matplotlib.ticker import LogLocator 14 | 15 | import mne 16 | from mne.datasets import sample 17 | 18 | from qndiag import qndiag, ajd_pham, gradient, transform_set 19 | 20 | rng = np.random.RandomState(0) 21 | 22 | fontsize = 5 23 | params = { 24 | 'axes.titlesize': 10, 25 | 'axes.labelsize': 10, 26 | 'font.size': 7, 27 | 'legend.fontsize': 8, 28 | 'xtick.labelsize': fontsize, 29 | 'ytick.labelsize': fontsize, 30 | 'text.usetex': True, 31 | 'ytick.major.pad': '0', 32 | 'ytick.minor.pad': '0'} 33 | plt.rcParams.update(params) 34 | 35 | 36 | def loss(D): 37 | n, p, _ = D.shape 38 | output = 0 39 | for i in range(n): 40 | Di = D[i] 41 | output += np.sum(np.log(np.diagonal(Di))) - np.linalg.slogdet(Di)[1] 42 | return output / (2 * n) 43 | 44 | 45 | n, p = 100, 40 46 | 47 | 48 | f, axes = plt.subplots(2, 3, figsize=(7, 3.04), sharex='col') 49 | expe_str = ['(a)', '(b)', '(c)'] 50 | axes = axes.T 51 | for j, (sigma, axe) in enumerate(zip([0., 0.1, 0], axes)): 52 | if j != 2: # Synthetic data 53 | # Generate diagonal matrices 54 | D = rng.uniform(size=(n, p)) 55 | # Generate a random mixing matrix 56 | A = rng.randn(p, p) 57 | C = np.zeros((n, p, p)) 58 | # Generate the dataset 59 | for i in range(n): 60 | R = rng.randn(p, p) 61 | C[i] = np.dot(A, D[i, :, None] * A.T) + sigma ** 2 * R.dot(R.T) 62 | else: # Real data 63 | data_path = sample.data_path() 64 | raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' 65 | raw = mne.io.read_raw_fif(raw_fname, preload=True) 66 | X = raw.get_data() 67 | # Reduce dimension of X by PCA: 68 | U, D, V = np.linalg.svd(X, full_matrices=False) 69 | X = V[-p:, :] 70 | C = np.array([np.dot(x, x.T) for x in np.split(X, n, axis=1)]) 71 | 72 | for algo in [qndiag, ajd_pham]: 73 | _, infos = algo(C, return_B_list=True) 74 | # For Pham, compute metrics after the algorithm is run 75 | B_list = infos['B_list'] 76 | infos['gradient_list'] =\ 77 | [np.linalg.norm(gradient(transform_set(B, C))) for B in B_list] 78 | infos['loss_list'] =\ 79 | [loss(transform_set(B, C)) for B in B_list] 80 | for i, (to_plot, name, ax) in enumerate( 81 | zip(['loss_list', 'gradient_list'], 82 | ['Objective function', 83 | 'Gradient norm'], 84 | axe)): 85 | ax.loglog(infos['t_list'], infos[to_plot], linewidth=2) 86 | if i == 1 and j == 1: 87 | ax.set_xlabel('Time (sec.)') 88 | if j == 0: 89 | ax.set_ylabel(name) 90 | if i == 1: 91 | art = ax.annotate(expe_str[j], (0, 0), (50, -30), 92 | xycoords='axes fraction', 93 | textcoords='offset points', va='top') 94 | ax.grid(True) 95 | ax.yaxis.set_major_locator(LogLocator(numticks=4, subs=(1.,))) 96 | ax.minorticks_off() 97 | 98 | lgd = plt.figlegend(ax.lines, ['Quasi-Newton (proposed)', 'Pham 01'], 99 | loc=(0.32, .9), ncol=2, labelspacing=0.) 100 | plt.savefig('expe.pdf', bbox_extra_artists=(art, lgd), bbox_inches='tight') 101 | -------------------------------------------------------------------------------- /qndiag/pham.py: -------------------------------------------------------------------------------- 1 | # Authors: Pierre Ablin 2 | # 3 | # License: MIT 4 | from time import time 5 | 6 | import numpy as np 7 | from .qndiag import transform_set 8 | 9 | 10 | def ajd_pham(X, tol=1e-14, max_iter=1000, return_B_list=False, verbose=False): 11 | """ 12 | This function comes from mne-python/decoding/csp.py 13 | Approximate joint diagonalization based on Pham's algorithm. 14 | 15 | This is a direct implementation of the PHAM's AJD algorithm [1]. 16 | 17 | Parameters 18 | ---------- 19 | X : ndarray, shape (n_epochs, n_channels, n_channels) 20 | A set of covariance matrices to diagonalize. 21 | tol : float, defaults to 1e-6 22 | The tolerance for stoping criterion. 23 | max_iter : int, defaults to 1000 24 | The maximum number of iteration to reach convergence. 25 | 26 | Returns 27 | ------- 28 | V : ndarray, shape (n_channels, n_channels) 29 | The diagonalizer. 30 | D : ndarray, shape (n_epochs, n_channels, n_channels) 31 | The set of quasi diagonal matrices. 32 | 33 | References 34 | ---------- 35 | .. [1] Pham, Dinh Tuan. "Joint approximate diagonalization of positive 36 | definite Hermitian matrices." SIAM Journal on Matrix Analysis and 37 | Applications 22, no. 4 (2001): 1136-1152. 38 | 39 | """ 40 | # Adapted from http://github.com/alexandrebarachant/pyRiemann 41 | t0 = time() 42 | n_epochs = X.shape[0] 43 | C_mean = np.mean(X, axis=0) 44 | d, p = np.linalg.eigh(C_mean) 45 | V = p.T / np.sqrt(d[:, None]) 46 | X = transform_set(V, X) 47 | # Reshape input matrix 48 | A = np.concatenate(X, axis=0).T 49 | # Init variables 50 | n_times, n_m = A.shape 51 | epsilon = n_times * (n_times - 1) * tol 52 | t_list = [] 53 | if return_B_list: 54 | B_list = [] 55 | for it in range(max_iter): 56 | t_list.append(time() - t0) 57 | if return_B_list: 58 | B_list.append(V.copy()) 59 | decr = 0 60 | for ii in range(1, n_times): 61 | for jj in range(ii): 62 | Ii = np.arange(ii, n_m, n_times) 63 | Ij = np.arange(jj, n_m, n_times) 64 | 65 | c1 = A[ii, Ii] 66 | c2 = A[jj, Ij] 67 | c3 = A[ii, Ij] 68 | 69 | g12 = np.mean(c3 / c1) 70 | g21 = np.mean(c3 / c2) 71 | 72 | omega21 = np.mean(c1 / c2) 73 | omega12 = np.mean(c2 / c1) 74 | omega = np.sqrt(omega12 * omega21) 75 | 76 | tmp = np.sqrt(omega21 / omega12) 77 | tmp1 = (tmp * g12 + g21) / (omega + 1) 78 | tmp2 = (tmp * g12 - g21) / max(omega - 1, 1e-9) 79 | 80 | h12 = tmp1 + tmp2 81 | h21 = np.conj((tmp1 - tmp2) / tmp) 82 | 83 | decr += n_epochs * (g12 * np.conj(h12) + g21 * h21) / 2.0 84 | 85 | tmp = 1 + 1.j * 0.5 * np.imag(h12 * h21) 86 | tmp = np.real(tmp + np.sqrt(tmp ** 2 - h12 * h21)) 87 | tau = np.array([[1, -h12 / tmp], [-h21 / tmp, 1]]) 88 | 89 | A[[ii, jj], :] = np.dot(tau, A[[ii, jj], :]) 90 | tmp = np.c_[A[:, Ii], A[:, Ij]] 91 | tmp = np.reshape(tmp, (n_times * n_epochs, 2), order='F') 92 | tmp = np.dot(tmp, tau.T) 93 | 94 | tmp = np.reshape(tmp, (n_times, n_epochs * 2), order='F') 95 | A[:, Ii] = tmp[:, :n_epochs] 96 | A[:, Ij] = tmp[:, n_epochs:] 97 | V[[ii, jj], :] = np.dot(tau, V[[ii, jj], :]) 98 | if verbose: 99 | print('Iteration %d, decr : %.2e' % (it, decr)) 100 | if decr < epsilon: 101 | break 102 | 103 | infos = {'t_list': t_list} 104 | if return_B_list: 105 | infos['B_list'] = B_list 106 | return V, infos 107 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | GITHUB_PAGES_BRANCH = gh-pages 11 | OUTPUTDIR = _build/html 12 | 13 | # User-friendly check for sphinx-build 14 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 15 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 16 | endif 17 | 18 | # Internal variables. 19 | PAPEROPT_a4 = -D latex_paper_size=a4 20 | PAPEROPT_letter = -D latex_paper_size=letter 21 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 22 | # the i18n builder cannot share the environment and doctrees with the others 23 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 24 | 25 | .PHONY: help 26 | help: 27 | @echo "Please use \`make ' where is one of" 28 | @echo " html-noplot to make standalone HTML files, without plotting anything" 29 | @echo " html to make standalone HTML files" 30 | @echo " dirhtml to make HTML files named index.html in directories" 31 | @echo " singlehtml to make a single large HTML file" 32 | @echo " pickle to make pickle files" 33 | @echo " htmlhelp to make HTML files and a HTML help project" 34 | @echo " qthelp to make HTML files and a qthelp project" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " changes to make an overview of all changed/added/deprecated items" 38 | @echo " linkcheck to check all external links for integrity" 39 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 40 | @echo " coverage to run coverage check of the documentation (if enabled)" 41 | @echo " install to make the html and push it online" 42 | 43 | .PHONY: clean 44 | 45 | clean: 46 | rm -rf $(BUILDDIR)/* 47 | rm -rf auto_examples/ 48 | rm -rf generated/* 49 | rm -rf modules/* 50 | 51 | html-noplot: 52 | $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 53 | @echo 54 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 55 | 56 | .PHONY: html 57 | html: 58 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 61 | 62 | .PHONY: dirhtml 63 | dirhtml: 64 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 65 | @echo 66 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 67 | 68 | .PHONY: singlehtml 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | .PHONY: pickle 75 | pickle: 76 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 77 | @echo 78 | @echo "Build finished; now you can process the pickle files." 79 | 80 | .PHONY: htmlhelp 81 | htmlhelp: 82 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 83 | @echo 84 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 85 | ".hhp project file in $(BUILDDIR)/htmlhelp." 86 | 87 | .PHONY: qthelp 88 | qthelp: 89 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 90 | @echo 91 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 92 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 93 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/smica.qhcp" 94 | @echo "To view the help file:" 95 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/smica.qhc" 96 | 97 | .PHONY: latex 98 | latex: 99 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 100 | @echo 101 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 102 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 103 | "(use \`make latexpdf' here to do that automatically)." 104 | 105 | .PHONY: latexpdf 106 | latexpdf: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo "Running LaTeX files through pdflatex..." 109 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 110 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 111 | 112 | .PHONY: changes 113 | changes: 114 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 115 | @echo 116 | @echo "The overview file is in $(BUILDDIR)/changes." 117 | 118 | .PHONY: linkcheck 119 | linkcheck: 120 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 121 | @echo 122 | @echo "Link check complete; look for any errors in the above output " \ 123 | "or in $(BUILDDIR)/linkcheck/output.txt." 124 | 125 | .PHONY: doctest 126 | doctest: 127 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 128 | @echo "Testing of doctests in the sources finished, look at the " \ 129 | "results in $(BUILDDIR)/doctest/output.txt." 130 | 131 | .PHONY: coverage 132 | coverage: 133 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 134 | @echo "Testing of coverage in the sources finished, look at the " \ 135 | "results in $(BUILDDIR)/coverage/python.txt." 136 | 137 | install: 138 | touch $(OUTPUTDIR)/.nojekyll 139 | ghp-import -m "Generate Pelican site [ci skip]" -b $(GITHUB_PAGES_BRANCH) $(OUTPUTDIR) 140 | git push origin $(GITHUB_PAGES_BRANCH) 141 | -------------------------------------------------------------------------------- /matlab_octave/qndiag.m: -------------------------------------------------------------------------------- 1 | function [D, B, infos] = qndiag(C, varargin) 2 | % Joint diagonalization of matrices using the quasi-Newton method 3 | % 4 | % The algorithm is detailed in: 5 | % 6 | % P. Ablin, J.F. Cardoso and A. Gramfort. Beyond Pham’s algorithm 7 | % for joint diagonalization. Proc. ESANN 2019. 8 | % https://www.elen.ucl.ac.be/Proceedings/esann/esannpdf/es2019-119.pdf 9 | % https://hal.archives-ouvertes.fr/hal-01936887v1 10 | % https://arxiv.org/abs/1811.11433 11 | % 12 | % The function takes as input a set of matrices of size `(p, p)`, stored as 13 | % a `(n, p, p)` array, `C`. It outputs a `(p, p)` array, `B`, such that the 14 | % matrices `B * C(i,:,:) * B'` are as diagonal as possible. 15 | % 16 | % There are several optional parameters which can be provided in the 17 | % varargin variable. 18 | % 19 | % Optional parameters: 20 | % -------------------- 21 | % 'B0' Initial point for the algorithm. 22 | % If absent, a whitener is used. 23 | % 'weights' Weights for each matrix in the loss: 24 | % L = sum(weights * KL(C, C')). 25 | % No weighting (weights = 1) by default. 26 | % 'maxiter' (int) Maximum number of iterations to perform. 27 | % Default : 1000 28 | % 29 | % 'tol' (float) A positive scalar giving the tolerance at 30 | % which the algorithm is considered to have converged. 31 | % The algorithm stops when |gradient| < tol. 32 | % Default : 1e-6 33 | % 34 | % lambda_min (float) A positive regularization scalar. Each 35 | % eigenvalue of the Hessian approximation below 36 | % lambda_min is set to lambda_min. 37 | % 38 | % max_ls_tries (int), Maximum number of line-search tries to 39 | % perform. 40 | % 41 | % return_B_list (bool) Chooses whether or not to return the list 42 | % of iterates. 43 | % 44 | % verbose (bool) Prints informations about the state of the 45 | % algorithm if True. 46 | % 47 | % Returns 48 | % ------- 49 | % D : Set of matrices jointly diagonalized 50 | % B : Estimated joint diagonalizer matrix. 51 | % infos : structure containing monitoring informations, containing the times, 52 | % gradient norms and objective values. 53 | % 54 | % Example: 55 | % -------- 56 | % 57 | % [D, B] = qndiag(C, 'maxiter', 100, 'tol', 1e-5) 58 | % 59 | % Authors: Pierre Ablin 60 | % Alexandre Gramfort 61 | % 62 | % License: MIT 63 | 64 | % First tests 65 | 66 | if nargin == 0, 67 | error('No signal provided'); 68 | end 69 | 70 | if length(size(C)) ~= 3, 71 | error('Input C should be 3 dimensional'); 72 | end 73 | 74 | if ~isa (C, 'double'), 75 | fprintf ('Converting input data to double...'); 76 | X = double(X); 77 | end 78 | 79 | % Default parameters 80 | 81 | C_mean = squeeze(mean(C, 1)); 82 | [p, d] = eigs(C_mean, size(C_mean, 1)); 83 | p = fliplr(p); 84 | d = flip(diag(d)); 85 | B = p' ./ repmat(sqrt(d), 1, size(p, 1)); 86 | 87 | max_iter = 1000; 88 | tol = 1e-6; 89 | lambda_min = 1e-4; 90 | max_ls_tries = 10; 91 | return_B_list = false; 92 | verbose = false; 93 | weights = []; 94 | 95 | % Read varargin 96 | 97 | if mod(length(varargin), 2) == 1, 98 | error('There should be an even number of optional parameters'); 99 | end 100 | 101 | for i = 1:2:length(varargin) 102 | param = lower(varargin{i}); 103 | value = varargin{i + 1}; 104 | switch param 105 | case 'B0' 106 | B = value; 107 | case 'max_iter' 108 | max_iter = value; 109 | case 'tol' 110 | tol = value; 111 | case 'weights' 112 | weights = value / mean(value(:)); 113 | case 'lambda_min' 114 | lambda_min = value; 115 | case 'max_ls_tries' 116 | max_ls_tries = value; 117 | case 'return_B_list' 118 | return_B_list = value; 119 | case 'verbose' 120 | verbose = value; 121 | otherwise 122 | error(['Parameter ''' param ''' unknown']) 123 | end 124 | end 125 | 126 | [n_samples, n_features, ~] = size(C); 127 | 128 | D = transform_set(B, C, false); 129 | current_loss = NaN; 130 | 131 | % Monitoring 132 | if return_B_list 133 | B_list = [] 134 | end 135 | 136 | t_list = []; 137 | gradient_list = []; 138 | loss_list = []; 139 | 140 | if verbose 141 | print('Running quasi-Newton for joint diagonalization'); 142 | print('iter | obj | gradient'); 143 | end 144 | 145 | for t=1:max_iter 146 | if return_B_list 147 | B_list(k) = B; 148 | end 149 | 150 | diagonals = zeros(n_samples, n_features); 151 | for k=1:n_samples 152 | diagonals(k, :) = diag(squeeze(D(k, :, :))); 153 | end 154 | 155 | % Gradient 156 | if isempty(weights) 157 | G = squeeze(mean(bsxfun(@rdivide, D, ... 158 | reshape(diagonals, n_samples, n_features, 1)), ... 159 | 1)) - eye(n_features); 160 | else 161 | G = squeeze(mean(... 162 | bsxfun(@times, ... 163 | reshape(weights, n_samples, 1, 1), ... 164 | bsxfun(@rdivide, D, ... 165 | reshape(diagonals, n_samples, n_features, 1))), ... 166 | 1)) - eye(n_features); 167 | end 168 | g_norm = norm(G); 169 | if g_norm < tol 170 | break 171 | end 172 | 173 | % Hessian coefficients 174 | if isempty(weights) 175 | h = mean(bsxfun(@rdivide, ... 176 | reshape(diagonals, n_samples, 1, n_features), ... 177 | reshape(diagonals, n_samples, n_features, 1)), 1); 178 | else 179 | h = mean(bsxfun(@times, ... 180 | reshape(weights, n_samples, 1, 1), ... 181 | bsxfun(@rdivide, ... 182 | reshape(diagonals, n_samples, 1, n_features), ... 183 | reshape(diagonals, n_samples, n_features, 1))), ... 184 | 1); 185 | end 186 | h = squeeze(h); 187 | 188 | % Quasi-Newton's direction 189 | dt = h .* h' - 1.; 190 | dt(dt < lambda_min) = lambda_min; % Regularize 191 | direction = -(G .* h' - G') ./ dt; 192 | 193 | % Line search 194 | [success, new_D, new_B, new_loss, direction] = ... 195 | linesearch(D, B, direction, current_loss, max_ls_tries, weights); 196 | D = new_D; 197 | B = new_B; 198 | current_loss = new_loss; 199 | 200 | % Monitoring 201 | gradient_list(t) = g_norm; 202 | loss_list(t) = current_loss; 203 | if verbose 204 | print(sprintf('%d - %.2e - %.2e', t, current_loss, g_norm)) 205 | end 206 | end 207 | 208 | infos = struct(); 209 | infos.t_list = t_list; 210 | infos.gradient_list = gradient_list; 211 | infos.loss_list = loss_list; 212 | 213 | if return_B_list 214 | infos.B_list = B_list 215 | end 216 | 217 | end 218 | 219 | function [op] = transform_set(M, D, diag_only) 220 | [n, p, ~] = size(D); 221 | if ~diag_only 222 | op = zeros(n, p, p); 223 | for k=1:n 224 | op(k, :, :) = M * squeeze(D(k, :, :)) * M'; 225 | end 226 | else 227 | op = zeros(n, p); 228 | for k=1:n 229 | op(k, :) = sum(M .* (squeeze(D(k, :, :)) * M'), 1); 230 | end 231 | end 232 | end 233 | 234 | function [v] = slogdet(A) 235 | v = log(abs(det(A))); 236 | end 237 | 238 | function [out] = loss(B, D, is_diag, weights) 239 | [n, p, ~] = size(D); 240 | if ~is_diag 241 | diagonals = zeros(n, p); 242 | for k=1:n 243 | diagonals(k, :) = diag(squeeze(D(k, :, :))); 244 | end 245 | else 246 | diagonals = D; 247 | end 248 | logdet = -slogdet(B); 249 | if ~isempty(weights) 250 | diagonals = bsxfun(@times, diagonals, reshape(weights, n, 1)); 251 | end 252 | out = logdet + 0.5 * sum(log(diagonals(:))) / n; 253 | end 254 | 255 | function [success, new_D, new_B, new_loss, delta] = linesearch(D, B, direction, current_loss, n_ls_tries, weights) 256 | [n, p, ~] = size(D); 257 | step = 1.; 258 | if current_loss == NaN 259 | current_loss = loss(B, D, false); 260 | end 261 | success = false; 262 | for n=1:n_ls_tries 263 | M = eye(p) + step * direction; 264 | new_D = transform_set(M, D, true); 265 | new_B = M * B; 266 | new_loss = loss(new_B, new_D, true, weights); 267 | 268 | if new_loss < current_loss 269 | success = true; 270 | break 271 | end 272 | step = step / 2; 273 | end 274 | new_D = transform_set(M, D, false); 275 | delta = step * direction; 276 | end 277 | -------------------------------------------------------------------------------- /qndiag/qndiag.py: -------------------------------------------------------------------------------- 1 | # Authors: Pierre Ablin 2 | # 3 | # License: MIT 4 | 5 | from time import time 6 | 7 | import numpy as np 8 | from scipy.linalg import expm 9 | 10 | 11 | def qndiag(C, B0=None, weights=None, max_iter=1000, tol=1e-6, 12 | lambda_min=1e-4, max_ls_tries=10, diag_only=False, 13 | return_B_list=False, check_sympos=True, verbose=False, ortho=False): 14 | """Joint diagonalization of matrices using the quasi-Newton method 15 | 16 | Parameters 17 | ---------- 18 | C : array-like, shape (n_samples, n_features, n_features) 19 | Set of matrices to be jointly diagonalized. C[0] is the first matrix, 20 | etc... 21 | 22 | B0 : None | array-like, shape (n_features, n_features) 23 | Initial point for the algorithm. If None, a whitener is used. 24 | 25 | weights : None | array-like, shape (n_samples,) 26 | Weights for each matrix in the loss: 27 | L = sum(weights * KL(C, C')) / sum(weights). 28 | No weighting (weights = 1) by default. 29 | 30 | max_iter : int, optional 31 | Maximum number of iterations to perform. 32 | 33 | tol : float, optional 34 | A positive scalar giving the tolerance at which the 35 | algorithm is considered to have converged. The algorithm stops when 36 | `|gradient| < tol`. 37 | 38 | lambda_min : float, optional 39 | A positive regularization scalar. Each eigenvalue of the Hessian 40 | approximation below lambda_min is set to lambda_min. 41 | 42 | max_ls_tries : int, optional 43 | Maximum number of line-search tries to perform. 44 | 45 | diag_only : bool, optional 46 | If true, the line search is done by computing only the diagonals of the 47 | dataset. The dataset is then computed after the line search. 48 | Taking diag_only = True might be faster than diag_only=False 49 | when the matrices are large (n_features > 200) 50 | 51 | return_B_list : bool, optional 52 | Chooses whether or not to return the list of iterates. 53 | 54 | check_sympos : bool, optional 55 | Chooses whether to check that the provided dataset contains only 56 | symmetric positive definite matrices, as should be. 57 | 58 | verbose : bool, optional 59 | Prints informations about the state of the algorithm if True. 60 | 61 | ortho : bool, optional 62 | If true, performs joint diagonlization under orthogonal constrains. 63 | 64 | Returns 65 | ------- 66 | D : array-like, shape (n_samples, n_features, n_features) 67 | Set of matrices jointly diagonalized 68 | 69 | B : array, shape (n_features, n_features) 70 | Estimated joint diagonalizer matrix. 71 | 72 | infos : dict 73 | Dictionnary of monitoring informations, containing the times, 74 | gradient norms and objective values. 75 | 76 | References 77 | ---------- 78 | P. Ablin, J.F. Cardoso and A. Gramfort. Beyond Pham's algorithm 79 | for joint diagonalization. Proc. ESANN 2019. 80 | https://www.elen.ucl.ac.be/Proceedings/esann/esannpdf/es2019-119.pdf 81 | https://hal.archives-ouvertes.fr/hal-01936887v1 82 | https://arxiv.org/abs/1811.11433 83 | """ 84 | t0 = time() 85 | if not isinstance(C, np.ndarray): 86 | raise TypeError('Input tensor C should be a numpy array.') 87 | if C.ndim != 3: 88 | raise ValueError( 89 | f'Input tensor C should have 3 dimensions (Got {C.ndim})' 90 | ) 91 | if C.shape[1] != C.shape[2]: 92 | raise ValueError('The last two dimensions of C should be the same') 93 | if check_sympos: 94 | if not np.allclose(C, C.swapaxes(-1, -2)): 95 | raise ValueError('C does not contain only symmetric matrices') 96 | try: # Check positivity using Cholesky 97 | np.linalg.cholesky(C) 98 | except np.linalg.LinAlgError: 99 | raise ValueError('C contains some non positive matrices') 100 | 101 | n_samples, n_features, _ = C.shape 102 | if B0 is None: 103 | C_mean = np.mean(C, axis=0) 104 | d, p = np.linalg.eigh(C_mean) 105 | B = p.T / np.sqrt(d[:, None]) 106 | else: 107 | B = B0 108 | if weights is not None: # normalize 109 | weights_ = weights / np.mean(weights) 110 | else: 111 | weights_ = None 112 | 113 | Bu, _, Bv = np.linalg.svd(B, full_matrices=False) 114 | B = Bu.dot(Bv) 115 | D = transform_set(B, C) 116 | current_loss = None 117 | 118 | # Monitoring 119 | if return_B_list: 120 | B_list = [] 121 | t_list = [] 122 | gradient_list = [] 123 | loss_list = [] 124 | if verbose: 125 | print('Running quasi-Newton for joint diagonalization') 126 | print(' | '.join([name.center(8) for name in 127 | ["iter", "obj", "gradient"]])) 128 | 129 | for t in range(max_iter): 130 | if return_B_list: 131 | B_list.append(B.copy()) 132 | t_list.append(time() - t0) 133 | diagonals = np.diagonal(D, axis1=1, axis2=2) 134 | # Gradient 135 | G = np.average(D / diagonals[:, :, None], weights=weights_, 136 | axis=0) - np.eye(n_features) 137 | if ortho: 138 | G = G - G.T 139 | g_norm = np.linalg.norm(G) 140 | if g_norm < tol * np.sqrt(n_features): # rescale by identity 141 | break 142 | 143 | # Hessian coefficients 144 | h = np.average( 145 | diagonals[:, None, :] / diagonals[:, :, None], 146 | weights=weights_, axis=0) 147 | if ortho: 148 | det = h + h.T - 2 149 | det[det < lambda_min] = lambda_min # Regularize 150 | direction = -G / det 151 | else: 152 | # Quasi-Newton's direction 153 | det = h * h.T - 1. 154 | det[det < lambda_min] = lambda_min # Regularize 155 | direction = -(G * h.T - G.T) / det 156 | 157 | # Line search 158 | success, new_D, new_B, new_loss, direction =\ 159 | _linesearch(D, B, direction, current_loss, max_ls_tries, diag_only, 160 | weights_, ortho=ortho) 161 | D = new_D 162 | B = new_B 163 | current_loss = new_loss 164 | 165 | # Monitoring 166 | gradient_list.append(g_norm) 167 | loss_list.append(current_loss) 168 | if verbose: 169 | print(' | '.join([("%d" % (t + 1)).rjust(8), 170 | ("%.2e" % current_loss).rjust(8), 171 | ("%.2e" % g_norm).rjust(8)])) 172 | infos = {'t_list': t_list, 'gradient_list': gradient_list, 173 | 'loss_list': loss_list} 174 | if return_B_list: 175 | infos['B_list'] = B_list 176 | return B, infos 177 | 178 | 179 | def transform_set(M, D, diag_only=False): 180 | """Transform a set of matrices 181 | 182 | Returns matrices D' such that 183 | D'[i] = M x D[i] x M.T 184 | 185 | Parameters 186 | ---------- 187 | M : array-like, shape (n_features, n_features) 188 | The transform matrix 189 | 190 | D : array-like, shape (n_samples, n_features, n_features) 191 | The set of covariance matrices 192 | 193 | diag_only : bool, optional 194 | Whether to return the diagonal of the dataset only 195 | 196 | Returns 197 | ------- 198 | op : array-like 199 | Array of shape (n_samples, n_features, n_features) 200 | if diag_only is False, else (n_samples, n_features) 201 | The transformed set of covariances 202 | 203 | """ 204 | n, p, _ = D.shape 205 | if not diag_only: 206 | op = np.zeros((n, p, p)) 207 | for i, d in enumerate(D): 208 | op[i] = M.dot(d.dot(M.T)) 209 | else: 210 | op = np.zeros((n, p)) 211 | for i, d in enumerate(D): 212 | op[i] = np.sum(M * d.dot(M.T), axis=0) 213 | return op 214 | 215 | 216 | def loss(B, D, is_diag=False, weights=None): 217 | n, p = D.shape[:2] 218 | if not is_diag: 219 | diagonals = np.diagonal(D, axis1=1, axis2=2) 220 | else: 221 | diagonals = D 222 | logdet = -np.linalg.slogdet(B)[1] 223 | if weights is None: 224 | return logdet + 0.5 * np.sum(np.log(diagonals)) / n 225 | else: 226 | return logdet + 0.5 * np.sum(weights[:, None] * np.log(diagonals)) / n 227 | 228 | 229 | def gradient(D, weights=None): 230 | n, p, _ = D.shape 231 | diagonals = np.diagonal(D, axis1=1, axis2=2) 232 | grad = np.average(D / diagonals[:, :, None], weights=weights, axis=0) 233 | grad.flat[::p + 1] -= 1 # equivalent to - np.eye(p) 234 | return grad 235 | 236 | 237 | def _linesearch(D, B, direction, current_loss, n_ls_tries, diag_only, 238 | weights, ortho): 239 | n, p, _ = D.shape 240 | step = 1. 241 | if current_loss is None: 242 | current_loss = loss(B, D) 243 | for n in range(n_ls_tries): 244 | if ortho: 245 | M = expm(step * direction) 246 | else: 247 | M = np.eye(p) + step * direction 248 | 249 | new_D = transform_set(M, D, diag_only=diag_only) 250 | new_B = np.dot(M, B) 251 | new_loss = loss(new_B, new_D, diag_only, weights) 252 | if new_loss < current_loss: 253 | success = True 254 | break 255 | step /= 2. 256 | else: 257 | success = False 258 | # Compute new value of D if only its diagonal was computed 259 | if diag_only: 260 | new_D = transform_set(M, D, diag_only=False) 261 | return success, new_D, new_B, new_loss, step * direction 262 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # qndiag documentation build configuration file, created by 4 | # sphinx-quickstart on Mon May 23 16:22:52 2016. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | # import os 17 | import sphinx_gallery # noqa 18 | import sphinx_bootstrap_theme 19 | from numpydoc import numpydoc, docscrape # noqa 20 | 21 | # If extensions (or modules to document with autodoc) are in another directory, 22 | # add these directories to sys.path here. If the directory is relative to the 23 | # documentation root, use os.path.abspath to make it absolute, like shown here. 24 | #sys.path.insert(0, os.path.abspath('.')) 25 | 26 | # -- General configuration ------------------------------------------------ 27 | 28 | # If your documentation needs a minimal Sphinx version, state it here. 29 | #needs_sphinx = '1.0' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | 'sphinx.ext.autodoc', 36 | 'sphinx.ext.autosummary', 37 | 'sphinx.ext.doctest', 38 | 'sphinx.ext.intersphinx', 39 | 'sphinx.ext.mathjax', 40 | 'sphinx_gallery.gen_gallery', 41 | 'numpydoc', 42 | ] 43 | 44 | # generate autosummary even if no references 45 | autosummary_generate = True 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # The suffix(es) of source filenames. 51 | # You can specify multiple suffix as a list of string: 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = '.rst' 54 | 55 | # The encoding of source files. 56 | #source_encoding = 'utf-8-sig' 57 | 58 | # The master toctree document. 59 | master_doc = 'index' 60 | 61 | # General information about the project. 62 | project = u'qndiag' 63 | copyright = u'2020-2021, Pierre Ablin' 64 | author = u'Pierre Ablin' 65 | 66 | # The version info for the project you're documenting, acts as replacement for 67 | # |version| and |release|, also used in various other places throughout the 68 | # built documents. 69 | # 70 | # The short X.Y version. 71 | from qndiag import __version__ as version # noqa 72 | print(version) 73 | # version = u'0.1.dev0' 74 | # The full version, including alpha/beta/rc tags. 75 | release = version 76 | # release = u'0.1.dev0' 77 | 78 | # The language for content autogenerated by Sphinx. Refer to documentation 79 | # for a list of supported languages. 80 | # 81 | # This is also used if you do content translation via gettext catalogs. 82 | # Usually you set "language" from the command line for these cases. 83 | language = None 84 | 85 | # There are two options for replacing |today|: either, you set today to some 86 | # non-false value, then it is used: 87 | #today = '' 88 | # Else, today_fmt is used as the format for a strftime call. 89 | #today_fmt = '%B %d, %Y' 90 | 91 | # List of patterns, relative to source directory, that match files and 92 | # directories to ignore when looking for source files. 93 | exclude_patterns = ['_build'] 94 | 95 | # The reST default role (used for this markup: `text`) to use for all 96 | # documents. 97 | #default_role = None 98 | 99 | # If true, '()' will be appended to :func: etc. cross-reference text. 100 | #add_function_parentheses = True 101 | 102 | # If true, the current module name will be prepended to all description 103 | # unit titles (such as .. function::). 104 | #add_module_names = True 105 | 106 | # If true, sectionauthor and moduleauthor directives will be shown in the 107 | # output. They are ignored by default. 108 | #show_authors = False 109 | 110 | # The name of the Pygments (syntax highlighting) style to use. 111 | pygments_style = 'sphinx' 112 | 113 | # A list of ignored prefixes for module index sorting. 114 | #modindex_common_prefix = [] 115 | 116 | # If true, keep warnings as "system message" paragraphs in the built documents. 117 | #keep_warnings = False 118 | 119 | # If true, `todo` and `todoList` produce output, else they produce nothing. 120 | todo_include_todos = False 121 | 122 | 123 | # -- Options for HTML output ---------------------------------------------- 124 | 125 | # The theme to use for HTML and HTML Help pages. See the documentation for 126 | # a list of builtin themes. 127 | html_theme = 'bootstrap' 128 | 129 | # Theme options are theme-specific and customize the look and feel of a theme 130 | # further. For a list of options available for each theme, see the 131 | # documentation. 132 | html_theme_options = { 133 | 'navbar_sidebarrel': False, 134 | 'navbar_links': [ 135 | ("Examples", "auto_examples/index"), 136 | ("API", "api"), 137 | ("What's new", "whats_new"), 138 | ("GitHub", "https://github.com/pierreablin/qndiag", True) 139 | ], 140 | 'bootswatch_theme': "united" 141 | } 142 | 143 | # Add any paths that contain custom themes here, relative to this directory. 144 | html_theme_path = sphinx_bootstrap_theme.get_html_theme_path() 145 | 146 | # The name for this set of Sphinx documents. If None, it defaults to 147 | # " v documentation". 148 | #html_title = None 149 | 150 | # A shorter title for the navigation bar. Default is the same as html_title. 151 | #html_short_title = None 152 | 153 | # The name of an image file (relative to this directory) to place at the top 154 | # of the sidebar. 155 | #html_logo = None 156 | 157 | # The name of an image file (within the static path) to use as favicon of the 158 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 159 | # pixels large. 160 | html_favicon = 'logo/qndiag_logo.ico' 161 | 162 | # Add any paths that contain custom static files (such as style sheets) here, 163 | # relative to this directory. They are copied after the builtin static files, 164 | # so a file named "default.css" will overwrite the builtin "default.css". 165 | # html_static_path = ['_static'] 166 | 167 | # Add any extra paths that contain custom files (such as robots.txt or 168 | # .htaccess) here, relative to this directory. These files are copied 169 | # directly to the root of the documentation. 170 | #html_extra_path = [] 171 | 172 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 173 | # using the given strftime format. 174 | #html_last_updated_fmt = '%b %d, %Y' 175 | 176 | # If true, SmartyPants will be used to convert quotes and dashes to 177 | # typographically correct entities. 178 | #html_use_smartypants = True 179 | 180 | # Custom sidebar templates, maps document names to template names. 181 | #html_sidebars = {} 182 | 183 | # Additional templates that should be rendered to pages, maps page names to 184 | # template names. 185 | #html_additional_pages = {} 186 | 187 | # If false, no module index is generated. 188 | #html_domain_indices = True 189 | 190 | # If false, no index is generated. 191 | #html_use_index = True 192 | 193 | # If true, the index is split into individual pages for each letter. 194 | #html_split_index = False 195 | 196 | # If true, links to the reST sources are added to the pages. 197 | html_show_sourcelink = False 198 | 199 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 200 | #html_show_sphinx = True 201 | 202 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 203 | #html_show_copyright = True 204 | 205 | # If true, an OpenSearch description file will be output, and all pages will 206 | # contain a tag referring to it. The value of this option must be the 207 | # base URL from which the finished HTML is served. 208 | #html_use_opensearch = '' 209 | 210 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 211 | #html_file_suffix = None 212 | 213 | # Language to be used for generating the HTML full-text search index. 214 | # Sphinx supports the following languages: 215 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' 216 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' 217 | #html_search_language = 'en' 218 | 219 | # A dictionary with options for the search language support, empty by default. 220 | # Now only 'ja' uses this config value 221 | #html_search_options = {'type': 'default'} 222 | 223 | # The name of a javascript file (relative to the configuration directory) that 224 | # implements a search results scorer. If empty, the default will be used. 225 | #html_search_scorer = 'scorer.js' 226 | 227 | # Output file base name for HTML help builder. 228 | htmlhelp_basename = 'qndiagdoc' 229 | 230 | # -- Options for LaTeX output --------------------------------------------- 231 | 232 | latex_elements = { 233 | # The paper size ('letterpaper' or 'a4paper'). 234 | #'papersize': 'letterpaper', 235 | 236 | # The font size ('10pt', '11pt' or '12pt'). 237 | #'pointsize': '10pt', 238 | 239 | # Additional stuff for the LaTeX preamble. 240 | #'preamble': '', 241 | 242 | # Latex figure (float) alignment 243 | #'figure_align': 'htbp', 244 | } 245 | 246 | # Grouping the document tree into LaTeX files. List of tuples 247 | # (source start file, target name, title, 248 | # author, documentclass [howto, manual, or own class]). 249 | latex_documents = [ 250 | (master_doc, 'qndiag.tex', u'qndiag Documentation', 251 | u'Pierre Ablin', 'manual'), 252 | ] 253 | 254 | # The name of an image file (relative to this directory) to place at the top of 255 | # the title page. 256 | #latex_logo = None 257 | 258 | # For "manual" documents, if this is true, then toplevel headings are parts, 259 | # not chapters. 260 | #latex_use_parts = False 261 | 262 | # If true, show page references after internal links. 263 | #latex_show_pagerefs = False 264 | 265 | # If true, show URL addresses after external links. 266 | #latex_show_urls = False 267 | 268 | # Documents to append as an appendix to all manuals. 269 | #latex_appendices = [] 270 | 271 | # If false, no module index is generated. 272 | #latex_domain_indices = True 273 | 274 | 275 | # -- Options for manual page output --------------------------------------- 276 | 277 | # One entry per manual page. List of tuples 278 | # (source start file, name, description, authors, manual section). 279 | man_pages = [ 280 | (master_doc, 'qndiag', u'QNDIAG Documentation', 281 | [author], 1) 282 | ] 283 | 284 | # If true, show URL addresses after external links. 285 | #man_show_urls = False 286 | 287 | 288 | # -- Options for Texinfo output ------------------------------------------- 289 | 290 | # Grouping the document tree into Texinfo files. List of tuples 291 | # (source start file, target name, title, author, 292 | # dir menu entry, description, category) 293 | texinfo_documents = [ 294 | (master_doc, 'qndiag', u'qndiag Documentation', 295 | author, 'qndiag', 'One line description of project.', 296 | 'Miscellaneous'), 297 | ] 298 | 299 | # Documents to append as an appendix to all manuals. 300 | #texinfo_appendices = [] 301 | 302 | # If false, no module index is generated. 303 | #texinfo_domain_indices = True 304 | 305 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 306 | #texinfo_show_urls = 'footnote' 307 | 308 | # If true, do not generate a @detailmenu in the "Top" node's menu. 309 | #texinfo_no_detailmenu = False 310 | 311 | 312 | # Example configuration for intersphinx: refer to the Python standard library. 313 | intersphinx_mapping = { 314 | 'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None), 315 | 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 316 | # 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), 317 | 'matplotlib': ('https://matplotlib.org/', None), 318 | # 'sklearn': ('http://scikit-learn.org/stable', None), 319 | } 320 | 321 | sphinx_gallery_conf = { 322 | 'backreferences_dir': 'gen_modules/backreferences', 323 | 'doc_module': ('qndiag', 'numpy'), 324 | 'examples_dirs': '../examples', 325 | 'gallery_dirs': 'auto_examples', 326 | 'reference_url': { 327 | 'qndiag': None, 328 | } 329 | } 330 | --------------------------------------------------------------------------------