├── .gitignore ├── .readthedocs.yml ├── .travis.yml ├── .zenodo.json ├── CHANGELOG.rst ├── CONTRIBUTING.rst ├── LICENSE.txt ├── MANIFEST.in ├── Makefile ├── README.rst ├── doc ├── changelog.rst ├── conf.py ├── contributing.rst ├── index.rst ├── reference │ ├── features.rst │ ├── filters.rst │ ├── graphs.rst │ ├── index.rst │ ├── learning.rst │ ├── optimization.rst │ ├── plotting.rst │ ├── reduction.rst │ └── utils.rst ├── references.bib ├── references.rst └── tutorials │ ├── index.rst │ ├── intro.rst │ ├── optimization.rst │ ├── pyramid.rst │ └── wavelet.rst ├── examples ├── README.txt ├── eigenvalue_concentration.py ├── eigenvector_localization.py ├── filtering.py ├── fourier_basis.py ├── fourier_transform.py ├── heat_diffusion.py ├── kernel_localization.py ├── playground.ipynb ├── random_walk.py └── wave_propagation.py ├── postBuild ├── pygsp ├── __init__.py ├── data │ ├── pointclouds │ │ ├── airfoil.mat │ │ ├── bunny.mat │ │ ├── david500.mat │ │ ├── david64.mat │ │ ├── logogsp.mat │ │ ├── minnesota.mat │ │ └── two_moons.mat │ ├── readme_example_filter.png │ └── readme_example_graph.png ├── features.py ├── filters │ ├── __init__.py │ ├── abspline.py │ ├── approximations.py │ ├── expwin.py │ ├── filter.py │ ├── gabor.py │ ├── halfcosine.py │ ├── heat.py │ ├── held.py │ ├── itersine.py │ ├── mexicanhat.py │ ├── meyer.py │ ├── modulation.py │ ├── papadakis.py │ ├── rectangular.py │ ├── regular.py │ ├── simoncelli.py │ ├── simpletight.py │ └── wave.py ├── graphs │ ├── __init__.py │ ├── _io.py │ ├── _layout.py │ ├── airfoil.py │ ├── barabasialbert.py │ ├── comet.py │ ├── community.py │ ├── davidsensornet.py │ ├── difference.py │ ├── erdosrenyi.py │ ├── fourier.py │ ├── fullconnected.py │ ├── graph.py │ ├── grid2d.py │ ├── linegraph.py │ ├── logo.py │ ├── lowstretchtree.py │ ├── minnesota.py │ ├── nngraphs │ │ ├── __init__.py │ │ ├── bunny.py │ │ ├── cube.py │ │ ├── grid2dimgpatches.py │ │ ├── imgpatches.py │ │ ├── nngraph.py │ │ ├── sensor.py │ │ ├── sphere.py │ │ └── twomoons.py │ ├── path.py │ ├── randomregular.py │ ├── randomring.py │ ├── ring.py │ ├── star.py │ ├── stochasticblockmodel.py │ ├── swissroll.py │ └── torus.py ├── learning.py ├── optimization.py ├── plotting.py ├── reduction.py ├── tests │ ├── __init__.py │ ├── test_docstrings.py │ ├── test_filters.py │ ├── test_graphs.py │ ├── test_learning.py │ ├── test_plotting.py │ └── test_utils.py └── utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | # Packages 4 | *.egg 5 | *.egg-info 6 | dist 7 | build 8 | eggs 9 | parts 10 | bin 11 | var 12 | sdist 13 | develop-eggs 14 | .installed.cfg 15 | 16 | # Installer logs 17 | pip-log.txt 18 | 19 | # Coverage reports 20 | .coverage 21 | htmlcov 22 | 23 | # Complexity 24 | output/*.html 25 | output/*/index.html 26 | 27 | # Sphinx documentation 28 | /doc/_build/ 29 | /doc/examples/ 30 | /doc/backrefs/ 31 | 32 | # Vim swap files 33 | .*.swp 34 | 35 | # Mac OS garbage 36 | .DS_Store 37 | 38 | # Jupyter notebook 39 | .ipynb_checkpoints/ 40 | 41 | # Visual Studio Code 42 | .vscode/ 43 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | formats: 4 | - htmlzip 5 | 6 | sphinx: 7 | builder: html 8 | configuration: doc/conf.py 9 | 10 | python: 11 | install: 12 | - method: pip 13 | path: . 14 | extra_requirements: 15 | - dev 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: bionic 2 | language: python 3 | 4 | python: 5 | - 3.7 6 | - 3.8 7 | - 3.9 8 | 9 | addons: 10 | apt: 11 | sources: 12 | - sourceline: 'deb https://downloads.skewed.de/apt bionic main' 13 | key_url: 'https://keys.openpgp.org/vks/v1/by-keyid/612DEFB798507F25' 14 | packages: 15 | - python3-graph-tool 16 | - libqt5gui5 # pyqt5>5.11 fails to load the xcb platform plugin without it 17 | 18 | install: 19 | - pip install --upgrade pip setuptools wheel # install with latest tools 20 | - pip install --upgrade --upgrade-strategy eager .[dev] # get latest deps 21 | 22 | before_script: 23 | # As graph-tool cannot be installed by pip, link to the system installation 24 | # from the virtual environment. 25 | - ln -s "/usr/lib/python3/dist-packages/graph_tool" $(python -c "import site; print(site.getsitepackages()[0])") 26 | 27 | script: 28 | # - make lint 29 | - make test 30 | - make doc 31 | 32 | after_success: 33 | - coveralls 34 | -------------------------------------------------------------------------------- /.zenodo.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "PyGSP: Graph Signal Processing in Python", 3 | "description": "The PyGSP facilitates a wide variety of operations on graphs, like computing their Fourier basis, filtering or interpolating signals, plotting graphs, signals, and filters.", 4 | "upload_type": "software", 5 | "license": "BSD-3-Clause", 6 | "access_right": "open", 7 | "creators": [ 8 | { 9 | "name": "Micha\u00ebl Defferrard", 10 | "affiliation": "EPFL", 11 | "orcid": "0000-0002-6028-9024" 12 | }, 13 | { 14 | "name": "Lionel Martin", 15 | "affiliation": "EPFL" 16 | }, 17 | { 18 | "name": "Rodrigo Pena", 19 | "affiliation": "EPFL" 20 | }, 21 | { 22 | "name": "Nathana\u00ebl Perraudin", 23 | "affiliation": "EPFL", 24 | "orcid": "0000-0001-8285-1308" 25 | } 26 | ], 27 | "related_identifiers": [ 28 | { 29 | "scheme": "url", 30 | "identifier": "https://github.com/epfl-lts2/pygsp", 31 | "relation": "isSupplementTo" 32 | }, 33 | { 34 | "scheme": "doi", 35 | "identifier": "10.5281/zenodo.1003157", 36 | "relation": "isPartOf" 37 | } 38 | ] 39 | } 40 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | Contributions are welcome, and they are greatly appreciated! The development of 6 | this package takes place on `GitHub `_. 7 | Issues, bugs, and feature requests should be reported `there 8 | `_. 9 | Code and documentation can be improved by submitting a `pull request 10 | `_. Please add documentation and 11 | tests for any new code. 12 | 13 | The package can be set up (ideally in a fresh virtual environment) for local 14 | development with the following:: 15 | 16 | $ git clone https://github.com/epfl-lts2/pygsp.git 17 | $ pip install --upgrade --editable pygsp[dev] 18 | 19 | The ``dev`` "extras requirement" ensures that dependencies required for 20 | development (to run the test suite and build the documentation) are installed. 21 | Only `graph-tool `_ will be missing: install it 22 | manually as it cannot be installed by pip. 23 | 24 | You can improve or add functionality in the ``pygsp`` folder, along with 25 | corresponding unit tests in ``pygsp/tests/test_*.py`` (with reasonable 26 | coverage). 27 | If you have a nice example to demonstrate the use of the introduced 28 | functionality, please consider adding a tutorial in ``doc/tutorials`` or a 29 | short example in ``examples``. 30 | 31 | Update ``README.rst`` and ``CHANGELOG.rst`` if applicable. 32 | 33 | After making any change, please check the style, run the tests, and build the 34 | documentation with the following (enforced by Travis CI):: 35 | 36 | $ make lint 37 | $ make test 38 | $ make doc 39 | 40 | Check the generated coverage report at ``htmlcov/index.html`` to make sure the 41 | tests reasonably cover the changes you've introduced. 42 | 43 | To iterate faster, you can partially run the test suite, at various degrees of 44 | granularity, as follows:: 45 | 46 | $ python -m unittest pygsp.tests.test_docstrings.suite_reference 47 | $ python -m unittest pygsp.tests.test_graphs.TestImportExport 48 | $ python -m unittest pygsp.tests.test_graphs.TestImportExport.test_save_load 49 | 50 | Making a release 51 | ---------------- 52 | 53 | #. Update the version number and release date in ``setup.py``, 54 | ``pygsp/__init__.py`` and ``CHANGELOG.rst``. 55 | #. Create a git tag with ``git tag -a v0.5.0 -m "PyGSP v0.5.0"``. 56 | #. Push the tag to GitHub with ``git push github v0.5.0``. The tag should now 57 | appear in the releases and tags tab. 58 | #. `Create a release `_ on 59 | GitHub and select the created tag. A DOI should then be issued by Zenodo. 60 | #. Go on Zenodo and fix the metadata if necessary. 61 | #. Build the distribution with ``make dist`` and check that the 62 | ``dist/PyGSP-0.5.0.tar.gz`` source archive contains all required files. The 63 | binary wheel should be found as ``dist/PyGSP-0.5.0-py2.py3-none-any.whl``. 64 | #. Test the upload and installation process:: 65 | 66 | $ twine upload --repository-url https://test.pypi.org/legacy/ dist/* 67 | $ pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple pygsp 68 | 69 | Log in as the LTS2 user. 70 | #. Build and upload the distribution to the real PyPI with ``make release``. 71 | #. Update the conda feedstock (at least the version number and sha256 in 72 | ``recipe/meta.yaml``) by sending a PR to 73 | `conda-forge `_. 74 | 75 | Repository organization 76 | ----------------------- 77 | 78 | :: 79 | 80 | LICENSE.txt Project license 81 | *.rst Important documentation 82 | Makefile Targets for make 83 | setup.py Meta information about package (published on PyPI) 84 | .gitignore Files ignored by the git revision control system 85 | .travis.yml Defines testing on Travis continuous integration 86 | 87 | pygsp/ Contains the modules (the actual toolbox implementation) 88 | __init.py__ Load modules at package import 89 | *.py One file per module 90 | 91 | pygsp/tests/ Contains the test suites (will be distributed to end user) 92 | __init.py__ Load modules at package import 93 | test_*.py One test suite per module 94 | test_docstrings.py Test the examples in the docstrings (reference doc) 95 | test_tutorials.py Test the tutorials in doc/tutorials 96 | test_all.py Launch all the tests (docstrings, tutorials, modules) 97 | 98 | doc/ Package documentation 99 | conf.py Sphinx configuration 100 | index.rst Documentation entry page 101 | *.rst Include doc files from root directory 102 | 103 | doc/reference/ Reference documentation 104 | index.rst Reference entry page 105 | *.rst Only directives, the actual doc is alongside the code 106 | 107 | doc/tutorials/ 108 | index.rst Tutorials entry page 109 | *.rst One file per tutorial 110 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, EPFL LTS2 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of PyGSP nor the names of its contributors may be used 15 | to endorse or promote products derived from this software without specific 16 | prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE.txt 2 | include README.rst 3 | include CHANGELOG.rst 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | NB = $(sort $(wildcard examples/*.ipynb)) 2 | .PHONY: help clean lint test doc dist release 3 | 4 | help: 5 | @echo "clean remove non-source files and clean source files" 6 | @echo "lint check style" 7 | @echo "test run tests and check coverage" 8 | @echo "doc generate HTML documentation and check links" 9 | @echo "dist package (source & wheel)" 10 | @echo "release package and upload to PyPI" 11 | 12 | clean: 13 | git clean -Xdf 14 | jupyter nbconvert --inplace --ClearOutputPreprocessor.enabled=True $(NB) 15 | 16 | lint: 17 | flake8 --doctests --exclude=doc 18 | 19 | # Matplotlib doesn't print to screen. Also faster. 20 | export MPLBACKEND = agg 21 | # Virtual framebuffer nonetheless needed for the pyqtgraph backend. 22 | export DISPLAY = :99 23 | 24 | test: 25 | Xvfb $$DISPLAY -screen 0 800x600x24 & 26 | coverage run --branch --source pygsp setup.py test 27 | coverage report 28 | coverage html 29 | killall Xvfb 30 | 31 | doc: 32 | sphinx-build -b html -d doc/_build/doctrees doc doc/_build/html 33 | sphinx-build -b linkcheck -d doc/_build/doctrees doc doc/_build/linkcheck 34 | 35 | dist: clean 36 | python setup.py sdist 37 | python setup.py bdist_wheel --universal 38 | ls -lh dist/* 39 | twine check dist/* 40 | 41 | release: dist 42 | twine upload dist/* 43 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ======================================== 2 | PyGSP: Graph Signal Processing in Python 3 | ======================================== 4 | 5 | The PyGSP is a Python package to ease 6 | `Signal Processing on Graphs `_. 7 | The documentation is available on 8 | `Read the Docs `_ 9 | and development takes place on 10 | `GitHub `_. 11 | A (mostly unmaintained) `Matlab version `_ exists. 12 | 13 | +-----------------------------------+ 14 | | |doc| |pypi| |conda| |binder| | 15 | +-----------------------------------+ 16 | | |zenodo| |license| |pyversions| | 17 | +-----------------------------------+ 18 | | |travis| |coveralls| |github| | 19 | +-----------------------------------+ 20 | 21 | .. |doc| image:: https://readthedocs.org/projects/pygsp/badge/?version=latest 22 | :target: https://pygsp.readthedocs.io 23 | .. |pypi| image:: https://img.shields.io/pypi/v/pygsp.svg 24 | :target: https://pypi.org/project/PyGSP 25 | .. |zenodo| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.1003157.svg 26 | :target: https://doi.org/10.5281/zenodo.1003157 27 | .. |license| image:: https://img.shields.io/pypi/l/pygsp.svg 28 | :target: https://github.com/epfl-lts2/pygsp/blob/master/LICENSE.txt 29 | .. |pyversions| image:: https://img.shields.io/pypi/pyversions/pygsp.svg 30 | :target: https://pypi.org/project/PyGSP 31 | .. |travis| image:: https://img.shields.io/travis/com/epfl-lts2/pygsp.svg 32 | :target: https://app.travis-ci.com/github/epfl-lts2/pygsp 33 | .. |coveralls| image:: https://img.shields.io/coveralls/github/epfl-lts2/pygsp.svg 34 | :target: https://coveralls.io/github/epfl-lts2/pygsp 35 | .. |github| image:: https://img.shields.io/github/stars/epfl-lts2/pygsp.svg?style=social 36 | :target: https://github.com/epfl-lts2/pygsp 37 | .. |binder| image:: https://static.mybinder.org/badge_logo.svg 38 | :target: https://mybinder.org/v2/gh/epfl-lts2/pygsp/master?urlpath=lab/tree/examples/playground.ipynb 39 | .. |conda| image:: https://img.shields.io/conda/vn/conda-forge/pygsp.svg 40 | :target: https://anaconda.org/conda-forge/pygsp 41 | 42 | The PyGSP facilitates a wide variety of operations on graphs, like computing 43 | their Fourier basis, filtering or interpolating signals, plotting graphs, 44 | signals, and filters. Its core is spectral graph theory, and many of the 45 | provided operations scale to very large graphs. The package includes a wide 46 | range of graphs, from point clouds like the Stanford bunny and the Swiss roll; 47 | to networks like the Minnesota road network; to models for generating random 48 | graphs like stochastic block models, sensor networks, Erdős–Rényi model, 49 | Barabási-Albert model; to simple graphs like the path, the ring, and the grid. 50 | Many filter banks are also provided, e.g. various wavelets like the Mexican 51 | hat, Meyer, Half Cosine; some low-pass filters like the heat kernel and the 52 | exponential window; and Gabor filters. Despite all the pre-defined models, you 53 | can easily use a custom graph by defining its adjacency matrix, and a custom 54 | filter bank by defining a set of functions in the spectral domain. 55 | 56 | While NetworkX_ and graph-tool_ are tools to analyze the topology of graphs, 57 | the aim of the PyGSP is to analyze graph signals, also known as features or 58 | properties (i.e., not the graph itself). 59 | Those three tools are complementary and work well together with the provided 60 | import / export facility. 61 | 62 | .. _NetworkX: https://networkx.org 63 | .. _graph-tool: https://graph-tool.skewed.de 64 | 65 | The following demonstrates how to instantiate a graph and a filter, the two 66 | main objects of the package. 67 | 68 | >>> from pygsp import graphs, filters 69 | >>> G = graphs.Logo() 70 | >>> G.compute_fourier_basis() # Fourier to plot the eigenvalues. 71 | >>> # G.estimate_lmax() is otherwise sufficient. 72 | >>> g = filters.Heat(G, scale=50) 73 | >>> fig, ax = g.plot() 74 | 75 | .. image:: ../pygsp/data/readme_example_filter.png 76 | :alt: 77 | .. image:: pygsp/data/readme_example_filter.png 78 | :alt: 79 | 80 | Let's now create a graph signal: a set of three Kronecker deltas for that 81 | example. We can now look at one step of heat diffusion by filtering the deltas 82 | with the above defined filter. Note how the diffusion follows the local 83 | structure! 84 | 85 | >>> import numpy as np 86 | >>> DELTAS = [20, 30, 1090] 87 | >>> s = np.zeros(G.N) 88 | >>> s[DELTAS] = 1 89 | >>> s = g.filter(s) 90 | >>> fig, ax = G.plot(s, highlight=DELTAS) 91 | 92 | .. image:: ../pygsp/data/readme_example_graph.png 93 | :alt: 94 | .. image:: pygsp/data/readme_example_graph.png 95 | :alt: 96 | 97 | You can 98 | `try it online `_, 99 | look at the 100 | `tutorials `_ 101 | to learn how to use it, or look at the 102 | `reference guide `_ 103 | for an exhaustive documentation of the API. Enjoy! 104 | 105 | Installation 106 | ------------ 107 | 108 | The PyGSP is available on PyPI:: 109 | 110 | $ pip install pygsp 111 | 112 | The PyGSP is available on `conda-forge `_:: 113 | 114 | $ conda install -c conda-forge pygsp 115 | 116 | The PyGSP is available in the `Arch User Repository `_:: 117 | 118 | $ git clone https://aur.archlinux.org/python-pygsp.git 119 | $ cd python-pygsp 120 | $ makepkg -csi 121 | 122 | Contributing 123 | ------------ 124 | 125 | See the guidelines for contributing in ``CONTRIBUTING.rst``. 126 | 127 | Acknowledgments 128 | --------------- 129 | 130 | The PyGSP was started in 2014 as an academic open-source project for 131 | research purpose at the `EPFL LTS2 laboratory `_. 132 | This project has been partly funded by the Swiss National Science Foundation 133 | under grant 200021_154350 "Towards Signal Processing on Graphs". 134 | 135 | It is released under the terms of the BSD 3-Clause license. 136 | 137 | If you are using the library for your research, for the sake of 138 | reproducibility, please cite the version you used as indexed by 139 | `Zenodo `_. 140 | Or cite the generic concept as:: 141 | 142 | @misc{pygsp, 143 | title = {PyGSP: Graph Signal Processing in Python}, 144 | author = {Defferrard, Micha\"el and Martin, Lionel and Pena, Rodrigo and Perraudin, Nathana\"el}, 145 | doi = {10.5281/zenodo.1003157}, 146 | url = {https://github.com/epfl-lts2/pygsp/}, 147 | } 148 | -------------------------------------------------------------------------------- /doc/changelog.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CHANGELOG.rst 2 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pygsp 4 | 5 | extensions = [ 6 | 'sphinx.ext.viewcode', 7 | 'sphinx.ext.autosummary', 8 | 'sphinx.ext.mathjax', 9 | 'sphinx.ext.inheritance_diagram', 10 | ] 11 | 12 | extensions.append('sphinx.ext.autodoc') 13 | autodoc_default_options = { 14 | 'members': True, 15 | 'undoc-members': True, 16 | 'member-order': 'groupwise', # alphabetical, groupwise, bysource 17 | } 18 | 19 | extensions.append('sphinx.ext.intersphinx') 20 | intersphinx_mapping = { 21 | 'python': ('https://docs.python.org/3', None), 22 | 'numpy': ('https://numpy.org/doc/stable', None), 23 | 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), 24 | 'matplotlib': ('https://matplotlib.org/stable', None), 25 | 'pyunlocbox': ('https://pyunlocbox.readthedocs.io/en/stable', None), 26 | 'networkx': ('https://networkx.org/documentation/stable', None), 27 | 'graph_tool': ('https://graph-tool.skewed.de/static/doc', None), 28 | } 29 | 30 | extensions.append('numpydoc') 31 | numpydoc_show_class_members = False 32 | numpydoc_use_plots = True # Add the plot directive whenever mpl is imported. 33 | 34 | extensions.append('matplotlib.sphinxext.plot_directive') 35 | plot_include_source = True 36 | plot_html_show_source_link = False 37 | plot_html_show_formats = False 38 | plot_working_directory = '.' 39 | plot_rcparams = { 40 | 'figure.figsize': (10, 4) 41 | } 42 | plot_pre_code = """ 43 | import numpy as np 44 | from pygsp import graphs, filters, utils, plotting 45 | """ 46 | 47 | extensions.append('sphinx_gallery.gen_gallery') 48 | sphinx_gallery_conf = { 49 | 'examples_dirs': '../examples', 50 | 'gallery_dirs': 'examples', 51 | 'filename_pattern': '/', 52 | 'reference_url': {'pygsp': None}, 53 | 'backreferences_dir': 'backrefs', 54 | 'doc_module': 'pygsp', 55 | 'show_memory': True, 56 | } 57 | 58 | extensions.append('sphinx_copybutton') 59 | copybutton_prompt_text = ">>> " 60 | 61 | extensions.append('sphinxcontrib.bibtex') 62 | bibtex_bibfiles = ['references.bib'] 63 | 64 | exclude_patterns = ['_build'] 65 | source_suffix = '.rst' 66 | master_doc = 'index' 67 | 68 | project = 'PyGSP' 69 | version = pygsp.__version__ 70 | release = pygsp.__version__ 71 | copyright = 'EPFL LTS2' 72 | 73 | pygments_style = 'sphinx' 74 | html_theme = 'sphinx_rtd_theme' 75 | html_theme_options = { 76 | 'navigation_depth': 2, 77 | } 78 | latex_elements = { 79 | 'papersize': 'a4paper', 80 | 'pointsize': '10pt', 81 | } 82 | latex_documents = [ 83 | ('index', 'pygsp.tex', 'PyGSP documentation', 84 | 'EPFL LTS2', 'manual'), 85 | ] 86 | -------------------------------------------------------------------------------- /doc/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | :hidden: 5 | 6 | Home 7 | tutorials/index 8 | examples/index 9 | reference/index 10 | changelog 11 | contributing 12 | references 13 | -------------------------------------------------------------------------------- /doc/reference/features.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Features 3 | ======== 4 | 5 | .. automodule:: pygsp.features 6 | -------------------------------------------------------------------------------- /doc/reference/filters.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Filters 3 | ======= 4 | 5 | .. automodule:: pygsp.filters 6 | -------------------------------------------------------------------------------- /doc/reference/graphs.rst: -------------------------------------------------------------------------------- 1 | ====== 2 | Graphs 3 | ====== 4 | 5 | .. automodule:: pygsp.graphs 6 | :exclude-members: Graph 7 | 8 | .. autoclass:: Graph 9 | :inherited-members: 10 | -------------------------------------------------------------------------------- /doc/reference/index.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | API reference 3 | ============= 4 | 5 | .. automodule:: pygsp 6 | 7 | .. toctree:: 8 | :hidden: 9 | 10 | graphs 11 | filters 12 | plotting 13 | learning 14 | reduction 15 | features 16 | optimization 17 | utils 18 | -------------------------------------------------------------------------------- /doc/reference/learning.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Learning 3 | ======== 4 | 5 | .. automodule:: pygsp.learning 6 | -------------------------------------------------------------------------------- /doc/reference/optimization.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Optimization 3 | ============ 4 | 5 | .. automodule:: pygsp.optimization 6 | -------------------------------------------------------------------------------- /doc/reference/plotting.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Plotting 3 | ======== 4 | 5 | .. automodule:: pygsp.plotting 6 | -------------------------------------------------------------------------------- /doc/reference/reduction.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Reduction 3 | ========= 4 | 5 | .. automodule:: pygsp.reduction 6 | -------------------------------------------------------------------------------- /doc/reference/utils.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Utilities 3 | ========= 4 | 5 | .. automodule:: pygsp.utils 6 | -------------------------------------------------------------------------------- /doc/references.bib: -------------------------------------------------------------------------------- 1 | @UNPUBLISHED{gleich, 2 | author={D. Gleich}, 3 | title={The {MatlabBGL} {Matlab} Library}, 4 | note={http://www.cs.purdue.edu/homes/dgleich/packages/matlab{\_}bgl/index.html}, 5 | } 6 | 7 | @BOOK{chung1997spectral, 8 | title = {Spectral Graph Theory}, 9 | publisher = {Vol. 92 of the {CBMS} Regional Conference Series in Mathematics, {American Mathematical Society}}, 10 | year = {1997}, 11 | author = {F. R. K. Chung}, 12 | } 13 | 14 | @BOOK{parlett1998symmetric, 15 | title = {The Symmetric Eigenvalue Problem}, 16 | publisher = {SIAM}, 17 | year = {1998}, 18 | author = {B. N. Parlett}, 19 | } 20 | 21 | @UNPUBLISHED{davis2011ldl, 22 | title={User Guide for {LDL}, a concise sparse {Cholesky} package}, 23 | author={Timothy A. Davis}, 24 | month={Jan.}, 25 | year={2011}, 26 | note={http://www.cise.ufl.edu/research/sparse/ldl/} 27 | } 28 | 29 | @ARTICLE{davis2005ldl, 30 | title={{Algorithm 849: A concise sparse Cholesky factorization package}}, 31 | author={T. A. Davis}, 32 | journal={ACM Trans. Mathem. Software}, 33 | volume={31}, 34 | number={4}, 35 | month={Dec.}, 36 | year={2005}, 37 | pages={587-591}, 38 | } 39 | 40 | @Article{hammond2011wavelets, 41 | author={D. K. Hammond and P. Vandergheynst and R. Gribonval}, 42 | journal={Appl. Comput. Harmon. Anal.}, 43 | title={Wavelets on graphs via spectral graph theory}, 44 | month={Mar.}, 45 | year=2011, 46 | volume={30}, 47 | number={2}, 48 | pages={129-150}, 49 | } 50 | 51 | @article{shuman2016vertexfrequency, 52 | title={Vertex-frequency analysis on graphs}, 53 | author={Shuman, David I and Ricaud, Benjamin and Vandergheynst, Pierre}, 54 | journal={Applied and Computational Harmonic Analysis}, 55 | year={2016}, 56 | } 57 | 58 | @ARTICLE{leonardi2013tight, 59 | author = {N. Leonardi and D. {Van De Ville}}, 60 | title = {Tight Wavelet Frames on Multislice Graphs}, 61 | journal = {IEEE Trans. Signal Process.}, 62 | month={Jul.}, 63 | volume={61}, 64 | number={13}, 65 | year = {2013}, 66 | pages={3357-3367}, 67 | } 68 | 69 | @article{klein1993resistance, 70 | title={Resistance distance}, 71 | author={Klein, Douglas J and Randi{\'c}, M}, 72 | journal={Journal of Mathematical Chemistry}, 73 | volume={12}, 74 | number={1}, 75 | pages={81--95}, 76 | year={1993}, 77 | publisher={Springer} 78 | } 79 | 80 | @article{strang1999dct, 81 | title={The discrete cosine transform}, 82 | author={Strang, Gilbert}, 83 | journal={SIAM review}, 84 | volume={41}, 85 | number={1}, 86 | pages={135--147}, 87 | year={1999}, 88 | publisher={SIAM}, 89 | url={https://sci-hub.se/10.1137/S0036144598336745}, 90 | } 91 | 92 | @article{shuman2013spectrum, 93 | title={Spectrum-adapted tight graph wavelet and vertex-frequency frames}, 94 | author={Shuman, David I and Wiesmeyr, Christoph and Holighaus, Nicki and Vandergheynst, Pierre}, 95 | journal={arXiv preprint arXiv:1311.0897}, 96 | year={2013} 97 | } 98 | 99 | 100 | @article{spielman2011graph, 101 | title={Graph sparsification by effective resistances}, 102 | author={Spielman, Daniel A and Srivastava, Nikhil}, 103 | journal={SIAM Journal on Computing}, 104 | volume={40}, 105 | number={6}, 106 | pages={1913--1926}, 107 | year={2011}, 108 | publisher={SIAM} 109 | } 110 | 111 | @article{rudelson1999random, 112 | title={Random vectors in the isotropic position}, 113 | author={Rudelson, Mark}, 114 | journal={Journal of Functional Analysis}, 115 | volume={164}, 116 | number={1}, 117 | pages={60--72}, 118 | year={1999}, 119 | publisher={Elsevier} 120 | } 121 | 122 | @article{rudelson2007sampling, 123 | title={Sampling from large matrices: An approach through geometric functional analysis}, 124 | author={Rudelson, Mark and Vershynin, Roman}, 125 | journal={Journal of the ACM (JACM)}, 126 | volume={54}, 127 | number={4}, 128 | pages={21}, 129 | year={2007}, 130 | publisher={ACM} 131 | } 132 | 133 | 134 | @article{dorfler2013kron, 135 | title={Kron reduction of graphs with applications to electrical networks}, 136 | author={Dorfler, Florian and Bullo, Francesco}, 137 | journal={Circuits and Systems I: Regular Papers, IEEE Transactions on}, 138 | volume={60}, 139 | number={1}, 140 | pages={150--163}, 141 | year={2013}, 142 | publisher={IEEE} 143 | } 144 | 145 | @article{pesenson2009variational, 146 | title={Variational splines and Paley--Wiener spaces on combinatorial graphs}, 147 | author={Pesenson, Isaac}, 148 | journal={Constructive Approximation}, 149 | volume={29}, 150 | number={1}, 151 | pages={1--21}, 152 | year={2009}, 153 | publisher={Springer} 154 | } 155 | 156 | 157 | @article{shuman2013framework, 158 | title={A framework for multiscale transforms on graphs}, 159 | author={Shuman, David I and Faraji, Mohammad Javad and Vandergheynst, Pierre}, 160 | journal={arXiv preprint arXiv:1308.4942}, 161 | year={2013} 162 | } 163 | 164 | @inproceedings{turk1994zippered, 165 | title={Zippered polygon meshes from range images}, 166 | author={Turk, Greg and Levoy, Marc}, 167 | booktitle={Proceedings of the 21st annual conference on Computer graphics and interactive techniques}, 168 | pages={311--318}, 169 | year={1994}, 170 | organization={ACM} 171 | } 172 | 173 | @ARTICLE{perraudin2014gspbox, 174 | author = {{Perraudin}, Nathana{\"e}l and {Paratte}, Johan and {Shuman}, David and {Kalofolias}, Vassilis and 175 | {Vandergheynst}, Pierre and {Hammond}, David K. }, 176 | title = "{GSPBOX: A toolbox for signal processing on graphs}", 177 | journal = {ArXiv e-prints}, 178 | archivePrefix = "arXiv", 179 | eprint = {1408.5781}, 180 | primaryClass = "cs.IT", 181 | keywords = {Computer Science - Information Theory}, 182 | year = 2014, 183 | month = aug, 184 | adsurl = {http://arxiv.org/abs/1408.5781}, 185 | } 186 | 187 | @article{tremblay2016compressive, 188 | title={Compressive spectral clustering}, 189 | author={Tremblay, Nicolas and Puy, Gilles and Gribonval, R{\'e}mi and Vandergheynst, Pierre}, 190 | journal={arXiv preprint arXiv:1602.02018}, 191 | year={2016} 192 | } 193 | 194 | @inproceedings{kim2003randomregulargraphs, 195 | title={Generating random regular graphs}, 196 | author={Kim, Jeong Han and Vu, Van H}, 197 | booktitle={Proceedings of the thirty-fifth annual ACM symposium on Theory of computing}, 198 | year={2003} 199 | } 200 | 201 | @inproceedings{leonardi2011wavelet, 202 | title={Wavelet frames on graphs defined by fMRI functional connectivity}, 203 | author={Leonardi, Nora and Van De Ville, Dimitri}, 204 | booktitle={Biomedical Imaging: From Nano to Macro, 2011 IEEE International Symposium on}, 205 | pages={2136--2139}, 206 | year={2011}, 207 | organization={IEEE} 208 | } 209 | 210 | @inproceedings{grassi2016timevertex, 211 | title={Tracking time-vertex propagation using dynamic graph wavelets}, 212 | author={Grassi, Francesco and Perraudin, Nathanael and Ricaud, Benjamin}, 213 | year={2016}, 214 | booktitle={Signal and Information Processing (GlobalSIP), 2016 IEEE Global Conference on}, 215 | } 216 | 217 | @article{grassi2018timevertex, 218 | title={A time-vertex signal processing framework: Scalable processing and meaningful representations for time-series on graphs}, 219 | author={Grassi, Francesco and Loukas, Andreas and Perraudin, Nathanael and Ricaud, Benjamin}, 220 | year={2018}, 221 | journal={IEEE Transactions on Signal Processing}, 222 | } 223 | -------------------------------------------------------------------------------- /doc/references.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | References 3 | ========== 4 | 5 | .. bibliography:: references.bib 6 | :cited: 7 | :style: alpha 8 | -------------------------------------------------------------------------------- /doc/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Tutorials 3 | ========= 4 | 5 | The following are some tutorials which explain how to use the toolbox and show 6 | how to use it to solve some problems. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | intro 12 | wavelet 13 | optimization 14 | pyramid 15 | -------------------------------------------------------------------------------- /doc/tutorials/optimization.rst: -------------------------------------------------------------------------------- 1 | =========================================================== 2 | Optimization problems: graph TV vs. Tikhonov regularization 3 | =========================================================== 4 | 5 | Description 6 | ----------- 7 | 8 | Modern signal processing often involves solving an optimization problem. Graph signal processing (GSP) consists roughly of working with linear operators defined by a graph (e.g., the graph Laplacian). The setting up of an optimization problem in the graph context is often then simply a matter of identifying which graph-defined linear operator is relevant to be used in a regularization and/or fidelity term. 9 | 10 | This tutorial focuses on the problem of recovering a label signal on a graph from subsampled and noisy data, but most information should be fairly generally applied to other problems as well. 11 | 12 | .. plot:: 13 | :context: reset 14 | 15 | >>> import numpy as np 16 | >>> from pygsp import graphs, plotting 17 | >>> 18 | >>> # Create a random sensor graph 19 | >>> G = graphs.Sensor(N=256, distributed=True, seed=42) 20 | >>> G.compute_fourier_basis() 21 | >>> 22 | >>> # Create label signal 23 | >>> label_signal = np.copysign(np.ones(G.N), G.U[:, 3]) 24 | >>> 25 | >>> fig, ax = G.plot(label_signal) 26 | 27 | The first figure shows a plot of the original label signal, that we wish to recover, on the graph. 28 | 29 | .. plot:: 30 | :context: close-figs 31 | 32 | >>> rng = np.random.default_rng(42) 33 | >>> 34 | >>> # Create the mask 35 | >>> M = rng.uniform(size=G.N) 36 | >>> M = (M > 0.6).astype(float) # Probability of having no label on a vertex. 37 | >>> 38 | >>> # Applying the mask to the data 39 | >>> sigma = 0.1 40 | >>> subsampled_noisy_label_signal = M * (label_signal + sigma * rng.standard_normal(G.N)) 41 | >>> 42 | >>> fig, ax = G.plot(subsampled_noisy_label_signal) 43 | 44 | This figure shows the label signal on the graph after the application of the subsampling mask and the addition of noise. The label of more than half of the vertices has been set to :math:`0`. 45 | 46 | Since the problem is ill-posed, we will use some regularization to reach a solution that is more in tune with what we expect a label signal to look like. We will compare two approaches, but they are both based on measuring local differences on the label signal. Those differences are essentially an edge signal: to each edge we can associate the difference between the label signals of its associated nodes. The linear operator that does such a mapping is called the graph gradient :math:`\nabla_G`, and, fortunately for us, it is available under the :attr:`pygsp.graphs.Graph.D` (for differential) attribute of any :mod:`pygsp.graphs` graph. 47 | 48 | The reason for measuring local differences comes from prior knowledge: we assume label signals don't vary too much locally. The precise measure of such variation is what distinguishes the two regularization approaches we'll use. 49 | 50 | The first one, shown below, is called graph total variation (TV) regularization. The quadratic fidelity term is multiplied by a regularization constant :math:`\gamma` and its goal is to force the solution to stay close to the observed labels :math:`b`. The :math:`\ell_1` norm of the action of the graph gradient is what's called the graph TV. We will see that it promotes piecewise-constant solutions. 51 | 52 | .. math:: \operatorname*{arg\,min}_x \|\nabla_G x\|_1 + \gamma \|Mx-b\|_2^2 53 | 54 | The second approach, called graph Tikhonov regularization, is to use a smooth (differentiable) quadratic regularizer. A consequence of this choice is that the solution will tend to have smoother transitions. The quadratic fidelity term is still the same. 55 | 56 | .. math:: \operatorname*{arg\,min}_x \|\nabla_G x\|_2^2 + \gamma \|Mx-b\|_2^2 57 | 58 | Results and code 59 | ---------------- 60 | 61 | For solving the optimization problems we've assembled, you will need a numerical solver package. This part is implemented in this tutorial with the `pyunlocbox `_, which is based on proximal splitting algorithms. Check also its `documentation `_ for more information about the parameters used here. 62 | 63 | We start with the graph TV regularization. We will use the :class:`pyunlocbox.solvers.mlfbf` solver from :mod:`pyunlocbox`. It is a primal-dual solver, which means for our problem that the regularization term will be written in terms of the dual variable :math:`u = \nabla_G x`, and the graph gradient :math:`\nabla_G` will be passed to the solver as the primal-dual map. The value of :math:`3.0` for the regularization parameter :math:`\gamma` was chosen on the basis of the visual appeal of the returned solution. 64 | 65 | .. plot:: 66 | :context: close-figs 67 | 68 | >>> import pyunlocbox 69 | >>> 70 | >>> # Set the functions in the problem 71 | >>> gamma = 3.0 72 | >>> d = pyunlocbox.functions.dummy() 73 | >>> r = pyunlocbox.functions.norm_l1() 74 | >>> f = pyunlocbox.functions.norm_l2(w=M, y=subsampled_noisy_label_signal, 75 | ... lambda_=gamma) 76 | >>> 77 | >>> # Define the solver 78 | >>> G.compute_differential_operator() 79 | >>> L = G.D.T.toarray() 80 | >>> step = 0.999 / (1 + np.linalg.norm(L)) 81 | >>> solver = pyunlocbox.solvers.mlfbf(L=L, step=step) 82 | >>> 83 | >>> # Solve the problem 84 | >>> x0 = subsampled_noisy_label_signal.copy() 85 | >>> prob1 = pyunlocbox.solvers.solve([d, r, f], solver=solver, 86 | ... x0=x0, rtol=0, maxit=1000) 87 | Solution found after 1000 iterations: 88 | objective function f(sol) = 2.213139e+02 89 | stopping criterion: MAXIT 90 | >>> 91 | >>> fig, ax = G.plot(prob1['sol']) 92 | 93 | This figure shows the label signal recovered by graph total variation regularization. We can confirm here that this sort of regularization does indeed promote piecewise-constant solutions. 94 | 95 | .. plot:: 96 | :context: close-figs 97 | 98 | >>> # Set the functions in the problem 99 | >>> r = pyunlocbox.functions.norm_l2(A=L, tight=False) 100 | >>> 101 | >>> # Define the solver 102 | >>> step = 0.999 / np.linalg.norm(np.dot(L.T, L) + gamma * np.diag(M), 2) 103 | >>> solver = pyunlocbox.solvers.gradient_descent(step=step) 104 | >>> 105 | >>> # Solve the problem 106 | >>> x0 = subsampled_noisy_label_signal.copy() 107 | >>> prob2 = pyunlocbox.solvers.solve([r, f], solver=solver, 108 | ... x0=x0, rtol=0, maxit=1000) 109 | Solution found after 1000 iterations: 110 | objective function f(sol) = 6.422673e+01 111 | stopping criterion: MAXIT 112 | >>> 113 | >>> fig, ax = G.plot(prob2['sol']) 114 | 115 | This last figure shows the label signal recovered by Tikhonov regularization. As expected, the recovered label signal has smoother transitions than the one obtained by graph TV regularization. 116 | -------------------------------------------------------------------------------- /doc/tutorials/pyramid.rst: -------------------------------------------------------------------------------- 1 | =================================== 2 | Graph multiresolution: Kron pyramid 3 | =================================== 4 | 5 | In this demonstration file, we show how to reduce a graph using the PyGSP. Then we apply the pyramid to simple signal. 6 | To start open a python shell (IPython is recommended here) and import the required packages. You would probably also import numpy as you will need it to create matrices and arrays. 7 | 8 | >>> import numpy as np 9 | >>> from pygsp import graphs, reduction 10 | 11 | For this demo we will be using a sensor graph with 400 nodes. 12 | 13 | >>> G = graphs.Sensor(400, distributed=True) 14 | >>> G.compute_fourier_basis() 15 | 16 | The function graph_multiresolution computes the graph pyramid for you: 17 | 18 | >>> levels = 5 19 | >>> Gs = reduction.graph_multiresolution(G, levels, sparsify=False) 20 | 21 | Next, we will compute the fourier basis of our different graph layers: 22 | 23 | >>> for gr in Gs: 24 | ... gr.compute_fourier_basis() 25 | 26 | Those that were already computed are returning with an error, meaning that nothing happened. 27 | Let's now create two signals and a filter, resp f, f2 and g: 28 | 29 | >>> f = np.ones((G.N)) 30 | >>> f[np.arange(G.N//2)] = -1 31 | >>> f = f + 10*Gs[0].U[:, 7] 32 | 33 | >>> f2 = np.ones((G.N, 2)) 34 | >>> f2[np.arange(G.N//2)] = -1 35 | 36 | >>> g = [lambda x: 5./(5 + x)] 37 | 38 | We will run the analysis of the two signals on the pyramid and obtain a coarse approximation for each layer, with decreasing number of nodes. 39 | Additionally, we will also get prediction errors at each node at every layer. 40 | 41 | >>> ca, pe = reduction.pyramid_analysis(Gs, f, h_filters=g, method='exact') 42 | >>> ca2, pe2 = reduction.pyramid_analysis(Gs, f2, h_filters=g, method='exact') 43 | 44 | Given the pyramid, the coarsest approximation and the prediction errors, we will now reconstruct the original signal on the full graph. 45 | 46 | >>> f_pred, _ = reduction.pyramid_synthesis(Gs, ca[levels], pe, method='exact') 47 | >>> f_pred2, _ = reduction.pyramid_synthesis(Gs, ca2[levels], pe2, method='exact') 48 | 49 | Here are the final errors for each signal after reconstruction. 50 | 51 | >>> err = np.linalg.norm(f_pred-f)/np.linalg.norm(f) 52 | >>> err2 = np.linalg.norm(f_pred2-f2)/np.linalg.norm(f2) 53 | >>> assert (err < 1e-10) & (err2 < 1e-10) 54 | -------------------------------------------------------------------------------- /doc/tutorials/wavelet.rst: -------------------------------------------------------------------------------- 1 | ======================================= 2 | Introduction to spectral graph wavelets 3 | ======================================= 4 | 5 | This tutorial will show you how to easily construct a wavelet_ frame, a kind of 6 | filter bank, and apply it to a signal. This tutorial will walk you into 7 | computing the wavelet coefficients of a graph, visualizing filters in the 8 | vertex domain, and using the wavelets to estimate the curvature of a 3D shape. 9 | 10 | .. _wavelet: https://en.wikipedia.org/wiki/Wavelet 11 | 12 | As usual, we first have to import some packages. 13 | 14 | .. plot:: 15 | :context: reset 16 | 17 | >>> import numpy as np 18 | >>> import matplotlib.pyplot as plt 19 | >>> from pygsp import graphs, filters, plotting, utils 20 | 21 | Then we can load a graph. The graph we'll use is a nearest-neighbor graph of a 22 | point cloud of the Stanford bunny. It will allow us to get interesting visual 23 | results using wavelets. 24 | 25 | .. plot:: 26 | :context: close-figs 27 | 28 | >>> G = graphs.Bunny() 29 | 30 | .. note:: 31 | At this stage we could compute the Fourier basis using 32 | :meth:`pygsp.graphs.Graph.compute_fourier_basis`, but this would take some 33 | time, and can be avoided with a Chebychev polynomials approximation to 34 | graph filtering. See the documentation of the 35 | :meth:`pygsp.filters.Filter.filter` filtering function and 36 | :cite:`hammond2011wavelets` for details on how it is down. 37 | 38 | Simple filtering: heat diffusion 39 | -------------------------------- 40 | 41 | Before tackling wavelets, let's observe the effect of a single filter on a 42 | graph signal. We first design a few heat kernel filters, each with a different 43 | scale. 44 | 45 | .. plot:: 46 | :context: close-figs 47 | 48 | >>> taus = [10, 25, 50] 49 | >>> g = filters.Heat(G, taus) 50 | 51 | Let's create a signal as a Kronecker delta located on one vertex, e.g. the 52 | vertex 20. That signal is our heat source. 53 | 54 | .. plot:: 55 | :context: close-figs 56 | 57 | >>> s = np.zeros(G.N) 58 | >>> DELTA = 20 59 | >>> s[DELTA] = 1 60 | 61 | We can now simulate heat diffusion by filtering our signal `s` with each of our 62 | heat kernels. 63 | 64 | .. plot:: 65 | :context: close-figs 66 | 67 | >>> s = g.filter(s, method='chebyshev') 68 | 69 | And finally plot the filtered signal showing heat diffusion at different 70 | scales. 71 | 72 | .. plot:: 73 | :context: close-figs 74 | 75 | >>> fig = plt.figure(figsize=(10, 3)) 76 | >>> for i in range(g.Nf): 77 | ... ax = fig.add_subplot(1, g.Nf, i+1, projection='3d') 78 | ... title = r'Heat diffusion, $\tau={}$'.format(taus[i]) 79 | ... _ = G.plot(s[:, i], colorbar=False, title=title, ax=ax) 80 | ... ax.set_axis_off() 81 | >>> fig.tight_layout() 82 | 83 | .. note:: 84 | The :meth:`pygsp.filters.Filter.localize` method can be used to visualize a 85 | filter in the vertex domain instead of doing it manually. 86 | 87 | Visualizing wavelets atoms 88 | -------------------------- 89 | 90 | Let's now replace the Heat filter by a filter bank of wavelets. We can create a 91 | filter bank using one of the predefined filters, such as 92 | :class:`pygsp.filters.MexicanHat` to design a set of `Mexican hat wavelets`_. 93 | 94 | .. _Mexican hat wavelets: 95 | https://en.wikipedia.org/wiki/Mexican_hat_wavelet 96 | 97 | .. plot:: 98 | :context: close-figs 99 | 100 | >>> g = filters.MexicanHat(G, Nf=6) # Nf = 6 filters in the filter bank. 101 | 102 | Then plot the frequency response of those filters. 103 | 104 | .. plot:: 105 | :context: close-figs 106 | 107 | >>> fig, ax = plt.subplots(figsize=(10, 5)) 108 | >>> _ = g.plot(title='Filter bank of mexican hat wavelets', ax=ax) 109 | 110 | .. note:: 111 | We can see that the wavelet atoms are stacked on the low frequency part of 112 | the spectrum. A better coverage could be obtained by adapting the filter 113 | bank with :class:`pygsp.filters.WarpedTranslates` or by using another 114 | filter bank like :class:`pygsp.filters.Itersine`. 115 | 116 | We can visualize the atoms as we did with the heat kernel, by filtering 117 | a Kronecker delta placed at one specific vertex. 118 | 119 | .. plot:: 120 | :context: close-figs 121 | 122 | >>> s = g.localize(DELTA) 123 | >>> 124 | >>> fig = plt.figure(figsize=(10, 2.5)) 125 | >>> for i in range(3): 126 | ... ax = fig.add_subplot(1, 3, i+1, projection='3d') 127 | ... _ = G.plot(s[:, i], title='Wavelet {}'.format(i+1), ax=ax) 128 | ... ax.set_axis_off() 129 | >>> fig.tight_layout() 130 | 131 | Curvature estimation 132 | -------------------- 133 | 134 | As a last and more applied example, let us try to estimate the curvature of the 135 | underlying 3D model by only using spectral filtering on the nearest-neighbor 136 | graph formed by its point cloud. 137 | 138 | A simple way to accomplish that is to use the coordinates map :math:`[x, y, z]` 139 | and filter it using the above defined wavelets. Doing so gives us a 140 | 3-dimensional signal 141 | :math:`[g_i(L)x, g_i(L)y, g_i(L)z], \ i \in [0, \ldots, N_f]` 142 | which describes variation along the 3 coordinates. 143 | 144 | .. plot:: 145 | :context: close-figs 146 | 147 | >>> s = G.coords 148 | >>> s = g.filter(s) 149 | 150 | The curvature is then estimated by taking the :math:`\ell_1` or :math:`\ell_2` 151 | norm across the 3D position. 152 | 153 | .. plot:: 154 | :context: close-figs 155 | 156 | >>> s = np.linalg.norm(s, ord=2, axis=1) 157 | 158 | Let's finally plot the result to observe that we indeed have a measure of the 159 | curvature at different scales. 160 | 161 | .. plot:: 162 | :context: close-figs 163 | 164 | >>> fig = plt.figure(figsize=(10, 7)) 165 | >>> for i in range(4): 166 | ... ax = fig.add_subplot(2, 2, i+1, projection='3d') 167 | ... title = 'Curvature estimation (scale {})'.format(i+1) 168 | ... _ = G.plot(s[:, i], title=title, ax=ax) 169 | ... ax.set_axis_off() 170 | >>> fig.tight_layout() 171 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | ======== 2 | Examples 3 | ======== 4 | -------------------------------------------------------------------------------- /examples/eigenvalue_concentration.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Concentration of the eigenvalues 3 | ================================ 4 | 5 | The eigenvalues of the graph Laplacian concentrates to the same value as the 6 | graph becomes full. 7 | """ 8 | 9 | import numpy as np 10 | from matplotlib import pyplot as plt 11 | import pygsp as pg 12 | 13 | n_neighbors = [1, 2, 5, 8] 14 | fig, axes = plt.subplots(3, len(n_neighbors), figsize=(15, 8)) 15 | 16 | for k, ax in zip(n_neighbors, axes.T): 17 | 18 | graph = pg.graphs.Ring(17, k=k) 19 | graph.compute_fourier_basis() 20 | graph.plot(graph.U[:, 1], ax=ax[0]) 21 | ax[0].axis('equal') 22 | ax[1].spy(graph.W) 23 | ax[2].plot(graph.e, '.') 24 | ax[2].set_title('k={}'.format(k)) 25 | #graph.set_coordinates('line1D') 26 | #graph.plot(graph.U[:, :4], ax=ax[3], title='') 27 | 28 | # Check that the DFT matrix is an eigenbasis of the Laplacian. 29 | U = np.fft.fft(np.identity(graph.n_vertices)) 30 | LambdaM = (graph.L.todense().dot(U)) / (U + 1e-15) 31 | # Eigenvalues should be real. 32 | assert np.all(np.abs(np.imag(LambdaM)) < 1e-10) 33 | LambdaM = np.real(LambdaM) 34 | # Check that the eigenvectors are really eigenvectors of the laplacian. 35 | Lambda = np.mean(LambdaM, axis=0) 36 | assert np.all(np.abs(LambdaM - Lambda) < 1e-10) 37 | 38 | fig.tight_layout() 39 | -------------------------------------------------------------------------------- /examples/eigenvector_localization.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Localization of Fourier modes 3 | ============================= 4 | 5 | The Fourier modes (the eigenvectors of the graph Laplacian) can be localized in 6 | the spacial domain. As a consequence, graph signals can be localized in both 7 | space and frequency (which is impossible for Euclidean domains or manifolds, by 8 | the Heisenberg's uncertainty principle). 9 | 10 | This example demonstrates that the more isolated a node is, the more a Fourier 11 | mode will be localized on it. 12 | 13 | The mutual coherence between the basis of Kronecker deltas and the basis formed 14 | by the eigenvectors of the Laplacian, :attr:`pygsp.graphs.Graph.coherence`, is 15 | a measure of the localization of the Fourier modes. The larger the value, the 16 | more localized the eigenvectors can be. 17 | 18 | See `Global and Local Uncertainty Principles for Signals on Graphs 19 | `_ for details. 20 | """ 21 | 22 | import numpy as np 23 | import matplotlib as mpl 24 | from matplotlib import pyplot as plt 25 | import pygsp as pg 26 | 27 | fig, axes = plt.subplots(2, 2, figsize=(8, 8)) 28 | 29 | for w, ax in zip([10, 1, 0.1, 0.01], axes.flatten()): 30 | 31 | adjacency = [ 32 | [0, w, 0, 0], 33 | [w, 0, 1, 0], 34 | [0, 1, 0, 1], 35 | [0, 0, 1, 0], 36 | ] 37 | graph = pg.graphs.Graph(adjacency) 38 | graph.compute_fourier_basis() 39 | 40 | # Plot eigenvectors. 41 | ax.plot(graph.U) 42 | ax.set_ylim(-1, 1) 43 | ax.set_yticks([-1, 0, 1]) 44 | ax.legend([f'$u_{i}(v)$, $\lambda_{i}={graph.e[i]:.1f}$' for i in 45 | range(graph.n_vertices)], loc='upper right') 46 | 47 | ax.text(0, -0.9, f'coherence = {graph.coherence:.2f}' 48 | f'$\in [{1/np.sqrt(graph.n_vertices)}, 1]$') 49 | 50 | # Plot vertices. 51 | ax.set_xticks(range(graph.n_vertices)) 52 | ax.set_xticklabels([f'$v_{i}$' for i in range(graph.n_vertices)]) 53 | 54 | # Plot graph. 55 | x, y = np.arange(0, graph.n_vertices), -1.20*np.ones(graph.n_vertices) 56 | line = mpl.lines.Line2D(x, y, lw=3, color='k', marker='.', markersize=20) 57 | line.set_clip_on(False) 58 | ax.add_line(line) 59 | 60 | # Plot edge weights. 61 | for i in range(graph.n_vertices - 1): 62 | j = i+1 63 | ax.text(i+0.5, -1.15, f'$w_{{{i}{j}}} = {adjacency[i][j]}$', 64 | horizontalalignment='center') 65 | 66 | fig.tight_layout() 67 | -------------------------------------------------------------------------------- /examples/filtering.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Filtering a signal 3 | ================== 4 | 5 | A graph signal is filtered by transforming it to the spectral domain (via the 6 | Fourier transform), performing a point-wise multiplication (motivated by the 7 | convolution theorem), and transforming it back to the vertex domain (via the 8 | inverse graph Fourier transform). 9 | 10 | .. note:: 11 | 12 | In practice, filtering is implemented in the vertex domain to avoid the 13 | computationally expensive graph Fourier transform. To do so, filters are 14 | implemented as polynomials of the eigenvalues / Laplacian. Hence, filtering 15 | a signal reduces to its multiplications with sparse matrices (the graph 16 | Laplacian). 17 | 18 | """ 19 | 20 | import numpy as np 21 | from matplotlib import pyplot as plt 22 | import pygsp as pg 23 | 24 | G = pg.graphs.Sensor(seed=42) 25 | G.compute_fourier_basis() 26 | 27 | #g = pg.filters.Rectangular(G, band_max=0.2) 28 | g = pg.filters.Expwin(G, band_max=0.5) 29 | 30 | fig, axes = plt.subplots(1, 3, figsize=(12, 4)) 31 | fig.subplots_adjust(hspace=0.5) 32 | 33 | x = np.random.default_rng(1).normal(size=G.N) 34 | #x = np.random.default_rng(42).uniform(-1, 1, size=G.N) 35 | x = 3 * x / np.linalg.norm(x) 36 | y = g.filter(x) 37 | x_hat = G.gft(x).squeeze() 38 | y_hat = G.gft(y).squeeze() 39 | 40 | limits = [x.min(), x.max()] 41 | 42 | G.plot(x, limits=limits, ax=axes[0], title='input signal $x$ in the vertex domain') 43 | axes[0].text(0, -0.1, '$x^T L x = {:.2f}$'.format(G.dirichlet_energy(x))) 44 | axes[0].set_axis_off() 45 | 46 | g.plot(ax=axes[1], alpha=1) 47 | line_filt = axes[1].lines[-2] 48 | line_in, = axes[1].plot(G.e, np.abs(x_hat), '.-') 49 | line_out, = axes[1].plot(G.e, np.abs(y_hat), '.-') 50 | #axes[1].set_xticks(range(0, 16, 4)) 51 | axes[1].set_xlabel(r'graph frequency $\lambda$') 52 | axes[1].set_ylabel(r'frequency content $\hat{x}(\lambda)$') 53 | axes[1].set_title(r'signals in the spectral domain') 54 | axes[1].legend(['input signal $\hat{x}$']) 55 | labels = [ 56 | r'input signal $\hat{x}$', 57 | 'kernel $g$', 58 | r'filtered signal $\hat{y}$', 59 | ] 60 | axes[1].legend([line_in, line_filt, line_out], labels, loc='upper right') 61 | 62 | G.plot(y, limits=limits, ax=axes[2], title='filtered signal $y$ in the vertex domain') 63 | axes[2].text(0, -0.1, '$y^T L y = {:.2f}$'.format(G.dirichlet_energy(y))) 64 | axes[2].set_axis_off() 65 | 66 | fig.tight_layout() 67 | -------------------------------------------------------------------------------- /examples/fourier_basis.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Fourier basis 3 | ============= 4 | 5 | The eigenvectors of the graph Laplacian form the Fourier basis. 6 | The eigenvalues are a measure of variation of their corresponding eigenvector. 7 | The lower the eigenvalue, the smoother the eigenvector. They are hence a 8 | measure of "frequency". 9 | 10 | In classical signal processing, Fourier modes are completely delocalized, like 11 | on the grid graph. For general graphs however, Fourier modes might be 12 | localized. See :attr:`pygsp.graphs.Graph.coherence`. 13 | """ 14 | 15 | import numpy as np 16 | from matplotlib import pyplot as plt 17 | import pygsp as pg 18 | 19 | n_eigenvectors = 7 20 | 21 | fig, axes = plt.subplots(2, 7, figsize=(15, 4)) 22 | 23 | def plot_eigenvectors(G, axes): 24 | G.compute_fourier_basis(n_eigenvectors) 25 | limits = [f(G.U) for f in (np.min, np.max)] 26 | for i, ax in enumerate(axes): 27 | G.plot(G.U[:, i], limits=limits, colorbar=False, vertex_size=50, ax=ax) 28 | energy = abs(G.dirichlet_energy(G.U[:, i])) 29 | ax.set_title(r'$u_{0}^\top L u_{0} = {1:.2f}$'.format(i+1, energy)) 30 | ax.set_axis_off() 31 | 32 | G = pg.graphs.Grid2d(10, 10) 33 | plot_eigenvectors(G, axes[0]) 34 | fig.subplots_adjust(hspace=0.5, right=0.8) 35 | cax = fig.add_axes([0.82, 0.60, 0.01, 0.26]) 36 | fig.colorbar(axes[0, -1].collections[1], cax=cax, ticks=[-0.2, 0, 0.2]) 37 | 38 | G = pg.graphs.Sensor(seed=42) 39 | plot_eigenvectors(G, axes[1]) 40 | fig.subplots_adjust(hspace=0.5, right=0.8) 41 | cax = fig.add_axes([0.82, 0.16, 0.01, 0.26]) 42 | _ = fig.colorbar(axes[1, -1].collections[1], cax=cax, ticks=[-0.4, 0, 0.4]) 43 | -------------------------------------------------------------------------------- /examples/fourier_transform.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Fourier transform 3 | ================= 4 | 5 | The graph Fourier transform :meth:`pygsp.graphs.Graph.gft` transforms a 6 | signal from the vertex domain to the spectral domain. The smoother the signal 7 | (see :meth:`pygsp.graphs.Graph.dirichlet_energy`), the lower in the frequencies 8 | its energy is concentrated. 9 | """ 10 | 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | import pygsp as pg 14 | 15 | G = pg.graphs.Sensor(seed=42) 16 | G.compute_fourier_basis() 17 | 18 | scales = [10, 3, 0] 19 | limit = 0.44 20 | 21 | fig, axes = plt.subplots(2, len(scales), figsize=(12, 4)) 22 | fig.subplots_adjust(hspace=0.5) 23 | 24 | x0 = np.random.default_rng(1).normal(size=G.N) 25 | for i, scale in enumerate(scales): 26 | g = pg.filters.Heat(G, scale) 27 | x = g.filter(x0).squeeze() 28 | x /= np.linalg.norm(x) 29 | x_hat = G.gft(x).squeeze() 30 | 31 | assert np.all((-limit < x) & (x < limit)) 32 | G.plot(x, limits=[-limit, limit], ax=axes[0, i]) 33 | axes[0, i].set_axis_off() 34 | axes[0, i].set_title('$x^T L x = {:.2f}$'.format(G.dirichlet_energy(x))) 35 | 36 | axes[1, i].plot(G.e, np.abs(x_hat), '.-') 37 | axes[1, i].set_xticks(range(0, 16, 4)) 38 | axes[1, i].set_xlabel(r'graph frequency $\lambda$') 39 | axes[1, i].set_ylim(-0.05, 0.95) 40 | 41 | axes[1, 0].set_ylabel(r'frequency content $\hat{x}(\lambda)$') 42 | 43 | # axes[0, 0].set_title(r'$x$: signal in the vertex domain') 44 | # axes[1, 0].set_title(r'$\hat{x}$: signal in the spectral domain') 45 | 46 | fig.tight_layout() 47 | -------------------------------------------------------------------------------- /examples/heat_diffusion.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Heat diffusion 3 | ============== 4 | 5 | Solve the heat equation by filtering the initial conditions with the heat 6 | kernel :class:`pygsp.filters.Heat`. 7 | """ 8 | 9 | from os import path 10 | 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | import pygsp as pg 14 | 15 | n_side = 13 16 | G = pg.graphs.Grid2d(n_side) 17 | G.compute_fourier_basis() 18 | 19 | sources = [ 20 | (n_side//4 * n_side) + (n_side//4), 21 | (n_side*3//4 * n_side) + (n_side*3//4), 22 | ] 23 | x = np.zeros(G.n_vertices) 24 | x[sources] = 5 25 | 26 | times = [0, 5, 10, 20] 27 | 28 | fig, axes = plt.subplots(2, len(times), figsize=(12, 5)) 29 | for i, t in enumerate(times): 30 | g = pg.filters.Heat(G, scale=t) 31 | title = r'$\hat{{f}}({0}) = g_{{1,{0}}} \odot \hat{{f}}(0)$'.format(t) 32 | g.plot(alpha=1, ax=axes[0, i], title=title) 33 | axes[0, i].set_xlabel(r'$\lambda$') 34 | # axes[0, i].set_ylabel(r'$g(\lambda)$') 35 | if i > 0: 36 | axes[0, i].set_ylabel('') 37 | y = g.filter(x) 38 | line, = axes[0, i].plot(G.e, G.gft(y)) 39 | labels = [r'$\hat{{f}}({})$'.format(t), r'$g_{{1,{}}}$'.format(t)] 40 | axes[0, i].legend([line, axes[0, i].lines[-3]], labels, loc='lower right') 41 | G.plot(y, edges=False, highlight=sources, ax=axes[1, i], title=r'$f({})$'.format(t)) 42 | axes[1, i].set_aspect('equal', 'box') 43 | axes[1, i].set_axis_off() 44 | 45 | fig.tight_layout() 46 | -------------------------------------------------------------------------------- /examples/kernel_localization.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Kernel localization 3 | =================== 4 | 5 | In classical signal processing, a filter can be translated in the vertex 6 | domain. We cannot do that on graphs. Instead, we can 7 | :meth:`~pygsp.filters.Filter.localize` a filter kernel. Note how on classic 8 | structures (like the ring), the localized kernel is the same everywhere, while 9 | it changes when localized on irregular graphs. 10 | """ 11 | 12 | import numpy as np 13 | from matplotlib import pyplot as plt 14 | import pygsp as pg 15 | 16 | fig, axes = plt.subplots(2, 4, figsize=(10, 4)) 17 | 18 | graphs = [ 19 | pg.graphs.Ring(40), 20 | pg.graphs.Sensor(64, seed=42), 21 | ] 22 | 23 | locations = [0, 10, 20] 24 | 25 | for graph, axs in zip(graphs, axes): 26 | graph.compute_fourier_basis() 27 | g = pg.filters.Heat(graph) 28 | g.plot(ax=axs[0], title='heat kernel') 29 | axs[0].set_xlabel(r'eigenvalues $\lambda$') 30 | axs[0].set_ylabel(r'$g(\lambda) = \exp \left( \frac{{-{}\lambda}}{{\lambda_{{max}}}} \right)$'.format(g.scale[0])) 31 | maximum = 0 32 | for loc in locations: 33 | x = g.localize(loc) 34 | maximum = np.maximum(maximum, x.max()) 35 | for loc, ax in zip(locations, axs[1:]): 36 | graph.plot(g.localize(loc), limits=[0, maximum], highlight=loc, ax=ax, 37 | title=r'$g(L) \delta_{{{}}}$'.format(loc)) 38 | ax.set_axis_off() 39 | 40 | fig.tight_layout() 41 | -------------------------------------------------------------------------------- /examples/playground.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Playing with the PyGSP\n", 8 | "" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "%matplotlib inline\n", 18 | "\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from pygsp import graphs, filters\n", 22 | "\n", 23 | "plt.rcParams['figure.figsize'] = (17, 5)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## 1 Example\n", 31 | "\n", 32 | "The following demonstrates how to instantiate a graph and a filter, the two main objects of the package." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "G = graphs.Logo()\n", 42 | "G.estimate_lmax()\n", 43 | "g = filters.Heat(G, tau=100)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "Let's now create a graph signal: a set of three Kronecker deltas for that example. We can now look at one step of heat diffusion by filtering the deltas with the above defined filter. Note how the diffusion follows the local structure!" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "DELTAS = [20, 30, 1090]\n", 60 | "s = np.zeros(G.N)\n", 61 | "s[DELTAS] = 1\n", 62 | "s = g.filter(s)\n", 63 | "G.plot(s, highlight=DELTAS, backend='matplotlib')" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## 2 Tutorials and examples\n", 71 | "\n", 72 | "Try our [tutorials](https://pygsp.readthedocs.io/en/stable/tutorials/index.html) or [examples](https://pygsp.readthedocs.io/en/stable/examples/index.html)." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# Your code here." 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "## 3 Playground\n", 89 | "\n", 90 | "Try something of your own!\n", 91 | "The [API reference](https://pygsp.readthedocs.io/en/stable/reference/index.html) is your friend." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# Your code here." 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "If you miss a package, you can install it with:" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "%pip install numpy" 117 | ] 118 | } 119 | ], 120 | "metadata": { 121 | "kernelspec": { 122 | "display_name": "Python 3", 123 | "language": "python", 124 | "name": "python3" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 4 129 | } 130 | -------------------------------------------------------------------------------- /examples/random_walk.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Random walks 3 | ============ 4 | 5 | Probability of a random walker to be on any given vertex after a given number 6 | of steps starting from a given distribution. 7 | """ 8 | 9 | # sphinx_gallery_thumbnail_number = 2 10 | 11 | import numpy as np 12 | from scipy import sparse 13 | from matplotlib import pyplot as plt 14 | import pygsp as pg 15 | 16 | N = 7 17 | steps = [0, 1, 2, 3] 18 | 19 | graph = pg.graphs.Grid2d(N) 20 | delta = np.zeros(graph.N) 21 | delta[N//2*N + N//2] = 1 22 | 23 | probability = sparse.diags(graph.dw**(-1)).dot(graph.W) 24 | 25 | fig, axes = plt.subplots(1, len(steps), figsize=(12, 3)) 26 | for step, ax in zip(steps, axes): 27 | state = (probability**step).__rmatmul__(delta) ## = delta @ probability**step 28 | graph.plot(state, ax=ax, title=r'$\delta P^{}$'.format(step)) 29 | ax.set_axis_off() 30 | 31 | fig.tight_layout() 32 | 33 | ############################################################################### 34 | # Stationary distribution. 35 | 36 | graphs = [ 37 | pg.graphs.Ring(10), 38 | pg.graphs.Grid2d(5), 39 | pg.graphs.Comet(8, 4), 40 | pg.graphs.BarabasiAlbert(20, seed=42), 41 | ] 42 | 43 | fig, axes = plt.subplots(1, len(graphs), figsize=(12, 3)) 44 | 45 | for graph, ax in zip(graphs, axes): 46 | 47 | if not hasattr(graph, 'coords'): 48 | graph.set_coordinates(seed=10) 49 | 50 | P = sparse.diags(graph.dw**(-1)).dot(graph.W) 51 | 52 | # e, u = np.linalg.eig(P.T.toarray()) 53 | # np.testing.assert_allclose(np.linalg.inv(u.T) @ np.diag(e) @ u.T, 54 | # P.toarray(), atol=1e-10) 55 | # np.testing.assert_allclose(np.abs(e[0]), 1) 56 | # stationary = np.abs(u.T[0]) 57 | 58 | e, u = sparse.linalg.eigs(P.T, k=1, which='LR') 59 | np.testing.assert_allclose(e, 1) 60 | stationary = np.abs(u).squeeze() 61 | assert np.all(stationary < 0.71) 62 | 63 | colorbar = False if type(graph) is pg.graphs.Ring else True 64 | graph.plot(stationary, colorbar=colorbar, ax=ax, title='$xP = x$') 65 | ax.set_axis_off() 66 | 67 | fig.tight_layout() 68 | -------------------------------------------------------------------------------- /examples/wave_propagation.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Wave propagation 3 | ================ 4 | 5 | Solve the wave equation by filtering the initial conditions with the wave 6 | kernel :class:`pygsp.filters.Wave`. 7 | """ 8 | 9 | from os import path 10 | 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | import pygsp as pg 14 | 15 | n_side = 13 16 | G = pg.graphs.Grid2d(n_side) 17 | G.compute_fourier_basis() 18 | 19 | sources = [ 20 | (n_side//4 * n_side) + (n_side//4), 21 | (n_side*3//4 * n_side) + (n_side*3//4), 22 | ] 23 | x = np.zeros(G.n_vertices) 24 | x[sources] = 5 25 | 26 | times = [0, 5, 10, 20] 27 | 28 | fig, axes = plt.subplots(2, len(times), figsize=(12, 5)) 29 | for i, t in enumerate(times): 30 | g = pg.filters.Wave(G, time=t, speed=1) 31 | title = r'$\hat{{f}}({0}) = g_{{1,{0}}} \odot \hat{{f}}(0)$'.format(t) 32 | g.plot(alpha=1, ax=axes[0, i], title=title) 33 | axes[0, i].set_xlabel(r'$\lambda$') 34 | # axes[0, i].set_ylabel(r'$g(\lambda)$') 35 | if i > 0: 36 | axes[0, i].set_ylabel('') 37 | y = g.filter(x) 38 | line, = axes[0, i].plot(G.e, G.gft(y)) 39 | labels = [r'$\hat{{f}}({})$'.format(t), r'$g_{{1,{}}}$'.format(t)] 40 | axes[0, i].legend([line, axes[0, i].lines[-3]], labels, loc='lower right') 41 | G.plot(y, edges=False, highlight=sources, ax=axes[1, i], title=r'$f({})$'.format(t)) 42 | axes[1, i].set_aspect('equal', 'box') 43 | axes[1, i].set_axis_off() 44 | 45 | fig.tight_layout() 46 | -------------------------------------------------------------------------------- /postBuild: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Tell https://mybinder.org to simply install the package. 3 | pip install .[dev] 4 | -------------------------------------------------------------------------------- /pygsp/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | r""" 4 | The :mod:`pygsp` package is mainly organized around the following two modules: 5 | 6 | * :mod:`.graphs` to create and manipulate various kinds of graphs, 7 | * :mod:`.filters` to create and manipulate various graph filters. 8 | 9 | Moreover, the following modules provide additional functionality: 10 | 11 | * :mod:`.plotting` to plot, 12 | * :mod:`.reduction` to reduce a graph while keeping its structure, 13 | * :mod:`.features` to compute features on graphs, 14 | * :mod:`.learning` to solve learning problems, 15 | * :mod:`.optimization` to help solving convex optimization problems, 16 | * :mod:`.utils` for various utilities. 17 | 18 | """ 19 | 20 | from . import graphs # noqa: F401 21 | from . import filters # noqa: F401 22 | from . import plotting # noqa: F401 23 | from . import reduction # noqa: F401 24 | from . import features # noqa: F401 25 | from . import learning # noqa: F401 26 | from . import optimization # noqa: F401 27 | from . import utils # noqa: F401 28 | 29 | # Users only call the plot methods from the objects. 30 | # It's thus more convenient for them to have the doc there. 31 | # But it's more convenient for developers to have the doc alongside the code. 32 | try: 33 | filters.Filter.plot.__doc__ = plotting._plot_filter.__doc__ 34 | graphs.Graph.plot.__doc__ = plotting._plot_graph.__doc__ 35 | graphs.Graph.plot_spectrogram.__doc__ = plotting._plot_spectrogram.__doc__ 36 | except AttributeError: 37 | # For Python 2.7. 38 | filters.Filter.plot.__func__.__doc__ = plotting._plot_filter.__doc__ 39 | graphs.Graph.plot.__func__.__doc__ = plotting._plot_graph.__doc__ 40 | graphs.Graph.plot_spectrogram.__func__.__doc__ = plotting._plot_spectrogram.__doc__ 41 | 42 | __version__ = '0.5.1' 43 | __release_date__ = '2017-12-15' 44 | 45 | 46 | def test(): # pragma: no cover 47 | """Run the test suite.""" 48 | import unittest 49 | # Lazy as it might be slow and require additional dependencies. 50 | from pygsp.tests import suite 51 | unittest.TextTestRunner(verbosity=2).run(suite) 52 | -------------------------------------------------------------------------------- /pygsp/data/pointclouds/airfoil.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/pointclouds/airfoil.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/bunny.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/pointclouds/bunny.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/david500.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/pointclouds/david500.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/david64.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/pointclouds/david64.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/logogsp.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/pointclouds/logogsp.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/minnesota.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/pointclouds/minnesota.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/two_moons.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/pointclouds/two_moons.mat -------------------------------------------------------------------------------- /pygsp/data/readme_example_filter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/readme_example_filter.png -------------------------------------------------------------------------------- /pygsp/data/readme_example_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/data/readme_example_graph.png -------------------------------------------------------------------------------- /pygsp/features.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | r""" 4 | The :mod:`pygsp.features` module implements different feature extraction 5 | techniques based on :mod:`pygsp.graphs` and :mod:`pygsp.filters`. 6 | """ 7 | 8 | import numpy as np 9 | 10 | from pygsp import filters, utils 11 | 12 | 13 | def compute_avg_adj_deg(G): 14 | r""" 15 | Compute the average adjacency degree for each node. 16 | 17 | The average adjacency degree is the average of the degrees of a node and 18 | its neighbors. 19 | 20 | Parameters 21 | ---------- 22 | G: Graph 23 | Graph on which the statistic is extracted 24 | """ 25 | return np.sum(np.dot(G.A, G.A), axis=1) / (np.sum(G.A, axis=1) + 1.) 26 | 27 | 28 | @utils.filterbank_handler 29 | def compute_tig(g, **kwargs): 30 | r""" 31 | Compute the Tig for a given filter or filter bank. 32 | 33 | .. math:: T_ig(n) = g(L)_{i, n} 34 | 35 | Parameters 36 | ---------- 37 | g: Filter 38 | One of :mod:`pygsp.filters`. 39 | kwargs: dict 40 | Additional parameters to be passed to the 41 | :func:`pygsp.filters.Filter.filter` method. 42 | """ 43 | return g.compute_frame() 44 | 45 | 46 | @utils.filterbank_handler 47 | def compute_norm_tig(g, **kwargs): 48 | r""" 49 | Compute the :math:`\ell_2` norm of the Tig. 50 | See :func:`compute_tig`. 51 | 52 | Parameters 53 | ---------- 54 | g: Filter 55 | The filter or filter bank. 56 | kwargs: dict 57 | Additional parameters to be passed to the 58 | :func:`pygsp.filters.Filter.filter` method. 59 | """ 60 | tig = compute_tig(g, **kwargs) 61 | return np.linalg.norm(tig, axis=1, ord=2) 62 | 63 | 64 | def compute_spectrogram(G, atom=None, M=100, **kwargs): 65 | r""" 66 | Compute the norm of the Tig for all nodes with a kernel shifted along the 67 | spectral axis. 68 | 69 | Parameters 70 | ---------- 71 | G : Graph 72 | Graph on which to compute the spectrogram. 73 | atom : func 74 | Kernel to use in the spectrogram (default = exp(-M*(x/lmax)²)). 75 | M : int (optional) 76 | Number of samples on the spectral scale. (default = 100) 77 | kwargs: dict 78 | Additional parameters to be passed to the 79 | :func:`pygsp.filters.Filter.filter` method. 80 | """ 81 | 82 | if not atom: 83 | def atom(x): 84 | return np.exp(-M * (x / G.lmax)**2) 85 | 86 | scale = np.linspace(0, G.lmax, M) 87 | spectr = np.empty((G.N, M)) 88 | 89 | for shift_idx in range(M): 90 | shift_filter = filters.Filter(G, lambda x: atom(x - scale[shift_idx])) 91 | tig = compute_norm_tig(shift_filter, **kwargs).squeeze()**2 92 | spectr[:, shift_idx] = tig 93 | 94 | G.spectr = spectr 95 | return spectr 96 | -------------------------------------------------------------------------------- /pygsp/filters/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | r""" 4 | The :mod:`pygsp.filters` module implements methods used for filtering and 5 | defines commonly used filters that can be applied to :mod:`pygsp.graphs`. A 6 | filter is associated to a graph and is defined with one or several functions. 7 | We define by filter bank a list of filters, usually centered around different 8 | frequencies, applied to a single graph. 9 | 10 | Interface 11 | --------- 12 | 13 | The :class:`Filter` base class implements a common interface to all filters: 14 | 15 | .. autosummary:: 16 | 17 | Filter.evaluate 18 | Filter.filter 19 | Filter.analyze 20 | Filter.synthesize 21 | Filter.complement 22 | Filter.inverse 23 | Filter.compute_frame 24 | Filter.estimate_frame_bounds 25 | Filter.plot 26 | Filter.localize 27 | 28 | Filters 29 | ------- 30 | 31 | Then, derived classes implement various common graph filters. 32 | 33 | **Filters that solve differential equations** 34 | 35 | The following filters solve partial differential equations (PDEs) on graphs, 36 | which model processes such as heat diffusion or wave propagation. 37 | 38 | .. autosummary:: 39 | 40 | Heat 41 | Wave 42 | 43 | **Low-pass filters** 44 | 45 | .. autosummary:: 46 | 47 | Heat 48 | 49 | **Band-pass filters** 50 | 51 | These filters can be configured to be low-pass, high-pass, or band-pass. 52 | 53 | .. autosummary:: 54 | 55 | Expwin 56 | Rectangular 57 | 58 | **Filter banks of two filters: a low-pass and a high-pass** 59 | 60 | .. autosummary:: 61 | 62 | Regular 63 | Held 64 | Simoncelli 65 | Papadakis 66 | 67 | **Filter banks composed of dilated or translated filters** 68 | 69 | .. autosummary:: 70 | 71 | Abspline 72 | HalfCosine 73 | Itersine 74 | MexicanHat 75 | Meyer 76 | SimpleTight 77 | 78 | **Filter banks for vertex-frequency analyzes** 79 | 80 | Those filter banks are composed of shifted versions of a mother filter, one per 81 | graph frequency (Laplacian eigenvalue). They can analyze frequency content 82 | locally, as a windowed graph Fourier transform. 83 | 84 | .. autosummary:: 85 | 86 | Gabor 87 | Modulation 88 | 89 | Approximations 90 | -------------- 91 | 92 | Moreover, two approximation methods are provided for fast filtering. The 93 | computational complexity of filtering with those approximations is linear with 94 | the number of edges. The complexity of the exact solution, which is to use the 95 | Fourier basis, is quadratic with the number of nodes (without taking into 96 | account the cost of the necessary eigendecomposition of the graph Laplacian). 97 | 98 | **Chebyshev polynomials** 99 | 100 | .. autosummary:: 101 | 102 | compute_cheby_coeff 103 | compute_jackson_cheby_coeff 104 | cheby_op 105 | cheby_rect 106 | 107 | **Lanczos algorithm** 108 | 109 | .. autosummary:: 110 | 111 | lanczos 112 | lanczos_op 113 | 114 | """ 115 | 116 | from .filter import Filter # noqa: F401 117 | from .abspline import Abspline # noqa: F401 118 | from .expwin import Expwin # noqa: F401 119 | from .gabor import Gabor # noqa: F401 120 | from .halfcosine import HalfCosine # noqa: F401 121 | from .heat import Heat # noqa: F401 122 | from .held import Held # noqa: F401 123 | from .itersine import Itersine # noqa: F401 124 | from .mexicanhat import MexicanHat # noqa: F401 125 | from .meyer import Meyer # noqa: F401 126 | from .modulation import Modulation # noqa: F401 127 | from .papadakis import Papadakis # noqa: F401 128 | from .rectangular import Rectangular # noqa: F401 129 | from .regular import Regular # noqa: F401 130 | from .simoncelli import Simoncelli # noqa: F401 131 | from .simpletight import SimpleTight # noqa: F401 132 | from .wave import Wave # noqa: F401 133 | 134 | from .approximations import compute_cheby_coeff # noqa: F401 135 | from .approximations import compute_jackson_cheby_coeff # noqa: F401 136 | from .approximations import cheby_op # noqa: F401 137 | from .approximations import cheby_rect # noqa: F401 138 | from .approximations import lanczos # noqa: F401 139 | from .approximations import lanczos_op # noqa: F401 140 | -------------------------------------------------------------------------------- /pygsp/filters/abspline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | from scipy import optimize 7 | 8 | from pygsp import utils 9 | from . import Filter # prevent circular import in Python < 3.5 10 | 11 | 12 | class Abspline(Filter): 13 | r"""Design an A B cubic spline wavelet filter bank. 14 | 15 | Parameters 16 | ---------- 17 | G : graph 18 | Nf : int 19 | Number of filters from 0 to lmax (default = 6) 20 | lpfactor : float 21 | Low-pass factor lmin=lmax/lpfactor will be used to determine scales, 22 | the scaling function will be created to fill the lowpass gap. 23 | (default = 20) 24 | scales : ndarray 25 | Vector of scales to be used. 26 | By default, initialized with :func:`pygsp.utils.compute_log_scales`. 27 | 28 | Examples 29 | -------- 30 | 31 | Filter bank's representation in Fourier and time (ring graph) domains. 32 | 33 | >>> import matplotlib.pyplot as plt 34 | >>> G = graphs.Ring(N=20) 35 | >>> G.estimate_lmax() 36 | >>> G.set_coordinates('line1D') 37 | >>> g = filters.Abspline(G) 38 | >>> s = g.localize(G.N // 2) 39 | >>> fig, axes = plt.subplots(1, 2) 40 | >>> _ = g.plot(ax=axes[0]) 41 | >>> _ = G.plot(s, ax=axes[1]) 42 | 43 | """ 44 | 45 | def __init__(self, G, Nf=6, lpfactor=20, scales=None): 46 | 47 | def kernel_abspline3(x, alpha, beta, t1, t2): 48 | M = np.array([[1, t1, t1**2, t1**3], 49 | [1, t2, t2**2, t2**3], 50 | [0, 1, 2*t1, 3*t1**2], 51 | [0, 1, 2*t2, 3*t2**2]]) 52 | v = np.array([1, 1, t1**(-alpha) * alpha * t1**(alpha - 1), 53 | -beta*t2**(- beta - 1) * t2**beta]) 54 | a = np.linalg.solve(M, v) 55 | 56 | r1 = x <= t1 57 | r2 = (x >= t1)*(x < t2) 58 | r3 = (x >= t2) 59 | 60 | if isinstance(x, np.float64): 61 | 62 | if r1: 63 | r = x[r1]**alpha * t1**(-alpha) 64 | if r2: 65 | r = a[0] + a[1] * x + a[2] * x**2 + a[3] * x**3 66 | if r3: 67 | r = x[r3]**(-beta) * t2**beta 68 | 69 | else: 70 | r = np.zeros(x.shape) 71 | 72 | x2 = x[r2] 73 | 74 | r[r1] = x[r1]**alpha * t1**(-alpha) 75 | r[r2] = a[0] + a[1] * x2 + a[2] * x2**2 + a[3] * x2**3 76 | r[r3] = x[r3]**(-beta) * t2 ** beta 77 | 78 | return r 79 | 80 | self.lpfactor = lpfactor 81 | 82 | lmin = G.lmax / lpfactor 83 | 84 | if scales is None: 85 | scales = utils.compute_log_scales(lmin, G.lmax, Nf - 1) 86 | self.scales = scales 87 | 88 | gb = lambda x: kernel_abspline3(x, 2, 2, 1, 2) 89 | gl = lambda x: np.exp(-np.power(x, 4)) 90 | 91 | lminfac = .4 * lmin 92 | 93 | g = [lambda x: 1.2 * np.exp(-1) * gl(x / lminfac)] 94 | for i in range(0, Nf - 1): 95 | g.append(lambda x, i=i: gb(self.scales[i] * x)) 96 | 97 | f = lambda x: -gb(x) 98 | xstar = optimize.minimize_scalar(f, bounds=(1, 2), 99 | method='bounded') 100 | gamma_l = -f(xstar.x) 101 | lminfac = .6 * lmin 102 | g[0] = lambda x: gamma_l * gl(x / lminfac) 103 | 104 | super(Abspline, self).__init__(G, g) 105 | 106 | def _get_extra_repr(self): 107 | return dict(lpfactor='{:.2f}'.format(self.lpfactor)) 108 | -------------------------------------------------------------------------------- /pygsp/filters/expwin.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class Expwin(Filter): 11 | r"""Design an exponential window filter. 12 | 13 | The window has the shape of a (half) triangle with the corners smoothed by 14 | an exponential function. 15 | 16 | Parameters 17 | ---------- 18 | G : graph 19 | band_min : float 20 | Minimum relative band. The filter evaluates at 0.5 at this frequency. 21 | Zero corresponds to the smallest eigenvalue (which is itself equal to 22 | zero), one corresponds to the largest eigenvalue. 23 | If None, the filter is high-pass. 24 | band_max : float 25 | Maximum relative band. The filter evaluates at 0.5 at this frequency. 26 | If None, the filter is low-pass. 27 | slope : float 28 | The slope at cut-off. 29 | 30 | Examples 31 | -------- 32 | 33 | Filter bank's representation in Fourier and time (ring graph) domains. 34 | 35 | >>> import matplotlib.pyplot as plt 36 | >>> G = graphs.Ring(N=20) 37 | >>> G.estimate_lmax() 38 | >>> G.set_coordinates('line1D') 39 | >>> g = filters.Expwin(G, band_min=0.1, band_max=0.7, slope=5) 40 | >>> s = g.localize(G.N // 2) 41 | >>> fig, axes = plt.subplots(1, 2) 42 | >>> _ = g.plot(ax=axes[0]) 43 | >>> _ = G.plot(s, ax=axes[1]) 44 | 45 | """ 46 | 47 | def __init__(self, G, band_min=None, band_max=0.2, slope=1): 48 | 49 | self.band_min = band_min 50 | self.band_max = band_max 51 | self.slope = slope 52 | 53 | def exp(x): 54 | """Exponential function with canary to avoid division by zero and 55 | overflow.""" 56 | y = np.where(x <= 0, -1, x) 57 | y = np.exp(-slope / y) 58 | return np.where(x <= 0, 0, y) 59 | 60 | def h(x): 61 | y = exp(x) 62 | z = exp(1 - x) 63 | return y / (y + z) 64 | 65 | def kernel_lowpass(x): 66 | return h(0.5 - x/G.lmax + band_max) 67 | def kernel_highpass(x): 68 | return h(0.5 + x/G.lmax - band_min) 69 | 70 | if (band_min is None) and (band_max is None): 71 | kernel = lambda x: np.ones_like(x) 72 | elif band_min is None: 73 | kernel = kernel_lowpass 74 | elif band_max is None: 75 | kernel = kernel_highpass 76 | else: 77 | kernel = lambda x: kernel_lowpass(x) * kernel_highpass(x) 78 | 79 | super(Expwin, self).__init__(G, kernel) 80 | 81 | def _get_extra_repr(self): 82 | attrs = dict() 83 | if self.band_min is not None: 84 | attrs.update(band_min='{:.2f}'.format(self.band_min)) 85 | if self.band_max is not None: 86 | attrs.update(band_max='{:.2f}'.format(self.band_max)) 87 | attrs.update(slope='{:.0f}'.format(self.slope)) 88 | return attrs 89 | -------------------------------------------------------------------------------- /pygsp/filters/gabor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from pygsp import utils 4 | from . import Filter # prevent circular import in Python < 3.5 5 | 6 | 7 | class Gabor(Filter): 8 | r"""Design a filter bank with a kernel centered at each frequency. 9 | 10 | Design a filter bank from translated versions of a mother filter. 11 | The mother filter is translated to each eigenvalue of the Laplacian. 12 | That is equivalent to convolutions with deltas placed at those eigenvalues. 13 | 14 | In classical image processing, a Gabor filter is a sinusoidal wave 15 | multiplied by a Gaussian function (here, the kernel). It analyzes whether 16 | there are any specific frequency content in the image in specific 17 | directions in a localized region around the point of analysis. This 18 | implementation for graph signals allows arbitrary (but isotropic) kernels. 19 | 20 | This filter bank can be used to compute the frequency content of a signal 21 | at each vertex. After filtering, one obtains a vertex-frequency 22 | representation :math:`Sf(i,k)` of a signal :math:`f` as 23 | 24 | .. math:: Sf(i, k) = \langle g_{i,k}, f \rangle, 25 | 26 | where :math:`g_{i,k}` is the mother kernel centered on eigenvalue 27 | :math:`\lambda_k` and localized on vertex :math:`v_i`. 28 | 29 | While :math:`g_{i,k}` should ideally be localized in both the spectral and 30 | vertex domains, that is impossible for some graphs due to the localization 31 | of some eigenvectors. See :attr:`pygsp.graphs.Graph.coherence`. 32 | 33 | Parameters 34 | ---------- 35 | graph : :class:`pygsp.graphs.Graph` 36 | kernel : :class:`pygsp.filters.Filter` 37 | Kernel function to be centered at each graph frequency (eigenvalue of 38 | the graph Laplacian). 39 | 40 | See Also 41 | -------- 42 | Modulation : Another way to translate a filter in the spectral domain. 43 | 44 | Notes 45 | ----- 46 | The eigenvalues of the graph Laplacian (i.e., the Fourier basis) are needed 47 | to center the kernels. 48 | 49 | Examples 50 | -------- 51 | 52 | Filter bank's representation in Fourier and time (path graph) domains. 53 | 54 | >>> import matplotlib.pyplot as plt 55 | >>> G = graphs.Path(N=7) 56 | >>> G.compute_fourier_basis() 57 | >>> G.set_coordinates('line1D') 58 | >>> 59 | >>> g1 = filters.Expwin(G, band_min=None, band_max=0, slope=3) 60 | >>> g2 = filters.Rectangular(G, band_min=-0.05, band_max=0.05) 61 | >>> g3 = filters.Heat(G, scale=10) 62 | >>> 63 | >>> fig, axes = plt.subplots(3, 2, figsize=(10, 10)) 64 | >>> for g, ax in zip([g1, g2, g3], axes): 65 | ... g = filters.Gabor(G, g) 66 | ... s = g.localize(G.N // 2, method='exact') 67 | ... _ = g.plot(ax=ax[0], sum=False) 68 | ... _ = G.plot(s, ax=ax[1]) 69 | >>> fig.tight_layout() 70 | 71 | """ 72 | 73 | def __init__(self, graph, kernel): 74 | 75 | if kernel.n_filters != 1: 76 | raise ValueError('A kernel must be one filter. The passed ' 77 | 'filter bank {} has {}.'.format( 78 | kernel, kernel.n_filters)) 79 | if kernel.G is not graph: 80 | raise ValueError('The graph passed to this filter bank must ' 81 | 'be the one used to build the mother kernel.') 82 | 83 | kernels = [] 84 | for i in range(graph.n_vertices): 85 | kernels.append(lambda x, i=i: kernel.evaluate(x - graph.e[i])) 86 | 87 | super(Gabor, self).__init__(graph, kernels) 88 | 89 | def filter(self, s, method='exact', order=None): 90 | """TODO: indirection will be removed when poly filtering is merged.""" 91 | return super(Gabor, self).filter(s, method='exact') 92 | -------------------------------------------------------------------------------- /pygsp/filters/halfcosine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class HalfCosine(Filter): 11 | r"""Design an half cosine filter bank (tight frame). 12 | 13 | Parameters 14 | ---------- 15 | G : graph 16 | Nf : int 17 | Number of filters from 0 to lmax (default = 6) 18 | 19 | Examples 20 | -------- 21 | 22 | Filter bank's representation in Fourier and time (ring graph) domains. 23 | 24 | >>> import matplotlib.pyplot as plt 25 | >>> G = graphs.Ring(N=20) 26 | >>> G.estimate_lmax() 27 | >>> G.set_coordinates('line1D') 28 | >>> g = filters.HalfCosine(G) 29 | >>> s = g.localize(G.N // 2) 30 | >>> fig, axes = plt.subplots(1, 2) 31 | >>> _ = g.plot(ax=axes[0]) 32 | >>> _ = G.plot(s, ax=axes[1]) 33 | 34 | """ 35 | 36 | def __init__(self, G, Nf=6): 37 | 38 | if Nf <= 2: 39 | raise ValueError('The number of filters must be greater than 2.') 40 | 41 | dila_fact = G.lmax * 3 / (Nf - 2) 42 | 43 | def kernel(x): 44 | y = np.cos(2 * np.pi * (x / dila_fact - .5)) 45 | y = np.multiply((.5 + .5*y), (x >= 0)) 46 | return np.multiply(y, (x <= dila_fact)) 47 | 48 | kernels = [] 49 | 50 | for i in range(Nf): 51 | 52 | def kernel_centered(x, i=i): 53 | return kernel(x - dila_fact/3 * (i - 2)) 54 | 55 | kernels.append(kernel_centered) 56 | 57 | super(HalfCosine, self).__init__(G, kernels) 58 | -------------------------------------------------------------------------------- /pygsp/filters/heat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class Heat(Filter): 11 | r"""Design a filter bank of heat kernels. 12 | 13 | The (low-pass) heat kernel is defined in the spectral domain as 14 | 15 | .. math:: g_\tau(\lambda) = \exp(-\tau \lambda), 16 | 17 | where :math:`\lambda \in [0, 1]` are the normalized eigenvalues of the 18 | graph Laplacian, and :math:`\tau` is a parameter that captures both time 19 | and thermal diffusivity. 20 | 21 | The heat kernel is the fundamental solution to the heat equation 22 | 23 | .. math:: - \tau L f(t) = \partial_t f(t), 24 | 25 | where :math:`f: \mathbb{R}_+ \rightarrow \mathbb{R}^N` is the heat 26 | distribution over the graph at time :math:`t`. Given the initial condition 27 | :math:`f(0)`, the solution of the heat equation is expressed as 28 | 29 | .. math:: f(t) = e^{-\tau t L} f(0) 30 | = U e^{-\tau t \Lambda} U^\top f(0) 31 | = g_{\tau t}(L) f(0). 32 | 33 | The above is, by definition, the convolution of the signal :math:`f(0)` 34 | with the kernel :math:`g_{\tau t}(\lambda) = \exp(-\tau t \lambda)`. 35 | Hence, applying this filter to a signal simulates heat diffusion. 36 | 37 | Since the kernel is applied to the graph eigenvalues :math:`\lambda`, which 38 | can be interpreted as squared frequencies, it can also be considered as a 39 | generalization of the Gaussian kernel on graphs. 40 | 41 | Parameters 42 | ---------- 43 | G : graph 44 | scale : float or iterable 45 | Scaling parameter. When solving heat diffusion, it encompasses both 46 | time and thermal diffusivity. 47 | If iterable, creates a filter bank with one filter per value. 48 | normalize : bool 49 | Whether to normalize the kernel to have unit L2 norm. 50 | The normalization needs the eigenvalues of the graph Laplacian. 51 | 52 | Examples 53 | -------- 54 | 55 | Filter bank's representation in Fourier and time (ring graph) domains. 56 | 57 | >>> import matplotlib.pyplot as plt 58 | >>> G = graphs.Ring(N=20) 59 | >>> G.estimate_lmax() 60 | >>> G.set_coordinates('line1D') 61 | >>> g = filters.Heat(G, scale=[5, 10, 100]) 62 | >>> s = g.localize(G.N // 2) 63 | >>> fig, axes = plt.subplots(1, 2) 64 | >>> _ = g.plot(ax=axes[0]) 65 | >>> _ = G.plot(s, ax=axes[1]) 66 | 67 | Heat diffusion from two sources on a grid. 68 | 69 | >>> import matplotlib.pyplot as plt 70 | >>> n_side = 11 71 | >>> graph = graphs.Grid2d(n_side) 72 | >>> graph.estimate_lmax() 73 | >>> sources = [ 74 | ... (n_side//4 * n_side) + (n_side//4), 75 | ... (n_side*3//4 * n_side) + (n_side*3//4), 76 | ... ] 77 | >>> delta = np.zeros(graph.n_vertices) 78 | >>> delta[sources] = 5 79 | >>> steps = np.array([1, 5]) 80 | >>> diffusivity = 10 81 | >>> g = filters.Heat(graph, scale=diffusivity*steps) 82 | >>> diffused = g.filter(delta) 83 | >>> fig, axes = plt.subplots(1, len(steps), figsize=(10, 4)) 84 | >>> _ = fig.suptitle('Heat diffusion', fontsize=16) 85 | >>> for i, ax in enumerate(axes): 86 | ... _ = graph.plot(diffused[:, i], highlight=sources, 87 | ... title='step {}'.format(steps[i]), ax=ax) 88 | ... ax.set_aspect('equal', 'box') 89 | ... ax.set_axis_off() 90 | 91 | Normalized heat kernel. 92 | 93 | >>> G = graphs.Logo() 94 | >>> G.compute_fourier_basis() 95 | >>> g = filters.Heat(G, scale=5) 96 | >>> y = g.evaluate(G.e) 97 | >>> print('norm: {:.2f}'.format(np.linalg.norm(y[0]))) 98 | norm: 9.76 99 | >>> g = filters.Heat(G, scale=5, normalize=True) 100 | >>> y = g.evaluate(G.e) 101 | >>> print('norm: {:.2f}'.format(np.linalg.norm(y[0]))) 102 | norm: 1.00 103 | 104 | """ 105 | 106 | def __init__(self, G, scale=10, normalize=False): 107 | 108 | try: 109 | iter(scale) 110 | except TypeError: 111 | scale = [scale] 112 | 113 | self.scale = scale 114 | self.normalize = normalize 115 | 116 | def kernel(x, scale): 117 | return np.minimum(np.exp(-scale * x / G.lmax), 1) 118 | 119 | kernels = [] 120 | for s in scale: 121 | norm = np.linalg.norm(kernel(G.e, s)) if normalize else 1 122 | kernels.append(lambda x, s=s, norm=norm: kernel(x, s) / norm) 123 | 124 | super(Heat, self).__init__(G, kernels) 125 | 126 | def _get_extra_repr(self): 127 | scale = '[' + ', '.join('{:.2f}'.format(s) for s in self.scale) + ']' 128 | return dict(scale=scale, normalize=self.normalize) 129 | -------------------------------------------------------------------------------- /pygsp/filters/held.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class Held(Filter): 11 | r"""Design 2 filters with the Held construction (tight frame). 12 | 13 | This function create a parseval filterbank of :math:`2` filters. 14 | The low-pass filter is defined by the function 15 | 16 | .. math:: f_{l}=\begin{cases} 1 & \mbox{if }x\leq a\\ 17 | \sin\left(2\pi\mu\left(\frac{x}{8a}\right)\right) & \mbox{if }a2a \end{cases} 19 | 20 | with 21 | 22 | .. math:: \mu(x) = -1+24x-144*x^2+256*x^3 23 | 24 | The high pass filter is adapted to obtain a tight frame. 25 | 26 | Parameters 27 | ---------- 28 | G : graph 29 | a : float 30 | See equation above for this parameter 31 | The spectrum is scaled between 0 and 2 (default = 2/3) 32 | 33 | Examples 34 | -------- 35 | 36 | Filter bank's representation in Fourier and time (ring graph) domains. 37 | 38 | >>> import matplotlib.pyplot as plt 39 | >>> G = graphs.Ring(N=20) 40 | >>> G.estimate_lmax() 41 | >>> G.set_coordinates('line1D') 42 | >>> g = filters.Held(G) 43 | >>> s = g.localize(G.N // 2) 44 | >>> fig, axes = plt.subplots(1, 2) 45 | >>> _ = g.plot(ax=axes[0]) 46 | >>> _ = G.plot(s, ax=axes[1]) 47 | 48 | """ 49 | 50 | def __init__(self, G, a=2./3): 51 | 52 | self.a = a 53 | 54 | def kernel(x, a): 55 | y = np.empty(np.shape(x)) 56 | l1 = a 57 | l2 = 2 * a 58 | 59 | r1ind = (x >= 0) * (x < l1) 60 | r2ind = (x >= l1) * (x < l2) 61 | r3ind = (x >= l2) 62 | 63 | def mu(x): 64 | return -1 + 24*x - 144*x**2 + 256*x**3 65 | 66 | y[r1ind] = 1 67 | y[r2ind] = np.sin(2 * np.pi * mu(x[r2ind] / 8 / a)) 68 | y[r3ind] = 0 69 | 70 | return y 71 | 72 | held = Filter(G, lambda x: kernel(x*2/G.lmax, a)) 73 | complement = held.complement(frame_bound=1) 74 | kernels = held._kernels + complement._kernels 75 | 76 | super(Held, self).__init__(G, kernels) 77 | 78 | def _get_extra_repr(self): 79 | return dict(a='{:.2f}'.format(self.a)) 80 | -------------------------------------------------------------------------------- /pygsp/filters/itersine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class Itersine(Filter): 11 | r"""Design an itersine filter bank (tight frame). 12 | 13 | Create an itersine half overlap filter bank of Nf filters. 14 | Going from 0 to lambda_max. 15 | 16 | Parameters 17 | ---------- 18 | G : graph 19 | Nf : int (optional) 20 | Number of filters from 0 to lmax. (default = 6) 21 | overlap : int (optional) 22 | (default = 2) 23 | 24 | Examples 25 | -------- 26 | 27 | Filter bank's representation in Fourier and time (ring graph) domains. 28 | 29 | >>> import matplotlib.pyplot as plt 30 | >>> G = graphs.Ring(N=20) 31 | >>> G.estimate_lmax() 32 | >>> G.set_coordinates('line1D') 33 | >>> g = filters.Itersine(G) 34 | >>> s = g.localize(G.N // 2) 35 | >>> fig, axes = plt.subplots(1, 2) 36 | >>> _ = g.plot(ax=axes[0]) 37 | >>> _ = G.plot(s, ax=axes[1]) 38 | 39 | """ 40 | 41 | def __init__(self, G, Nf=6, overlap=2): 42 | 43 | self.overlap = overlap 44 | self.mu = np.linspace(0, G.lmax, num=Nf) 45 | 46 | scales = G.lmax / (Nf - overlap + 1) * overlap 47 | 48 | def kernel(x): 49 | y = np.cos(x * np.pi)**2 50 | y = np.sin(0.5 * np.pi * y) 51 | return y * ((x >= -0.5) * (x <= 0.5)) 52 | 53 | kernels = [] 54 | for i in range(1, Nf + 1): 55 | 56 | def kernel_centered(x, i=i): 57 | y = kernel(x / scales - (i - overlap / 2) / overlap) 58 | return y * np.sqrt(2 / overlap) 59 | 60 | kernels.append(kernel_centered) 61 | 62 | super(Itersine, self).__init__(G, kernels) 63 | 64 | def _get_extra_repr(self): 65 | return dict(overlap='{:.2f}'.format(self.overlap)) 66 | -------------------------------------------------------------------------------- /pygsp/filters/mexicanhat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from pygsp import utils 8 | from . import Filter # prevent circular import in Python < 3.5 9 | 10 | 11 | class MexicanHat(Filter): 12 | r"""Design a filter bank of Mexican hat wavelets. 13 | 14 | The Mexican hat wavelet is the second oder derivative of a Gaussian. Since 15 | we express the filter in the Fourier domain, we find: 16 | 17 | .. math:: \hat{g}_b(x) = x * \exp(-x) 18 | 19 | for the band-pass filter. Note that in our convention the eigenvalues of 20 | the Laplacian are equivalent to the square of graph frequencies, 21 | i.e. :math:`x = \lambda^2`. 22 | 23 | The low-pass filter is given by 24 | 25 | .. math: \hat{g}_l(x) = \exp(-x^4). 26 | 27 | Parameters 28 | ---------- 29 | G : graph 30 | Nf : int 31 | Number of filters to cover the interval [0, lmax]. 32 | lpfactor : float 33 | Low-pass factor. lmin=lmax/lpfactor will be used to determine scales. 34 | The scaling function will be created to fill the low-pass gap. 35 | scales : array_like 36 | Scales to be used. 37 | By default, initialized with :func:`pygsp.utils.compute_log_scales`. 38 | normalize : bool 39 | Whether to normalize the wavelet by the factor ``sqrt(scales)``. 40 | 41 | Examples 42 | -------- 43 | 44 | Filter bank's representation in Fourier and time (ring graph) domains. 45 | 46 | >>> import matplotlib.pyplot as plt 47 | >>> G = graphs.Ring(N=20) 48 | >>> G.estimate_lmax() 49 | >>> G.set_coordinates('line1D') 50 | >>> g = filters.MexicanHat(G) 51 | >>> s = g.localize(G.N // 2) 52 | >>> fig, axes = plt.subplots(1, 2) 53 | >>> _ = g.plot(ax=axes[0]) 54 | >>> _ = G.plot(s, ax=axes[1]) 55 | 56 | """ 57 | 58 | def __init__(self, G, Nf=6, lpfactor=20, scales=None, normalize=False): 59 | 60 | self.lpfactor = lpfactor 61 | self.normalize = normalize 62 | 63 | lmin = G.lmax / lpfactor 64 | 65 | if scales is None: 66 | scales = utils.compute_log_scales(lmin, G.lmax, Nf-1) 67 | self.scales = scales 68 | 69 | if len(scales) != Nf - 1: 70 | raise ValueError('len(scales) should be Nf-1.') 71 | 72 | def band_pass(x): 73 | return x * np.exp(-x) 74 | 75 | def low_pass(x): 76 | return np.exp(-x**4) 77 | 78 | kernels = [lambda x: 1.2 * np.exp(-1) * low_pass(x / 0.4 / lmin)] 79 | 80 | for i in range(Nf - 1): 81 | 82 | def kernel(x, i=i): 83 | norm = np.sqrt(scales[i]) if normalize else 1 84 | return norm * band_pass(scales[i] * x) 85 | 86 | kernels.append(kernel) 87 | 88 | super(MexicanHat, self).__init__(G, kernels) 89 | 90 | def _get_extra_repr(self): 91 | return dict(lpfactor='{:.2f}'.format(self.lpfactor), 92 | normalize=self.normalize) 93 | -------------------------------------------------------------------------------- /pygsp/filters/meyer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp import utils 6 | from . import Filter # prevent circular import in Python < 3.5 7 | 8 | 9 | class Meyer(Filter): 10 | r"""Design a filter bank of Meyer wavelets (tight frame). 11 | 12 | Parameters 13 | ---------- 14 | G : graph 15 | Nf : int 16 | Number of filters from 0 to lmax (default = 6). 17 | scales : ndarray 18 | Vector of scales to be used (default: log scale). 19 | 20 | References 21 | ---------- 22 | Use of this kernel for SGWT proposed by Nora Leonardi and Dimitri Van De 23 | Ville in :cite:`leonardi2011wavelet`. 24 | 25 | Examples 26 | -------- 27 | 28 | Filter bank's representation in Fourier and time (ring graph) domains. 29 | 30 | >>> import matplotlib.pyplot as plt 31 | >>> G = graphs.Ring(N=20) 32 | >>> G.estimate_lmax() 33 | >>> G.set_coordinates('line1D') 34 | >>> g = filters.Meyer(G) 35 | >>> s = g.localize(G.N // 2) 36 | >>> fig, axes = plt.subplots(1, 2) 37 | >>> _ = g.plot(ax=axes[0]) 38 | >>> _ = G.plot(s, ax=axes[1]) 39 | 40 | """ 41 | 42 | def __init__(self, G, Nf=6, scales=None): 43 | 44 | if scales is None: 45 | scales = (4./(3 * G.lmax)) * np.power(2., np.arange(Nf-2, -1, -1)) 46 | self.scales = scales 47 | 48 | if len(scales) != Nf - 1: 49 | raise ValueError('len(scales) should be Nf-1.') 50 | 51 | kernels = [lambda x: kernel(scales[0] * x, 'scaling_function')] 52 | 53 | for i in range(Nf - 1): 54 | kernels.append(lambda x, i=i: kernel(scales[i] * x, 'wavelet')) 55 | 56 | def kernel(x, kernel_type): 57 | r""" 58 | Evaluates Meyer function and scaling function 59 | 60 | * meyer wavelet kernel: supported on [2/3,8/3] 61 | * meyer scaling function kernel: supported on [0,4/3] 62 | """ 63 | 64 | x = np.asanyarray(x) 65 | 66 | l1 = 2/3. 67 | l2 = 4/3. # 2*l1 68 | l3 = 8/3. # 4*l1 69 | 70 | def v(x): 71 | return x**4 * (35 - 84*x + 70*x**2 - 20*x**3) 72 | 73 | r1ind = (x < l1) 74 | r2ind = (x >= l1) * (x < l2) 75 | r3ind = (x >= l2) * (x < l3) 76 | 77 | # as we initialize r with zero, computed function will implicitly 78 | # be zero for all x not in one of the three regions defined above 79 | r = np.zeros(x.shape) 80 | if kernel_type == 'scaling_function': 81 | r[r1ind] = 1 82 | r[r2ind] = np.cos((np.pi/2) * v(np.abs(x[r2ind])/l1 - 1)) 83 | elif kernel_type == 'wavelet': 84 | r[r2ind] = np.sin((np.pi/2) * v(np.abs(x[r2ind])/l1 - 1)) 85 | r[r3ind] = np.cos((np.pi/2) * v(np.abs(x[r3ind])/l2 - 1)) 86 | else: 87 | raise ValueError('Unknown kernel type {}'.format(kernel_type)) 88 | 89 | return r 90 | 91 | super(Meyer, self).__init__(G, kernels) 92 | -------------------------------------------------------------------------------- /pygsp/filters/modulation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | from . import Filter # prevent circular import in Python < 3.5 7 | 8 | 9 | class Modulation(Filter): 10 | r"""Design a filter bank with a kernel centered at each frequency. 11 | 12 | Design a filter bank from translated versions of a mother filter. 13 | The mother filter is translated to each eigenvalue of the Laplacian via 14 | modulation. A signal is modulated by multiplying it with an eigenvector. 15 | Similarly to localization, it is an element-wise multiplication of a kernel 16 | with the columns of :attr:`pygsp.graphs.Graph.U`, i.e., the eigenvectors, 17 | in the vertex domain. 18 | 19 | This filter bank can be used to compute the frequency content of a signal 20 | at each vertex. After filtering, one obtains a vertex-frequency 21 | representation :math:`Sf(i,k)` of a signal :math:`f` as 22 | 23 | .. math:: Sf(i, k) = \langle g_{i,k}, f \rangle, 24 | 25 | where :math:`g_{i,k}` is the mother kernel modulated in the spectral domain 26 | by the eigenvector :math:`u_k`, and localized on vertex :math:`v_i`. 27 | 28 | While :math:`g_{i,k}` should ideally be localized in both the spectral and 29 | vertex domains, that is impossible for some graphs due to the localization 30 | of some eigenvectors. See :attr:`pygsp.graphs.Graph.coherence`. 31 | 32 | As modulation and localization don't commute, one can define the frame as 33 | :math:`g_{i,k} = T_i M_k g` (modulation first) or :math:`g_{i,k} = M_k T_i 34 | g` (localization first). Localization first usually gives better results. 35 | When localizing first, the obtained vertex-frequency representation is a 36 | generalization to graphs of the windowed graph Fourier transform. Indeed, 37 | 38 | .. math:: Sf(i, k) = \langle f^\text{win}_i, u_k \rangle 39 | 40 | is the graph Fourier transform of the windowed signal :math:`f^\text{win}`. 41 | The signal :math:`f` is windowed in the vertex domain by a point-wise 42 | multiplication with the localized kernel :math:`T_i g`. 43 | 44 | When localizing first, the spectral representation of the filter bank is 45 | different for every localization. As such, we always evaluate the filter in 46 | the spectral domain with modulation first. Moreover, the filter bank is 47 | only defined at the eigenvalues (as modulation is done with discrete 48 | eigenvectors). Evaluating it elsewhere returns NaNs. 49 | 50 | Parameters 51 | ---------- 52 | graph : :class:`pygsp.graphs.Graph` 53 | kernel : :class:`pygsp.filters.Filter` 54 | Kernel function to be modulated. 55 | modulation_first : bool 56 | First modulate then localize the kernel if True, first localize then 57 | modulate if False. The two operators do not commute. This setting only 58 | applies to :meth:`filter`. :meth:`evaluate` only performs modulation, 59 | as the filter would otherwise have a different spectrum depending on 60 | where it is localized. 61 | 62 | See Also 63 | -------- 64 | Gabor : Another way to translate a filter in the spectral domain. 65 | 66 | Notes 67 | ----- 68 | The eigenvalues of the graph Laplacian (i.e., the Fourier basis) are needed 69 | to modulate the kernels. 70 | 71 | References 72 | ---------- 73 | See :cite:`shuman2016vertexfrequency` for details on this vertex-frequency 74 | representation of graph signals. 75 | 76 | Examples 77 | -------- 78 | 79 | Vertex-frequency representations. 80 | Modulating first doesn't produce sufficiently localized filters. 81 | 82 | >>> import matplotlib.pyplot as plt 83 | >>> G = graphs.Path(90) 84 | >>> G.compute_fourier_basis() 85 | >>> 86 | >>> # Design the filter banks. 87 | >>> g = filters.Heat(G, 500) 88 | >>> g1 = filters.Modulation(G, g, modulation_first=False) 89 | >>> g2 = filters.Modulation(G, g, modulation_first=True) 90 | >>> _ = g1.plot(sum=False, labels=False) 91 | >>> 92 | >>> # Signal. 93 | >>> s = np.empty(G.N) 94 | >>> s[:30] = G.U[:30, 10] 95 | >>> s[30:60] = G.U[30:60, 60] 96 | >>> s[60:] = G.U[60:, 30] 97 | >>> G.set_coordinates('line1D') 98 | >>> _ = G.plot(s) 99 | >>> 100 | >>> # Filter with both filter banks. 101 | >>> s1 = g1.filter(s) 102 | >>> s2 = g2.filter(s) 103 | >>> 104 | >>> # Visualize the vertex-frequency representation of the signal. 105 | >>> fig, axes = plt.subplots(1, 2) 106 | >>> _ = axes[0].imshow(np.abs(s1.T)**2) 107 | >>> _ = axes[1].imshow(np.abs(s2.T)**2) 108 | >>> _ = axes[0].set_title('localization then modulation') 109 | >>> _ = axes[1].set_title('modulation then localization') 110 | >>> ticks = [0, G.N//2, G.N-1] 111 | >>> labels = ['{:.1f}'.format(e) for e in G.e[ticks]] 112 | >>> _ = axes[0].set_yticks(ticks) 113 | >>> _ = axes[1].set_yticks([]) 114 | >>> _ = axes[0].set_yticklabels(labels) 115 | >>> _ = axes[0].set_ylabel('graph frequency') 116 | >>> _ = axes[0].set_xlabel('node') 117 | >>> _ = axes[1].set_xlabel('node') 118 | >>> _ = axes[0].set_xticks(ticks) 119 | >>> _ = axes[1].set_xticks(ticks) 120 | >>> fig.tight_layout() 121 | >>> 122 | >>> # Reconstruction. 123 | >>> s = g2.filter(s2) 124 | >>> _ = G.plot(s) 125 | 126 | """ 127 | 128 | def __init__(self, graph, kernel, modulation_first=False): 129 | 130 | self.G = graph 131 | self._kernels = kernel 132 | self._modulation_first = modulation_first 133 | 134 | if kernel.n_filters != 1: 135 | raise ValueError('A kernel must be one filter. The passed ' 136 | 'filter bank {} has {}.'.format( 137 | kernel, kernel.n_filters)) 138 | if kernel.G is not graph: 139 | raise ValueError('The graph passed to this filter bank must ' 140 | 'be the one used to build the mother kernel.') 141 | 142 | self.n_features_in, self.n_features_out = (1, graph.n_vertices) 143 | self.n_filters = self.n_features_in * self.n_features_out 144 | self.Nf = self.n_filters # TODO: kept for backward compatibility only. 145 | 146 | def evaluate(self, x): 147 | """TODO: will become _evaluate once polynomial filtering is merged.""" 148 | 149 | if not hasattr(self, '_coefficients'): 150 | # Graph Fourier transform -> modulation -> inverse GFT. 151 | c = self.G.igft(self._kernels.evaluate(self.G.e).squeeze()) 152 | c = np.sqrt(self.G.n_vertices) * self.G.U * c[:, np.newaxis] 153 | self._coefficients = self.G.gft(c) 154 | 155 | shape = x.shape 156 | x = x.flatten() 157 | y = np.full((self.n_features_out, x.size), np.nan) 158 | for i in range(len(x)): 159 | query = self._coefficients[x[i] == self.G.e] 160 | if len(query) != 0: 161 | y[:, i] = query[0] 162 | return y.reshape((self.n_features_out,) + shape) 163 | 164 | def filter(self, s, method='exact', order=None): 165 | """TODO: indirection will be removed when poly filtering is merged. 166 | TODO: with _filter and shape handled in Filter.filter, synthesis will work. 167 | """ 168 | if self._modulation_first: 169 | return super(Modulation, self).filter(s, method='exact') 170 | else: 171 | # The dot product with each modulated kernel is equivalent to the 172 | # GFT, as for the localization and the IGFT. 173 | y = np.empty((self.G.n_vertices, self.G.n_vertices)) 174 | for i in range(self.G.n_vertices): 175 | x = s * self._kernels.localize(i) 176 | y[i] = np.sqrt(self.G.n_vertices) * self.G.gft(x) 177 | return y 178 | -------------------------------------------------------------------------------- /pygsp/filters/papadakis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class Papadakis(Filter): 11 | r"""Design 2 filters with the Papadakis construction (tight frame). 12 | 13 | This function creates a Parseval filter bank of 2 filters. 14 | The low-pass filter is defined by the function 15 | 16 | .. math:: f_{l}=\begin{cases} 1 & \mbox{if }x\leq a\\ 17 | \sqrt{1-\frac{\sin\left(\frac{3\pi}{2a}x\right)}{2}} & \mbox{if }a\frac{5a}{3} \end{cases} 19 | 20 | The high pass filter is adapted to obtain a tight frame. 21 | 22 | Parameters 23 | ---------- 24 | G : graph 25 | a : float 26 | See above equation for this parameter. 27 | The spectrum is scaled between 0 and 2 (default = 3/4). 28 | 29 | Examples 30 | -------- 31 | 32 | Filter bank's representation in Fourier and time (ring graph) domains. 33 | 34 | >>> import matplotlib.pyplot as plt 35 | >>> G = graphs.Ring(N=20) 36 | >>> G.estimate_lmax() 37 | >>> G.set_coordinates('line1D') 38 | >>> g = filters.Papadakis(G) 39 | >>> s = g.localize(G.N // 2) 40 | >>> fig, axes = plt.subplots(1, 2) 41 | >>> _ = g.plot(ax=axes[0]) 42 | >>> _ = G.plot(s, ax=axes[1]) 43 | 44 | """ 45 | 46 | def __init__(self, G, a=0.75): 47 | 48 | self.a = a 49 | 50 | def kernel(x, a): 51 | y = np.empty(np.shape(x)) 52 | l1 = a 53 | l2 = a * 5 / 3 54 | 55 | r1ind = (x >= 0) * (x < l1) 56 | r2ind = (x >= l1) * (x < l2) 57 | r3ind = (x >= l2) 58 | 59 | y[r1ind] = 1 60 | y[r2ind] = np.sqrt((1 - np.sin(3*np.pi/(2*a) * x[r2ind]))/2) 61 | y[r3ind] = 0 62 | 63 | return y 64 | 65 | papadakis = Filter(G, lambda x: kernel(x*2/G.lmax, a)) 66 | complement = papadakis.complement(frame_bound=1) 67 | kernels = papadakis._kernels + complement._kernels 68 | 69 | super(Papadakis, self).__init__(G, kernels) 70 | 71 | def _get_extra_repr(self): 72 | return dict(a='{:.2f}'.format(self.a)) 73 | -------------------------------------------------------------------------------- /pygsp/filters/rectangular.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class Rectangular(Filter): 11 | r"""Design a rectangular filter. 12 | 13 | The filter evaluates at one in the interval [band_min, band_max] and zero 14 | everywhere else. 15 | 16 | The rectangular kernel is defined as 17 | 18 | .. math:: g(\lambda) = \begin{cases} 19 | 0 & \text{if } \lambda < \text{band}_\text{min}, \\ 20 | 1 & \text{if band}_\text{min} \leq \lambda \leq \text{band}_\text{max}, \\ 21 | 0 & \text{if } \lambda > \text{band}_\text{max}, \\ 22 | \end{cases} 23 | 24 | where :math:`\lambda \in [0, 1]` corresponds to the normalized graph 25 | eigenvalues. 26 | 27 | Parameters 28 | ---------- 29 | G : graph 30 | band_min : float 31 | Minimum relative band. The filter evaluates at 1 at this frequency. 32 | Zero corresponds to the smallest eigenvalue (which is itself equal to 33 | zero), one corresponds to the largest eigenvalue. 34 | If None, the filter has no lower bound (which corresponds to 35 | :math:`\text{band}_\text{min} = -\infty`) and is high-pass. 36 | band_max : float 37 | Maximum relative band. The filter evaluates at 1 at this frequency. 38 | If None, the filter has no upper bound (which corresponds to 39 | :math:`\text{band}_\text{min} = \infty`) and is high-pass. 40 | 41 | Examples 42 | -------- 43 | 44 | Filter bank's representation in Fourier and time (ring graph) domains. 45 | 46 | >>> import matplotlib.pyplot as plt 47 | >>> G = graphs.Ring(N=20) 48 | >>> G.estimate_lmax() 49 | >>> G.set_coordinates('line1D') 50 | >>> g = filters.Rectangular(G, band_min=0.1, band_max=0.5) 51 | >>> s = g.localize(G.N // 2) 52 | >>> fig, axes = plt.subplots(1, 2) 53 | >>> _ = g.plot(ax=axes[0]) 54 | >>> _ = G.plot(s, ax=axes[1]) 55 | 56 | """ 57 | 58 | def __init__(self, G, band_min=None, band_max=0.2): 59 | 60 | self.band_min = band_min 61 | self.band_max = band_max 62 | 63 | def kernel_lowpass(x): 64 | x = x / G.lmax 65 | return x <= band_max 66 | 67 | def kernel_highpass(x): 68 | x = x / G.lmax 69 | return x >= band_min 70 | 71 | if (band_min is None) and (band_max is None): 72 | kernel = lambda x: np.ones_like(x) 73 | elif band_min is None: 74 | kernel = kernel_lowpass 75 | elif band_max is None: 76 | kernel = kernel_highpass 77 | else: 78 | kernel = lambda x: kernel_lowpass(x) * kernel_highpass(x) 79 | 80 | super(Rectangular, self).__init__(G, kernel) 81 | 82 | def _get_extra_repr(self): 83 | attrs = dict() 84 | if self.band_min is not None: 85 | attrs.update(band_min='{:.2f}'.format(self.band_min)) 86 | if self.band_max is not None: 87 | attrs.update(band_max='{:.2f}'.format(self.band_max)) 88 | return attrs 89 | -------------------------------------------------------------------------------- /pygsp/filters/regular.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class Regular(Filter): 11 | r"""Design 2 filters with the regular construction (tight frame). 12 | 13 | This function creates a Parseval filter bank of 2 filters. 14 | The low-pass filter is defined by a function :math:`f_l(x)` 15 | between :math:`0` and :math:`2`. For :math:`d = 0`. 16 | 17 | .. math:: f_{l}= \sin\left( \frac{\pi}{4} x \right) 18 | 19 | For :math:`d = 1` 20 | 21 | .. math:: f_{l}= \sin\left( \frac{\pi}{4} \left( 1+ \sin\left(\frac{\pi}{2}(x-1)\right) \right) \right) 22 | 23 | For :math:`d = 2` 24 | 25 | .. math:: f_{l}= \sin\left( \frac{\pi}{4} \left( 1+ \sin\left(\frac{\pi}{2} \sin\left(\frac{\pi}{2}(x-1)\right)\right) \right) \right) 26 | 27 | And so forth for other degrees :math:`d`. 28 | 29 | The high pass filter is adapted to obtain a tight frame. 30 | 31 | Parameters 32 | ---------- 33 | G : graph 34 | degree : float 35 | Degree (default = 3). See above equations. 36 | 37 | Examples 38 | -------- 39 | 40 | Filter bank's representation in Fourier and time (ring graph) domains. 41 | 42 | >>> import matplotlib.pyplot as plt 43 | >>> G = graphs.Ring(N=20) 44 | >>> G.estimate_lmax() 45 | >>> G.set_coordinates('line1D') 46 | >>> g = filters.Regular(G) 47 | >>> s = g.localize(G.N // 2) 48 | >>> fig, axes = plt.subplots(1, 2) 49 | >>> _ = g.plot(ax=axes[0]) 50 | >>> _ = G.plot(s, ax=axes[1]) 51 | 52 | """ 53 | 54 | def __init__(self, G, degree=3): 55 | 56 | self.degree = degree 57 | 58 | def kernel(x, degree): 59 | if degree == 0: 60 | return np.sin(np.pi / 4 * x) 61 | else: 62 | output = np.sin(np.pi * (x - 1) / 2) 63 | for _ in range(2, degree): 64 | output = np.sin(np.pi * output / 2) 65 | return np.sin(np.pi / 4 * (1 + output)) 66 | 67 | regular = Filter(G, lambda x: kernel(x*2/G.lmax, degree)) 68 | complement = regular.complement(frame_bound=1) 69 | kernels = regular._kernels + complement._kernels 70 | 71 | super(Regular, self).__init__(G, kernels) 72 | -------------------------------------------------------------------------------- /pygsp/filters/simoncelli.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from . import Filter # prevent circular import in Python < 3.5 8 | 9 | 10 | class Simoncelli(Filter): 11 | r"""Design 2 filters with the Simoncelli construction (tight frame). 12 | 13 | This function creates a Parseval filter bank of 2 filters. 14 | The low-pass filter is defined by the function 15 | 16 | .. math:: f_{l}=\begin{cases} 1 & \mbox{if }x\leq a\\ 17 | \cos\left(\frac{\pi}{2}\frac{\log\left(\frac{x}{2}\right)}{\log(2)}\right) & \mbox{if }a2a \end{cases} 19 | 20 | The high pass filter is adapted to obtain a tight frame. 21 | 22 | Parameters 23 | ---------- 24 | G : graph 25 | a : float 26 | See above equation for this parameter. 27 | The spectrum is scaled between 0 and 2 (default = 2/3). 28 | 29 | Examples 30 | -------- 31 | 32 | Filter bank's representation in Fourier and time (ring graph) domains. 33 | 34 | >>> import matplotlib.pyplot as plt 35 | >>> G = graphs.Ring(N=20) 36 | >>> G.estimate_lmax() 37 | >>> G.set_coordinates('line1D') 38 | >>> g = filters.Simoncelli(G) 39 | >>> s = g.localize(G.N // 2) 40 | >>> fig, axes = plt.subplots(1, 2) 41 | >>> _ = g.plot(ax=axes[0]) 42 | >>> _ = G.plot(s, ax=axes[1]) 43 | 44 | """ 45 | 46 | def __init__(self, G, a=2/3): 47 | 48 | self.a = a 49 | 50 | def kernel(x, a): 51 | y = np.empty(np.shape(x)) 52 | l1 = a 53 | l2 = 2 * a 54 | 55 | r1ind = (x >= 0) * (x < l1) 56 | r2ind = (x >= l1) * (x < l2) 57 | r3ind = (x >= l2) 58 | 59 | y[r1ind] = 1 60 | y[r2ind] = np.cos(np.pi/2 * np.log(x[r2ind]/a) / np.log(2)) 61 | y[r3ind] = 0 62 | 63 | return y 64 | 65 | simoncelli = Filter(G, lambda x: kernel(x*2/G.lmax, a)) 66 | complement = simoncelli.complement(frame_bound=1) 67 | kernels = simoncelli._kernels + complement._kernels 68 | 69 | super(Simoncelli, self).__init__(G, kernels) 70 | 71 | def _get_extra_repr(self): 72 | return dict(a='{:.2f}'.format(self.a)) 73 | -------------------------------------------------------------------------------- /pygsp/filters/simpletight.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp import utils 6 | from . import Filter # prevent circular import in Python < 3.5 7 | 8 | 9 | class SimpleTight(Filter): 10 | r"""Design a simple tight frame filter bank (tight frame). 11 | 12 | These filters have been designed to be a simple tight frame wavelet filter 13 | bank. The kernel is similar to Meyer, but simpler. The function is 14 | essentially :math:`\sin^2(x)` in ascending part and :math:`\cos^2` in 15 | descending part. 16 | 17 | Parameters 18 | ---------- 19 | G : graph 20 | Nf : int 21 | Number of filters to cover the interval [0, lmax]. 22 | scales : array_like 23 | Scales to be used. Defaults to log scale. 24 | 25 | Examples 26 | -------- 27 | 28 | Filter bank's representation in Fourier and time (ring graph) domains. 29 | 30 | >>> import matplotlib.pyplot as plt 31 | >>> G = graphs.Ring(N=20) 32 | >>> G.estimate_lmax() 33 | >>> G.set_coordinates('line1D') 34 | >>> g = filters.SimpleTight(G) 35 | >>> s = g.localize(G.N // 2) 36 | >>> fig, axes = plt.subplots(1, 2) 37 | >>> _ = g.plot(ax=axes[0]) 38 | >>> _ = G.plot(s, ax=axes[1]) 39 | 40 | """ 41 | 42 | def __init__(self, G, Nf=6, scales=None): 43 | 44 | def kernel(x, kerneltype): 45 | r""" 46 | Evaluates 'simple' tight-frame kernel. 47 | 48 | * simple tf wavelet kernel: supported on [1/4, 1] 49 | * simple tf scaling function kernel: supported on [0, 1/2] 50 | 51 | Parameters 52 | ---------- 53 | x : ndarray 54 | Array of independent variable values 55 | kerneltype : str 56 | Can be either 'sf' or 'wavelet' 57 | 58 | Returns 59 | ------- 60 | r : ndarray 61 | 62 | """ 63 | 64 | l1 = 0.25 65 | l2 = 0.5 66 | l3 = 1. 67 | 68 | def h(x): 69 | return np.sin(np.pi*x/2.)**2 70 | 71 | r1ind = (x < l1) 72 | r2ind = (x >= l1) * (x < l2) 73 | r3ind = (x >= l2) * (x < l3) 74 | 75 | r = np.zeros(x.shape) 76 | if kerneltype == 'sf': 77 | r[r1ind] = 1. 78 | r[r2ind] = np.sqrt(1 - h(4*x[r2ind] - 1)**2) 79 | elif kerneltype == 'wavelet': 80 | r[r2ind] = h(4*(x[r2ind] - 1/4.)) 81 | r[r3ind] = np.sqrt(1 - h(2*x[r3ind] - 1)**2) 82 | else: 83 | raise TypeError('Unknown kernel type', kerneltype) 84 | 85 | return r 86 | 87 | if not scales: 88 | scales = (1./(2.*G.lmax) * np.power(2, np.arange(Nf-2, -1, -1))) 89 | self.scales = scales 90 | 91 | if len(scales) != Nf - 1: 92 | raise ValueError('len(scales) should be Nf-1.') 93 | 94 | kernels = [lambda x: kernel(scales[0] * x, 'sf')] 95 | 96 | for i in range(Nf - 1): 97 | kernels.append(lambda x, i=i: kernel(scales[i] * x, 'wavelet')) 98 | 99 | super(SimpleTight, self).__init__(G, kernels) 100 | -------------------------------------------------------------------------------- /pygsp/filters/wave.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | from functools import partial 6 | 7 | import numpy as np 8 | 9 | from . import Filter # prevent circular import in Python < 3.5 10 | 11 | 12 | class Wave(Filter): 13 | r"""Design a filter bank of wave kernels. 14 | 15 | The wave kernel is defined in the spectral domain as 16 | 17 | .. math:: g_{\tau, t}(\lambda) = \cos \left( t 18 | \arccos \left( 1 - \frac{\tau^2}{2} \lambda \right) \right), 19 | 20 | where :math:`\lambda \in [0, 1]` are the normalized eigenvalues of the 21 | graph Laplacian, :math:`t` is time, and :math:`\tau` is the propagation 22 | speed. 23 | 24 | The wave kernel is the fundamental solution to the wave equation 25 | 26 | .. math:: - \tau^2 L f(t) = \partial_{tt} f(t), 27 | 28 | where :math:`f: \mathbb{R}_+ \rightarrow \mathbb{R}^N` models, for example, 29 | the mechanical displacement of a wave on a graph. Given the initial 30 | condition :math:`f(0)` and assuming a vanishing initial velocity, i.e., the 31 | first derivative in time of the initial distribution equals zero, the 32 | solution of the wave equation is expressed as 33 | 34 | .. math:: f(t) = U g_{\tau, t}(\Lambda) U^\top f(0) 35 | = g_{\tau, t}(L) f(0). 36 | 37 | The above is, by definition, the convolution of the signal :math:`f(0)` 38 | with the kernel :math:`g_{\tau, t}`. 39 | Hence, applying this filter to a signal simulates wave propagation. 40 | 41 | Parameters 42 | ---------- 43 | G : graph 44 | time : float or iterable 45 | Time step. 46 | If iterable, creates a filter bank with one filter per value. 47 | speed : float or iterable 48 | Propagation speed, bounded by 0 (included) and 2 (excluded). 49 | If iterable, creates a filter bank with one filter per value. 50 | 51 | References 52 | ---------- 53 | :cite:`grassi2016timevertex`, :cite:`grassi2018timevertex` 54 | 55 | Examples 56 | -------- 57 | 58 | Filter bank's representation in Fourier and time (ring graph) domains. 59 | 60 | >>> import matplotlib.pyplot as plt 61 | >>> G = graphs.Ring(N=20) 62 | >>> G.estimate_lmax() 63 | >>> G.set_coordinates('line1D') 64 | >>> g = filters.Wave(G, time=[5, 15], speed=1) 65 | >>> s = g.localize(G.N // 2) 66 | >>> fig, axes = plt.subplots(1, 2) 67 | >>> _ = g.plot(ax=axes[0]) 68 | >>> _ = G.plot(s, ax=axes[1]) 69 | 70 | Wave propagation from two sources on a grid. 71 | 72 | >>> import matplotlib.pyplot as plt 73 | >>> n_side = 11 74 | >>> graph = graphs.Grid2d(n_side) 75 | >>> graph.estimate_lmax() 76 | >>> sources = [ 77 | ... (n_side//4 * n_side) + (n_side//4), 78 | ... (n_side*3//4 * n_side) + (n_side*3//4), 79 | ... ] 80 | >>> delta = np.zeros(graph.n_vertices) 81 | >>> delta[sources] = 5 82 | >>> steps = np.array([5, 10]) 83 | >>> g = filters.Wave(graph, time=steps, speed=1) 84 | >>> propagated = g.filter(delta) 85 | >>> fig, axes = plt.subplots(1, len(steps), figsize=(10, 4)) 86 | >>> _ = fig.suptitle('Wave propagation', fontsize=16) 87 | >>> for i, ax in enumerate(axes): 88 | ... _ = graph.plot(propagated[:, i], highlight=sources, 89 | ... title='step {}'.format(steps[i]), ax=ax) 90 | ... ax.set_aspect('equal', 'box') 91 | ... ax.set_axis_off() 92 | 93 | """ 94 | 95 | def __init__(self, G, time=10, speed=1): 96 | 97 | try: 98 | iter(time) 99 | except TypeError: 100 | time = [time] 101 | try: 102 | iter(speed) 103 | except TypeError: 104 | speed = [speed] 105 | 106 | self.time = time 107 | self.speed = speed 108 | 109 | if len(time) != len(speed): 110 | if len(speed) == 1: 111 | speed = speed * len(time) 112 | elif len(time) == 1: 113 | time = time * len(speed) 114 | else: 115 | raise ValueError('If both parameters are iterable, ' 116 | 'they should have the same length.') 117 | 118 | if np.any(np.asanyarray(speed) >= 2): 119 | raise ValueError('The wave propagation speed should be in [0, 2[') 120 | 121 | def kernel(x, time, speed): 122 | return np.cos(time * np.arccos(1 - speed**2 * x / G.lmax / 2)) 123 | 124 | kernels = [partial(kernel, time=t, speed=s) 125 | for t, s in zip(time, speed)] 126 | 127 | super(Wave, self).__init__(G, kernels) 128 | 129 | def _get_extra_repr(self): 130 | time = '[' + ', '.join('{:.2f}'.format(t) for t in self.time) + ']' 131 | speed = '[' + ', '.join('{:.2f}'.format(s) for s in self.speed) + ']' 132 | return dict(time=time, speed=speed) 133 | -------------------------------------------------------------------------------- /pygsp/graphs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | r""" 4 | The :mod:`pygsp.graphs` module implements the graph class hierarchy. A graph 5 | object is either constructed from an adjacency matrix, or by instantiating one 6 | of the built-in graph models. 7 | 8 | Interface 9 | ========= 10 | 11 | The :class:`Graph` base class allows to construct a graph object from any 12 | adjacency matrix and provides a common interface to that object. Derived 13 | classes then allows to instantiate various standard graph models. 14 | 15 | Attributes 16 | ---------- 17 | 18 | **Matrix operators** 19 | 20 | .. autosummary:: 21 | 22 | Graph.W 23 | Graph.L 24 | Graph.U 25 | Graph.D 26 | 27 | **Vectors** 28 | 29 | .. autosummary:: 30 | 31 | Graph.d 32 | Graph.dw 33 | Graph.e 34 | 35 | **Scalars** 36 | 37 | .. autosummary:: 38 | 39 | Graph.lmax 40 | Graph.coherence 41 | 42 | Attributes computation 43 | ---------------------- 44 | 45 | .. autosummary:: 46 | 47 | Graph.compute_laplacian 48 | Graph.estimate_lmax 49 | Graph.compute_fourier_basis 50 | Graph.compute_differential_operator 51 | 52 | Differential operators 53 | ---------------------- 54 | 55 | .. autosummary:: 56 | 57 | Graph.grad 58 | Graph.div 59 | Graph.dirichlet_energy 60 | 61 | Transforms 62 | ---------- 63 | 64 | .. autosummary:: 65 | 66 | Graph.gft 67 | Graph.igft 68 | 69 | Vertex-frequency transforms are implemented as filter banks and are found in 70 | :mod:`pygsp.filters` (such as :class:`~pygsp.filters.Gabor` and 71 | :class:`~pygsp.filters.Modulation`). 72 | 73 | Checks 74 | ------ 75 | 76 | .. autosummary:: 77 | 78 | Graph.is_weighted 79 | Graph.is_connected 80 | Graph.is_directed 81 | Graph.has_loops 82 | 83 | Plotting 84 | -------- 85 | 86 | .. autosummary:: 87 | 88 | Graph.plot 89 | Graph.plot_spectrogram 90 | 91 | Import and export (I/O) 92 | ----------------------- 93 | 94 | We provide import and export facility to two well-known Python packages for 95 | network analysis: NetworkX_ and graph-tool_. 96 | Those packages and the PyGSP are fundamentally different in their goals (graph 97 | analysis versus graph signal analysis) and graph representations (if in the 98 | PyGSP everything is an ndarray, in NetworkX everything is a dictionary). 99 | Those tools are complementary and good interoperability is necessary to exploit 100 | the strengths of each tool. 101 | We ourselves leverage NetworkX and graph-tool to save and load graphs. 102 | 103 | Note: to tie a signal with the graph, such that they are exported together, 104 | attach it first with :meth:`Graph.set_signal`. 105 | 106 | .. _NetworkX: https://networkx.org 107 | .. _graph-tool: https://graph-tool.skewed.de 108 | 109 | .. autosummary:: 110 | 111 | Graph.load 112 | Graph.save 113 | Graph.from_networkx 114 | Graph.to_networkx 115 | Graph.from_graphtool 116 | Graph.to_graphtool 117 | 118 | Others 119 | ------ 120 | 121 | .. autosummary:: 122 | 123 | Graph.get_edge_list 124 | Graph.set_signal 125 | Graph.set_coordinates 126 | Graph.subgraph 127 | Graph.extract_components 128 | 129 | Graph models 130 | ============ 131 | 132 | In addition to the below graphs, useful resources are the random graph 133 | generators from NetworkX (see `NetworkX's documentation`_) and graph-tool (see 134 | :mod:`graph_tool.generation`), as well as graph-tool's assortment of standard 135 | networks (see :mod:`graph_tool.collection`). 136 | Any graph created by NetworkX or graph-tool can be imported in the PyGSP with 137 | :meth:`Graph.from_networkx` and :meth:`Graph.from_graphtool`. 138 | 139 | .. _NetworkX's documentation: https://networkx.org/documentation/stable/reference/generators.html 140 | 141 | Graphs built from other graphs 142 | ------------------------------ 143 | 144 | .. autosummary:: 145 | 146 | LineGraph 147 | 148 | Generated graphs 149 | ---------------- 150 | 151 | .. autosummary:: 152 | 153 | Airfoil 154 | BarabasiAlbert 155 | Comet 156 | Community 157 | DavidSensorNet 158 | ErdosRenyi 159 | FullConnected 160 | Grid2d 161 | Logo 162 | LowStretchTree 163 | Minnesota 164 | Path 165 | RandomRegular 166 | RandomRing 167 | Ring 168 | Star 169 | StochasticBlockModel 170 | SwissRoll 171 | Torus 172 | 173 | Nearest-neighbors graphs constructed from point clouds 174 | ------------------------------------------------------ 175 | 176 | .. autosummary:: 177 | 178 | NNGraph 179 | Bunny 180 | Cube 181 | ImgPatches 182 | Grid2dImgPatches 183 | Sensor 184 | Sphere 185 | TwoMoons 186 | 187 | """ 188 | 189 | from .graph import Graph # noqa: F401 190 | from .airfoil import Airfoil # noqa: F401 191 | from .barabasialbert import BarabasiAlbert # noqa: F401 192 | from .comet import Comet # noqa: F401 193 | from .community import Community # noqa: F401 194 | from .davidsensornet import DavidSensorNet # noqa: F401 195 | from .erdosrenyi import ErdosRenyi # noqa: F401 196 | from .fullconnected import FullConnected # noqa: F401 197 | from .grid2d import Grid2d # noqa: F401 198 | from .linegraph import LineGraph # noqa: F401 199 | from .logo import Logo # noqa: F401 200 | from .lowstretchtree import LowStretchTree # noqa: F401 201 | from .minnesota import Minnesota # noqa: F401 202 | from .path import Path # noqa: F401 203 | from .randomregular import RandomRegular # noqa: F401 204 | from .randomring import RandomRing # noqa: F401 205 | from .ring import Ring # noqa: F401 206 | from .star import Star # noqa: F401 207 | from .stochasticblockmodel import StochasticBlockModel # noqa: F401 208 | from .swissroll import SwissRoll # noqa: F401 209 | from .torus import Torus # noqa: F401 210 | 211 | from .nngraphs.nngraph import NNGraph # noqa: F401 212 | from .nngraphs.bunny import Bunny # noqa: F401 213 | from .nngraphs.cube import Cube # noqa: F401 214 | from .nngraphs.imgpatches import ImgPatches # noqa: F401 215 | from .nngraphs.grid2dimgpatches import Grid2dImgPatches # noqa: F401 216 | from .nngraphs.sensor import Sensor # noqa: F401 217 | from .nngraphs.sphere import Sphere # noqa: F401 218 | from .nngraphs.twomoons import TwoMoons # noqa: F401 219 | -------------------------------------------------------------------------------- /pygsp/graphs/airfoil.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from pygsp import utils 7 | from . import Graph # prevent circular import in Python < 3.5 8 | 9 | 10 | class Airfoil(Graph): 11 | r"""Airfoil graph. 12 | 13 | Examples 14 | -------- 15 | >>> import matplotlib.pyplot as plt 16 | >>> G = graphs.Airfoil() 17 | >>> fig, axes = plt.subplots(1, 2) 18 | >>> _ = axes[0].spy(G.W, markersize=0.5) 19 | >>> _ = G.plot(edges=True, ax=axes[1]) 20 | 21 | """ 22 | 23 | def __init__(self, **kwargs): 24 | 25 | data = utils.loadmat('pointclouds/airfoil') 26 | coords = np.concatenate((data['x'], data['y']), axis=1) 27 | 28 | i_inds = np.reshape(data['i_inds'] - 1, 12289) 29 | j_inds = np.reshape(data['j_inds'] - 1, 12289) 30 | A = sparse.coo_matrix((np.ones(12289), (i_inds, j_inds)), shape=(4253, 4253)) 31 | W = (A + A.T) / 2. 32 | 33 | plotting = {"vertex_size": 30, 34 | "limits": np.array([-1e-4, 1.01*data['x'].max(), 35 | -1e-4, 1.01*data['y'].max()])} 36 | 37 | super(Airfoil, self).__init__(W, coords=coords, plotting=plotting, 38 | **kwargs) 39 | -------------------------------------------------------------------------------- /pygsp/graphs/barabasialbert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class BarabasiAlbert(Graph): 10 | r"""Barabasi-Albert preferential attachment. 11 | 12 | The Barabasi-Albert graph is constructed by connecting nodes in two steps. 13 | First, m0 nodes are created. Then, nodes are added one by one. 14 | 15 | By lack of clarity, we take the liberty to create it as follows: 16 | 17 | 1. the m0 initial nodes are disconnected, 18 | 2. each node is connected to m of the older nodes with a probability 19 | distribution depending of the node-degrees of the other nodes, 20 | :math:`p_n(i) = \frac{1 + k_i}{\sum_j{1 + k_j}}`. 21 | 22 | Parameters 23 | ---------- 24 | N : int 25 | Number of nodes (default is 1000) 26 | m0 : int 27 | Number of initial nodes (default is 1) 28 | m : int 29 | Number of connections at each step (default is 1) 30 | m can never be larger than m0. 31 | seed : int 32 | Seed for the random number generator (for reproducible graphs). 33 | 34 | Examples 35 | -------- 36 | >>> import matplotlib.pyplot as plt 37 | >>> G = graphs.BarabasiAlbert(N=150, seed=42) 38 | >>> G.set_coordinates(kind='spring', seed=42) 39 | >>> fig, axes = plt.subplots(1, 2) 40 | >>> _ = axes[0].spy(G.W, markersize=2) 41 | >>> _ = G.plot(ax=axes[1]) 42 | 43 | """ 44 | def __init__(self, N=1000, m0=1, m=1, seed=None, **kwargs): 45 | 46 | if m > m0: 47 | raise ValueError('Parameter m cannot be above parameter m0.') 48 | 49 | self.m0 = m0 50 | self.m = m 51 | self.seed = seed 52 | 53 | W = sparse.lil_matrix((N, N)) 54 | rng = np.random.default_rng(seed) 55 | 56 | for i in range(m0, N): 57 | distr = W.sum(axis=1) 58 | distr += np.concatenate((np.ones((i, 1)), np.zeros((N-i, 1)))) 59 | 60 | connections = rng.choice( 61 | N, size=m, replace=False, p=np.ravel(distr / distr.sum())) 62 | for elem in connections: 63 | W[elem, i] = 1 64 | W[i, elem] = 1 65 | 66 | super(BarabasiAlbert, self).__init__(W, **kwargs) 67 | 68 | def _get_extra_repr(self): 69 | return dict(m0=self.m0, m=self.m, seed=self.seed) 70 | -------------------------------------------------------------------------------- /pygsp/graphs/comet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class Comet(Graph): 10 | r"""Comet graph. 11 | 12 | The comet is a path graph with a star of degree `k` at one end. 13 | Equivalently, the comet is a star made of `k` branches, where a branch of 14 | length `N-k` acts as the tail. 15 | The central vertex has degree `N-1`, the others have degree 1. 16 | 17 | Parameters 18 | ---------- 19 | N : int 20 | Number of vertices. 21 | k : int 22 | Degree of central vertex. 23 | 24 | See Also 25 | -------- 26 | Path : Comet without star 27 | Star : Comet without tail (path) 28 | 29 | Examples 30 | -------- 31 | >>> import matplotlib.pyplot as plt 32 | >>> G = graphs.Comet(15, 10) 33 | >>> fig, axes = plt.subplots(1, 2) 34 | >>> _ = axes[0].spy(G.W) 35 | >>> _ = G.plot(ax=axes[1]) 36 | 37 | """ 38 | 39 | def __init__(self, N=32, k=12, **kwargs): 40 | 41 | if k > N-1: 42 | raise ValueError('The degree of the central vertex k={} must be ' 43 | 'smaller than the number of vertices N={}.' 44 | ''.format(k, N)) 45 | 46 | self.k = k 47 | 48 | sources = np.concatenate(( 49 | np.zeros(k), np.arange(k)+1, # star 50 | np.arange(k, N-1), np.arange(k+1, N) # tail (path) 51 | )) 52 | targets = np.concatenate(( 53 | np.arange(k)+1, np.zeros(k), # star 54 | np.arange(k+1, N), np.arange(k, N-1) # tail (path) 55 | )) 56 | n_edges = N - 1 57 | weights = np.ones(2*n_edges) 58 | W = sparse.csr_matrix((weights, (sources, targets)), shape=(N, N)) 59 | 60 | indices = np.arange(k) + 1 61 | coords = np.zeros((N, 2)) 62 | coords[1:k+1, 0] = np.cos(indices*2*np.pi/k) 63 | coords[1:k+1, 1] = np.sin(indices*2*np.pi/k) 64 | coords[k+1:, 0] = np.arange(1, N-k) + 1 65 | 66 | super(Comet, self).__init__(W, coords=coords, **kwargs) 67 | 68 | def _get_extra_repr(self): 69 | return dict(k=self.k) 70 | -------------------------------------------------------------------------------- /pygsp/graphs/davidsensornet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp import utils 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class DavidSensorNet(Graph): 10 | r"""Sensor network. 11 | 12 | Parameters 13 | ---------- 14 | N : int 15 | Number of vertices (default = 64). Values of 64 and 500 yield 16 | pre-computed and saved graphs. Other values yield randomly generated 17 | graphs. 18 | seed : int 19 | Seed for the random number generator (for reproducible graphs). 20 | 21 | Examples 22 | -------- 23 | >>> import matplotlib.pyplot as plt 24 | >>> G = graphs.DavidSensorNet() 25 | >>> fig, axes = plt.subplots(1, 2) 26 | >>> _ = axes[0].spy(G.W, markersize=2) 27 | >>> _ = G.plot(ax=axes[1]) 28 | 29 | """ 30 | 31 | def __init__(self, N=64, seed=None, **kwargs): 32 | 33 | self.seed = seed 34 | 35 | if N == 64: 36 | data = utils.loadmat('pointclouds/david64') 37 | assert data['N'][0, 0] == N 38 | W = data['W'] 39 | coords = data['coords'] 40 | 41 | elif N == 500: 42 | data = utils.loadmat('pointclouds/david500') 43 | assert data['N'][0, 0] == N 44 | W = data['W'] 45 | coords = data['coords'] 46 | 47 | else: 48 | coords = np.random.default_rng(seed).uniform(size=(N, 2)) 49 | 50 | target_dist_cutoff = -0.125 * N / 436.075 + 0.2183 51 | T = 0.6 52 | s = np.sqrt(-target_dist_cutoff**2/(2*np.log(T))) 53 | d = utils.distanz(coords.T) 54 | W = np.exp(-np.power(d, 2)/(2.*s**2)) 55 | W[W < T] = 0 56 | W[np.diag_indices(N)] = 0 57 | 58 | plotting = {"limits": [0, 1, 0, 1]} 59 | 60 | super(DavidSensorNet, self).__init__(W, coords=coords, 61 | plotting=plotting, **kwargs) 62 | 63 | def _get_extra_repr(self): 64 | return dict(seed=self.seed) 65 | -------------------------------------------------------------------------------- /pygsp/graphs/erdosrenyi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # prevent circular import in Python < 3.5 4 | from .stochasticblockmodel import StochasticBlockModel 5 | 6 | 7 | class ErdosRenyi(StochasticBlockModel): 8 | r"""Erdos Renyi graph. 9 | 10 | The Erdos Renyi graph is constructed by randomly connecting nodes. Each 11 | edge is included in the graph with probability p, independently from any 12 | other edge. All edge weights are equal to 1. 13 | 14 | Parameters 15 | ---------- 16 | N : int 17 | Number of nodes (default is 100). 18 | p : float 19 | Probability to connect a node with another one. 20 | directed : bool 21 | Allow directed edges if True (default is False). 22 | self_loops : bool 23 | Allow self loops if True (default is False). 24 | connected : bool 25 | Force the graph to be connected (default is False). 26 | n_try : int 27 | Maximum number of trials to get a connected graph (default is 10). 28 | seed : int 29 | Seed for the random number generator (for reproducible graphs). 30 | 31 | Examples 32 | -------- 33 | >>> import matplotlib.pyplot as plt 34 | >>> G = graphs.ErdosRenyi(N=64, seed=42) 35 | >>> G.set_coordinates(kind='spring', seed=42) 36 | >>> fig, axes = plt.subplots(1, 2) 37 | >>> _ = axes[0].spy(G.W, markersize=2) 38 | >>> _ = G.plot(ax=axes[1]) 39 | 40 | """ 41 | 42 | def __init__(self, N=100, p=0.1, directed=False, self_loops=False, 43 | connected=False, n_try=10, seed=None, **kwargs): 44 | 45 | super(ErdosRenyi, self).__init__(N=N, k=1, p=p, 46 | directed=directed, 47 | self_loops=self_loops, 48 | connected=connected, 49 | n_try=n_try, 50 | seed=seed, 51 | **kwargs) 52 | -------------------------------------------------------------------------------- /pygsp/graphs/fullconnected.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from . import Graph # prevent circular import in Python < 3.5 6 | 7 | 8 | class FullConnected(Graph): 9 | r"""Fully connected graph. 10 | 11 | All weights are set to 1. There is no self-connections. 12 | 13 | Parameters 14 | ---------- 15 | N : int 16 | Number of vertices (default = 10) 17 | 18 | Examples 19 | -------- 20 | >>> import matplotlib.pyplot as plt 21 | >>> G = graphs.FullConnected(N=20) 22 | >>> G.set_coordinates(kind='spring', seed=42) 23 | >>> fig, axes = plt.subplots(1, 2) 24 | >>> _ = axes[0].spy(G.W, markersize=5) 25 | >>> _ = G.plot(ax=axes[1]) 26 | 27 | """ 28 | 29 | def __init__(self, N=10, **kwargs): 30 | 31 | W = np.ones((N, N)) - np.identity(N) 32 | plotting = {'limits': np.array([-1, 1, -1, 1])} 33 | 34 | super(FullConnected, self).__init__(W, plotting=plotting, **kwargs) 35 | -------------------------------------------------------------------------------- /pygsp/graphs/grid2d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from pygsp import utils 7 | from . import Graph # prevent circular import in Python < 3.5 8 | 9 | 10 | class Grid2d(Graph): 11 | r"""2-dimensional grid graph. 12 | 13 | On the 2D grid, the graph Fourier transform (GFT) is the Kronecker product 14 | between the GFT of two :class:`~pygsp.graphs.Path` graphs. 15 | 16 | Parameters 17 | ---------- 18 | N1 : int 19 | Number of vertices along the first dimension. 20 | N2 : int 21 | Number of vertices along the second dimension. Default is ``N1``. 22 | diagonal : float 23 | Value of the diagnal edges. Default is ``0.0`` 24 | 25 | See Also 26 | -------- 27 | Path : 1D line with even boundary conditions 28 | Torus : Kronecker product of two ring graphs 29 | Grid2dImgPatches 30 | 31 | Examples 32 | -------- 33 | >>> import matplotlib.pyplot as plt 34 | >>> G = graphs.Grid2d(N1=5, N2=4) 35 | >>> fig, axes = plt.subplots(1, 2) 36 | >>> _ = axes[0].spy(G.W) 37 | >>> _ = G.plot(ax=axes[1]) 38 | 39 | """ 40 | 41 | def __init__(self, N1=16, N2=None, diagonal=0.0, **kwargs): 42 | 43 | if N2 is None: 44 | N2 = N1 45 | 46 | self.N1 = N1 47 | self.N2 = N2 48 | 49 | N = N1 * N2 50 | 51 | # Filling up the weight matrix this way is faster than 52 | # looping through all the grid points: 53 | diag_1 = np.ones(N - 1) 54 | diag_1[(N2 - 1)::N2] = 0 55 | diag_2 = np.ones(N - N2) 56 | 57 | W = sparse.diags(diagonals=[diag_1, diag_2], 58 | offsets=[-1, -N2], 59 | shape=(N, N), 60 | format='csr', 61 | dtype='float') 62 | 63 | if min(N1, N2) > 1 and diagonal != 0.0: 64 | # Connecting node with they diagonal neighbours 65 | diag_3 = np.full(N - N2 - 1, diagonal) 66 | diag_4 = np.full(N - N2 + 1, diagonal) 67 | diag_3[N2 - 1::N2] = 0 68 | diag_4[0::N2] = 0 69 | D = sparse.diags(diagonals=[diag_3, diag_4], 70 | offsets=[-N2 - 1, -N2 + 1], 71 | shape=(N, N), 72 | format='csr', 73 | dtype='float') 74 | W += D 75 | 76 | W = utils.symmetrize(W, method='tril') 77 | 78 | x = np.kron(np.ones((N1, 1)), (np.arange(N2)/float(N2)).reshape(N2, 1)) 79 | y = np.kron(np.ones((N2, 1)), np.arange(N1)/float(N1)).reshape(N, 1) 80 | y = np.sort(y, axis=0)[::-1] 81 | coords = np.concatenate((x, y), axis=1) 82 | 83 | plotting = {"limits": np.array([-1. / N2, 1 + 1. / N2, 84 | 1. / N1, 1 + 1. / N1])} 85 | 86 | super(Grid2d, self).__init__(W, coords=coords, 87 | plotting=plotting, **kwargs) 88 | 89 | def _get_extra_repr(self): 90 | return dict(N1=self.N1, N2=self.N2) 91 | -------------------------------------------------------------------------------- /pygsp/graphs/linegraph.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from pygsp import utils 7 | from . import Graph # prevent circular import in Python < 3.5 8 | 9 | 10 | logger = utils.build_logger(__name__) 11 | 12 | 13 | class LineGraph(Graph): 14 | r"""Build the line graph of a graph. 15 | 16 | Each vertex of the line graph represents an edge in the original graph. Two 17 | vertices are connected if the edges they represent share a vertex in the 18 | original graph. 19 | 20 | Parameters 21 | ---------- 22 | graph : :class:`Graph` 23 | 24 | Examples 25 | -------- 26 | >>> import matplotlib.pyplot as plt 27 | >>> graph = graphs.Sensor(5, k=2, seed=10) 28 | >>> line_graph = graphs.LineGraph(graph) 29 | >>> fig, ax = plt.subplots() 30 | >>> fig, ax = graph.plot('blue', edge_color='blue', indices=True, ax=ax) 31 | >>> fig, ax = line_graph.plot('red', edge_color='red', indices=True, ax=ax) 32 | >>> _ = ax.set_title('graph and its line graph') 33 | 34 | """ 35 | 36 | def __init__(self, graph, **kwargs): 37 | 38 | if graph.is_weighted(): 39 | logger.warning('Your graph is weighted, and is considered ' 40 | 'unweighted to build a binary line graph.') 41 | 42 | graph.compute_differential_operator() 43 | # incidence = np.abs(graph.D) # weighted? 44 | incidence = (graph.D != 0) 45 | 46 | adjacency = incidence.T.dot(incidence).astype(int) 47 | adjacency -= sparse.identity(graph.n_edges, dtype=int) 48 | 49 | try: 50 | coords = incidence.T.dot(graph.coords) / 2 51 | except AttributeError: 52 | coords = None 53 | 54 | super(LineGraph, self).__init__(adjacency, coords=coords, 55 | plotting=graph.plotting, **kwargs) 56 | -------------------------------------------------------------------------------- /pygsp/graphs/logo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp import utils 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class Logo(Graph): 10 | r"""GSP logo. 11 | 12 | Examples 13 | -------- 14 | >>> import matplotlib.pyplot as plt 15 | >>> G = graphs.Logo() 16 | >>> fig, axes = plt.subplots(1, 2) 17 | >>> _ = axes[0].spy(G.W, markersize=0.5) 18 | >>> _ = G.plot(ax=axes[1]) 19 | 20 | """ 21 | 22 | def __init__(self, **kwargs): 23 | 24 | data = utils.loadmat('pointclouds/logogsp') 25 | 26 | # Remove 1 because the index in python start at 0 and not at 1 27 | self.info = {"idx_g": data["idx_g"]-1, 28 | "idx_s": data["idx_s"]-1, 29 | "idx_p": data["idx_p"]-1} 30 | 31 | plotting = {"limits": np.array([0, 640, -400, 0])} 32 | 33 | super(Logo, self).__init__(data['W'], coords=data['coords'], 34 | plotting=plotting, **kwargs) 35 | -------------------------------------------------------------------------------- /pygsp/graphs/lowstretchtree.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class LowStretchTree(Graph): 10 | r"""Low stretch tree. 11 | 12 | Build the root of a low stretch tree on a grid of points. There are 13 | :math:`2k` points on each side of the grid, and therefore :math:`2^{2k}` 14 | vertices in total. The edge weights are all equal to 1. 15 | 16 | Parameters 17 | ---------- 18 | k : int 19 | :math:`2^k` points on each side of the grid of vertices. 20 | 21 | Examples 22 | -------- 23 | >>> import matplotlib.pyplot as plt 24 | >>> G = graphs.LowStretchTree(k=2) 25 | >>> fig, axes = plt.subplots(1, 2) 26 | >>> _ = axes[0].spy(G.W) 27 | >>> _ = G.plot(ax=axes[1]) 28 | 29 | """ 30 | 31 | def __init__(self, k=6, **kwargs): 32 | 33 | self.k = k 34 | 35 | XCoords = np.array([1, 2, 1, 2], dtype=int) 36 | YCoords = np.array([1, 1, 2, 2], dtype=int) 37 | 38 | ii = np.array([0, 0, 1, 2, 2, 3], dtype=int) 39 | jj = np.array([1, 2, 1, 3, 0, 2], dtype=int) 40 | 41 | for p in range(1, k): 42 | ii = np.concatenate((ii, ii + 4**p, ii + 2*4**p, ii + 3*4**p, 43 | [4**p - 1], [4**p - 1], 44 | [4**p + (4**(p+1) + 2) // 3 - 1], 45 | [(5*4**p + 1) // 3 - 1], 46 | [4**p + (4**(p+1) + 2) // 3 - 1], [3*4**p])) 47 | jj = np.concatenate((jj, jj + 4**p, jj + 2*4**p, jj + 3*4**p, 48 | [(5*4**p + 1) // 3 - 1], 49 | [4**p + (4**(p+1) + 2) // 3 - 1], 50 | [3*4**p], [4**p - 1], [4**p - 1], 51 | [4**p + (4**(p+1) + 2) // 3 - 1])) 52 | 53 | YCoords = np.kron(np.ones((2), dtype=int), YCoords) 54 | YCoords = np.concatenate((YCoords, YCoords + 2**p)) 55 | 56 | XCoords = np.concatenate((XCoords, XCoords + 2**p)) 57 | XCoords = np.kron(np.ones((2), dtype=int), XCoords) 58 | 59 | W = sparse.csc_matrix((np.ones_like(ii), (ii, jj))) 60 | coords = np.concatenate((XCoords[:, np.newaxis], 61 | YCoords[:, np.newaxis]), 62 | axis=1) 63 | 64 | self.root = 4**(k - 1) 65 | 66 | plotting = {"edges_width": 1.25, 67 | "vertex_size": 75, 68 | "limits": np.array([0, 2**k + 1, 0, 2**k + 1])} 69 | 70 | super(LowStretchTree, self).__init__(W, 71 | coords=coords, 72 | plotting=plotting, 73 | **kwargs) 74 | 75 | def _get_extra_repr(self): 76 | return dict(k=self.k) 77 | -------------------------------------------------------------------------------- /pygsp/graphs/minnesota.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from pygsp import utils 7 | from . import Graph # prevent circular import in Python < 3.5 8 | 9 | 10 | class Minnesota(Graph): 11 | r"""Minnesota road network (from MatlabBGL). 12 | 13 | Parameters 14 | ---------- 15 | connected : bool 16 | If True, the adjacency matrix is adjusted so that all edge weights are 17 | equal to 1, and the graph is connected. Set to False to get the 18 | original disconnected graph. 19 | 20 | References 21 | ---------- 22 | See :cite:`gleich`. 23 | 24 | Examples 25 | -------- 26 | >>> import matplotlib.pyplot as plt 27 | >>> G = graphs.Minnesota() 28 | >>> fig, axes = plt.subplots(1, 2) 29 | >>> _ = axes[0].spy(G.W, markersize=0.5) 30 | >>> _ = G.plot(ax=axes[1]) 31 | 32 | """ 33 | 34 | def __init__(self, connected=True, **kwargs): 35 | 36 | self.connected = connected 37 | 38 | data = utils.loadmat('pointclouds/minnesota') 39 | self.labels = data['labels'] 40 | A = data['A'] 41 | 42 | plotting = {"limits": np.array([-98, -89, 43, 50]), 43 | "vertex_size": 40} 44 | 45 | if connected: 46 | 47 | # Missing edges needed to connect the graph. 48 | A = sparse.lil_matrix(A) 49 | A[348, 354] = 1 50 | A[354, 348] = 1 51 | A = sparse.csc_matrix(A) 52 | 53 | # Binarize: 8 entries are equal to 2 instead of 1. 54 | A = (A > 0).astype(bool) 55 | 56 | super(Minnesota, self).__init__(A, coords=data['xy'], 57 | plotting=plotting, **kwargs) 58 | 59 | def _get_extra_repr(self): 60 | return dict(connected=self.connected) 61 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/643b1c448559da8c7dbaed7537a9fd819183c569/pygsp/graphs/nngraphs/__init__.py -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/bunny.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from pygsp import utils 4 | from pygsp.graphs import NNGraph # prevent circular import in Python < 3.5 5 | 6 | 7 | class Bunny(NNGraph): 8 | r"""Stanford bunny (NN-graph). 9 | 10 | References 11 | ---------- 12 | See :cite:`turk1994zippered`. 13 | 14 | Examples 15 | -------- 16 | >>> import matplotlib.pyplot as plt 17 | >>> G = graphs.Bunny() 18 | >>> fig = plt.figure() 19 | >>> ax1 = fig.add_subplot(121) 20 | >>> ax2 = fig.add_subplot(122, projection='3d') 21 | >>> _ = ax1.spy(G.W, markersize=0.1) 22 | >>> _ = _ = G.plot(ax=ax2) 23 | 24 | """ 25 | 26 | def __init__(self, **kwargs): 27 | 28 | data = utils.loadmat('pointclouds/bunny') 29 | 30 | plotting = { 31 | 'vertex_size': 10, 32 | 'elevation': -90, 33 | 'azimuth': 90, 34 | 'distance': 8, 35 | } 36 | 37 | super(Bunny, self).__init__(Xin=data['bunny'], 38 | epsilon=0.02, NNtype='radius', 39 | center=False, rescale=False, 40 | plotting=plotting, **kwargs) 41 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/cube.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp.graphs import NNGraph # prevent circular import in Python < 3.5 6 | 7 | 8 | class Cube(NNGraph): 9 | r"""Hyper-cube (NN-graph). 10 | 11 | Parameters 12 | ---------- 13 | radius : float 14 | Edge lenght (default = 1) 15 | nb_pts : int 16 | Number of vertices (default = 300) 17 | nb_dim : int 18 | Dimension (default = 3) 19 | sampling : string 20 | Variance of the distance kernel (default = 'random') 21 | (Can now only be 'random') 22 | seed : int 23 | Seed for the random number generator (for reproducible graphs). 24 | 25 | Examples 26 | -------- 27 | >>> import matplotlib.pyplot as plt 28 | >>> G = graphs.Cube(seed=42) 29 | >>> fig = plt.figure() 30 | >>> ax1 = fig.add_subplot(121) 31 | >>> ax2 = fig.add_subplot(122, projection='3d') 32 | >>> _ = ax1.spy(G.W, markersize=0.5) 33 | >>> _ = G.plot(ax=ax2) 34 | 35 | """ 36 | 37 | def __init__(self, 38 | radius=1, 39 | nb_pts=300, 40 | nb_dim=3, 41 | sampling='random', 42 | seed=None, 43 | **kwargs): 44 | 45 | self.radius = radius 46 | self.nb_pts = nb_pts 47 | self.nb_dim = nb_dim 48 | self.sampling = sampling 49 | self.seed = seed 50 | rs = np.random.RandomState(seed) 51 | 52 | if self.nb_dim > 3: 53 | raise NotImplementedError("Dimension > 3 not supported yet!") 54 | 55 | if self.sampling == "random": 56 | if self.nb_dim == 2: 57 | pts = rs.rand(self.nb_pts, self.nb_dim) 58 | 59 | elif self.nb_dim == 3: 60 | n = self.nb_pts // 6 61 | 62 | pts = np.zeros((n*6, 3)) 63 | pts[:n, 1:] = rs.rand(n, 2) 64 | pts[n:2*n, :] = np.concatenate((np.ones((n, 1)), 65 | rs.rand(n, 2)), 66 | axis=1) 67 | 68 | pts[2*n:3*n, :] = np.concatenate((rs.rand(n, 1), 69 | np.zeros((n, 1)), 70 | rs.rand(n, 1)), 71 | axis=1) 72 | pts[3*n:4*n, :] = np.concatenate((rs.rand(n, 1), 73 | np.ones((n, 1)), 74 | rs.rand(n, 1)), 75 | axis=1) 76 | 77 | pts[4*n:5*n, :2] = rs.rand(n, 2) 78 | pts[5*n:6*n, :] = np.concatenate((rs.rand(n, 2), 79 | np.ones((n, 1))), 80 | axis=1) 81 | 82 | else: 83 | raise ValueError("Unknown sampling !") 84 | 85 | plotting = { 86 | 'vertex_size': 80, 87 | 'elevation': 15, 88 | 'azimuth': 0, 89 | 'distance': 9, 90 | } 91 | 92 | super(Cube, self).__init__(Xin=pts, k=10, 93 | center=False, rescale=False, 94 | plotting=plotting, **kwargs) 95 | 96 | def _get_extra_repr(self): 97 | attrs = {'radius': '{:.2f}'.format(self.radius), 98 | 'nb_pts': self.nb_pts, 99 | 'nb_dim': self.nb_dim, 100 | 'sampling': self.sampling, 101 | 'seed': self.seed} 102 | attrs.update(super(Cube, self)._get_extra_repr()) 103 | return attrs 104 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/grid2dimgpatches.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # prevent circular import in Python < 3.5 4 | from pygsp.graphs import Graph, Grid2d, ImgPatches 5 | 6 | 7 | class Grid2dImgPatches(Graph): 8 | r"""Union of a patch graph with a 2D grid graph. 9 | 10 | Parameters 11 | ---------- 12 | img : array 13 | Input image. 14 | aggregate: callable, optional 15 | Function to aggregate the weights ``Wp`` of the patch graph and the 16 | ``Wg`` of the grid graph. Default is ``lambda Wp, Wg: Wp + Wg``. 17 | kwargs : dict 18 | Parameters passed to :class:`ImgPatches`. 19 | 20 | See Also 21 | -------- 22 | ImgPatches 23 | Grid2d 24 | 25 | Examples 26 | -------- 27 | >>> import matplotlib.pyplot as plt 28 | >>> from skimage import data, img_as_float 29 | >>> img = img_as_float(data.camera()[::64, ::64]) 30 | >>> G = graphs.Grid2dImgPatches(img) 31 | >>> fig, axes = plt.subplots(1, 2) 32 | >>> _ = axes[0].spy(G.W, markersize=2) 33 | >>> _ = G.plot(ax=axes[1]) 34 | 35 | """ 36 | 37 | def __init__(self, img, aggregate=lambda Wp, Wg: Wp + Wg, **kwargs): 38 | 39 | self.Gg = Grid2d(img.shape[0], img.shape[1]) 40 | self.Gp = ImgPatches(img, **kwargs) 41 | 42 | W = aggregate(self.Gp.W, self.Gg.W) 43 | super(Grid2dImgPatches, self).__init__(W, 44 | coords=self.Gg.coords, 45 | plotting=self.Gg.plotting) 46 | 47 | def _get_extra_repr(self): 48 | attrs = self.Gg._get_extra_repr() 49 | attrs.update(self.Gp._get_extra_repr()) 50 | return attrs 51 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/imgpatches.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp.graphs import NNGraph # prevent circular import in Python < 3.5 6 | 7 | 8 | class ImgPatches(NNGraph): 9 | r"""NN-graph between patches of an image. 10 | 11 | Extract a feature vector in the form of a patch for every pixel of an 12 | image, then construct a nearest-neighbor graph between these feature 13 | vectors. The feature matrix, i.e. the patches, can be found in :attr:`Xin`. 14 | 15 | Parameters 16 | ---------- 17 | img : array 18 | Input image. 19 | patch_shape : tuple, optional 20 | Dimensions of the patch window. Syntax: (height, width), or (height,), 21 | in which case width = height. 22 | kwargs : dict 23 | Parameters passed to :class:`NNGraph`. 24 | 25 | See Also 26 | -------- 27 | Grid2dImgPatches 28 | 29 | Notes 30 | ----- 31 | The feature vector of a pixel `i` will consist of the stacking of the 32 | intensity values of all pixels in the patch centered at `i`, for all color 33 | channels. So, if the input image has `d` color channels, the dimension of 34 | the feature vector of each pixel is (patch_shape[0] * patch_shape[1] * d). 35 | 36 | Examples 37 | -------- 38 | >>> import matplotlib.pyplot as plt 39 | >>> from skimage import data, img_as_float 40 | >>> img = img_as_float(data.camera()[::64, ::64]) 41 | >>> G = graphs.ImgPatches(img, patch_shape=(3, 3)) 42 | >>> print('{} nodes ({} x {} pixels)'.format(G.Xin.shape[0], *img.shape)) 43 | 64 nodes (8 x 8 pixels) 44 | >>> print('{} features per node'.format(G.Xin.shape[1])) 45 | 9 features per node 46 | >>> G.set_coordinates(kind='spring', seed=42) 47 | >>> fig, axes = plt.subplots(1, 2) 48 | >>> _ = axes[0].spy(G.W, markersize=2) 49 | >>> _ = G.plot(ax=axes[1]) 50 | 51 | """ 52 | 53 | def __init__(self, img, patch_shape=(3, 3), **kwargs): 54 | 55 | self.img = img 56 | self.patch_shape = patch_shape 57 | 58 | try: 59 | h, w, d = img.shape 60 | except ValueError: 61 | try: 62 | h, w = img.shape 63 | d = 0 64 | except ValueError: 65 | print("Image should be at least a 2D array.") 66 | 67 | try: 68 | r, c = patch_shape 69 | except ValueError: 70 | r = patch_shape[0] 71 | c = r 72 | 73 | pad_width = [(int((r - 0.5) / 2.), int((r + 0.5) / 2.)), 74 | (int((c - 0.5) / 2.), int((c + 0.5) / 2.))] 75 | 76 | if d == 0: 77 | window_shape = (r, c) 78 | d = 1 # For the reshape in the return call 79 | else: 80 | pad_width += [(0, 0)] 81 | window_shape = (r, c, d) 82 | 83 | # Pad the image. 84 | img = np.pad(img, pad_width=pad_width, mode='symmetric') 85 | 86 | # Extract patches as node features. 87 | # Alternative: sklearn.feature_extraction.image.extract_patches_2d. 88 | # sklearn has much less dependencies than skimage. 89 | try: 90 | import skimage 91 | except Exception as e: 92 | raise ImportError('Cannot import skimage, which is needed to ' 93 | 'extract patches. Try to install it with ' 94 | 'pip (or conda) install scikit-image. ' 95 | 'Original exception: {}'.format(e)) 96 | patches = skimage.util.view_as_windows(img, window_shape=window_shape) 97 | patches = patches.reshape((h * w, r * c * d)) 98 | 99 | super(ImgPatches, self).__init__(patches, **kwargs) 100 | 101 | def _get_extra_repr(self): 102 | attrs = dict(patch_shape=self.patch_shape) 103 | attrs.update(super(ImgPatches, self)._get_extra_repr()) 104 | return attrs 105 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/sensor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | from pygsp.graphs import NNGraph # prevent circular import in Python < 3.5 8 | 9 | 10 | class Sensor(NNGraph): 11 | r"""Random sensor graph. 12 | 13 | The sensor graph is built by randomly picking ``N`` points on the [0, 1] x 14 | [0, 1] plane and connecting each to its ``k`` nearest neighbors. 15 | 16 | Parameters 17 | ---------- 18 | N : int 19 | Number of nodes. 20 | Must be a perfect square if ``distributed=True``. 21 | k : int 22 | Minimum number of neighbors. 23 | distributed : bool 24 | Whether to distribute the vertices more evenly on the plane. 25 | If False, coordinates are taken uniformly at random in a [0, 1] square. 26 | If True, the vertices are arranged on a perturbed grid. 27 | seed : int 28 | Seed for the random number generator (for reproducible graphs). 29 | **kwargs : 30 | Additional keyword arguments for :class:`NNGraph`. 31 | 32 | Notes 33 | ----- 34 | 35 | The definition of this graph changed in February 2019. 36 | See the `GitHub PR `_. 37 | 38 | Examples 39 | -------- 40 | >>> import matplotlib.pyplot as plt 41 | >>> G = graphs.Sensor(N=64, seed=42) 42 | >>> fig, axes = plt.subplots(1, 2) 43 | >>> _ = axes[0].spy(G.W, markersize=2) 44 | >>> _ = G.plot(ax=axes[1]) 45 | 46 | >>> import matplotlib.pyplot as plt 47 | >>> G = graphs.Sensor(N=64, distributed=True, seed=42) 48 | >>> fig, axes = plt.subplots(1, 2) 49 | >>> _ = axes[0].spy(G.W, markersize=2) 50 | >>> _ = G.plot(ax=axes[1]) 51 | 52 | """ 53 | 54 | def __init__(self, N=64, k=6, distributed=False, seed=None, **kwargs): 55 | 56 | self.distributed = distributed 57 | self.seed = seed 58 | 59 | plotting = {'limits': np.array([0, 1, 0, 1])} 60 | 61 | rng = np.random.default_rng(self.seed) 62 | 63 | if distributed: 64 | 65 | m = np.sqrt(N) 66 | if not m.is_integer(): 67 | raise ValueError('The number of vertices must be a ' 68 | 'perfect square if they are to be ' 69 | 'distributed on a grid.') 70 | 71 | coords = np.mgrid[0:1:1/m, 0:1:1/m].reshape(2, -1).T 72 | coords += rng.uniform(0, 1/m, (N, 2)) 73 | 74 | else: 75 | 76 | coords = rng.uniform(0, 1, (N, 2)) 77 | 78 | super(Sensor, self).__init__(Xin=coords, k=k, 79 | rescale=False, center=False, 80 | plotting=plotting, **kwargs) 81 | 82 | def _get_extra_repr(self): 83 | return {'k': self.k, 84 | 'distributed': self.distributed, 85 | 'seed': self.seed} 86 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/sphere.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp.graphs import NNGraph # prevent circular import in Python < 3.5 6 | 7 | 8 | class Sphere(NNGraph): 9 | r"""Spherical-shaped graph (NN-graph). 10 | 11 | Parameters 12 | ---------- 13 | radius : float 14 | Radius of the sphere (default = 1) 15 | nb_pts : int 16 | Number of vertices (default = 300) 17 | nb_dim : int 18 | Dimension (default = 3) 19 | sampling : string 20 | Variance of the distance kernel (default = 'random') 21 | (Can now only be 'random') 22 | seed : int 23 | Seed for the random number generator (for reproducible graphs). 24 | 25 | Examples 26 | -------- 27 | >>> import matplotlib.pyplot as plt 28 | >>> G = graphs.Sphere(nb_pts=100, seed=42) 29 | >>> fig = plt.figure() 30 | >>> ax1 = fig.add_subplot(121) 31 | >>> ax2 = fig.add_subplot(122, projection='3d') 32 | >>> _ = ax1.spy(G.W, markersize=1.5) 33 | >>> _ = _ = G.plot(ax=ax2) 34 | 35 | """ 36 | 37 | def __init__(self, 38 | radius=1, 39 | nb_pts=300, 40 | nb_dim=3, 41 | sampling='random', 42 | seed=None, 43 | **kwargs): 44 | 45 | self.radius = radius 46 | self.nb_pts = nb_pts 47 | self.nb_dim = nb_dim 48 | self.sampling = sampling 49 | self.seed = seed 50 | 51 | if self.sampling == 'random': 52 | 53 | rs = np.random.RandomState(seed) 54 | pts = rs.normal(0, 1, (self.nb_pts, self.nb_dim)) 55 | 56 | for i in range(self.nb_pts): 57 | pts[i] /= np.linalg.norm(pts[i]) 58 | 59 | else: 60 | 61 | raise ValueError('Unknown sampling {}'.format(sampling)) 62 | 63 | plotting = { 64 | 'vertex_size': 80, 65 | } 66 | 67 | super(Sphere, self).__init__(Xin=pts, k=10, 68 | center=False, rescale=False, 69 | plotting=plotting, **kwargs) 70 | 71 | def _get_extra_repr(self): 72 | attrs = {'radius': '{:.2f}'.format(self.radius), 73 | 'nb_pts': self.nb_pts, 74 | 'nb_dim': self.nb_dim, 75 | 'sampling': self.sampling, 76 | 'seed': self.seed} 77 | attrs.update(super(Sphere, self)._get_extra_repr()) 78 | return attrs 79 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/twomoons.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp import utils 6 | from pygsp.graphs import NNGraph # prevent circular import in Python < 3.5 7 | 8 | 9 | class TwoMoons(NNGraph): 10 | r"""Two Moons (NN-graph). 11 | 12 | Parameters 13 | ---------- 14 | moontype : 'standard' or 'synthesized' 15 | You have the freedom to chose if you want to create a standard 16 | two_moons graph or a synthesized one (default is 'standard'). 17 | 'standard' : Create a two_moons graph from a based graph. 18 | 'synthesized' : Create a synthesized two_moon 19 | sigmag : float 20 | Variance of the distance kernel (default = 0.05) 21 | dim : int 22 | The dimensionality of the points (default = 2). 23 | Only valid for moontype == 'standard'. 24 | N : int 25 | Number of vertices (default = 2000) 26 | Only valid for moontype == 'synthesized'. 27 | sigmad : float 28 | Variance of the data (do not set it too high or you won't see anything) 29 | (default = 0.05) 30 | Only valid for moontype == 'synthesized'. 31 | distance : float 32 | Distance between the two moons (default = 0.5) 33 | Only valid for moontype == 'synthesized'. 34 | seed : int 35 | Seed for the random number generator (for reproducible graphs). 36 | 37 | Examples 38 | -------- 39 | >>> import matplotlib.pyplot as plt 40 | >>> G = graphs.TwoMoons() 41 | >>> fig, axes = plt.subplots(1, 2) 42 | >>> _ = axes[0].spy(G.W, markersize=0.5) 43 | >>> _ = G.plot(edges=True, ax=axes[1]) 44 | 45 | """ 46 | 47 | def _create_arc_moon(self, N, sigmad, distance, number, seed): 48 | rng = np.random.default_rng(seed) 49 | phi = rng.uniform(size=(N, 1)) * np.pi 50 | r = 1 51 | rb = sigmad * rng.normal(size=(N, 1)) 52 | ab = rng.uniform(size=(N, 1)) * 2 * np.pi 53 | b = rb * np.exp(1j * ab) 54 | bx = np.real(b) 55 | by = np.imag(b) 56 | 57 | if number == 1: 58 | moonx = np.cos(phi) * r + bx + 0.5 59 | moony = -np.sin(phi) * r + by - (distance - 1)/2. 60 | elif number == 2: 61 | moonx = np.cos(phi) * r + bx - 0.5 62 | moony = np.sin(phi) * r + by + (distance - 1)/2. 63 | 64 | return np.concatenate((moonx, moony), axis=1) 65 | 66 | def __init__(self, moontype='standard', dim=2, sigmag=0.05, 67 | N=400, sigmad=0.07, distance=0.5, seed=None, **kwargs): 68 | 69 | self.moontype = moontype 70 | self.dim = dim 71 | self.sigmag = sigmag 72 | self.sigmad = sigmad 73 | self.distance = distance 74 | self.seed = seed 75 | 76 | if moontype == 'standard': 77 | N1, N2 = 1000, 1000 78 | data = utils.loadmat('pointclouds/two_moons') 79 | Xin = data['features'][:dim].T 80 | 81 | elif moontype == 'synthesized': 82 | N1 = N // 2 83 | N2 = N - N1 84 | 85 | coords1 = self._create_arc_moon(N1, sigmad, distance, 1, seed) 86 | coords2 = self._create_arc_moon(N2, sigmad, distance, 2, seed) 87 | 88 | Xin = np.concatenate((coords1, coords2)) 89 | 90 | else: 91 | raise ValueError('Unknown moontype {}'.format(moontype)) 92 | 93 | self.labels = np.concatenate((np.zeros(N1), np.ones(N2))) 94 | 95 | plotting = { 96 | 'vertex_size': 30, 97 | } 98 | 99 | super(TwoMoons, self).__init__(Xin=Xin, sigma=sigmag, k=5, 100 | center=False, rescale=False, 101 | plotting=plotting, **kwargs) 102 | 103 | def _get_extra_repr(self): 104 | attrs = {'moontype': self.moontype, 105 | 'dim': self.dim, 106 | 'sigmag': '{:.2f}'.format(self.sigmag), 107 | 'sigmad': '{:.2f}'.format(self.sigmad), 108 | 'distance': '{:.2f}'.format(self.distance), 109 | 'seed': self.seed} 110 | attrs.update(super(TwoMoons, self)._get_extra_repr()) 111 | return attrs 112 | -------------------------------------------------------------------------------- /pygsp/graphs/path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class Path(Graph): 10 | r"""Path graph. 11 | 12 | A signal on the path graph is akin to a 1-dimensional signal in classical 13 | signal processing. 14 | 15 | On the path graph, the graph Fourier transform (GFT) is the classical 16 | discrete cosine transform (DCT_). 17 | As the type-II DCT, the GFT assumes even boundary conditions on both sides. 18 | 19 | .. _DCT: https://en.wikipedia.org/wiki/Discrete_cosine_transform 20 | 21 | Parameters 22 | ---------- 23 | N : int 24 | Number of vertices. 25 | 26 | See Also 27 | -------- 28 | Ring : 1D line with periodic boundary conditions 29 | Grid2d : Kronecker product of two path graphs 30 | Comet : Generalization with a star at one end 31 | 32 | References 33 | ---------- 34 | :cite:`strang1999dct` shows that each DCT basis contains the eigenvectors 35 | of a symmetric "second difference" matrix. 36 | They get the eight types of DCTs by varying the boundary conditions. 37 | 38 | Examples 39 | -------- 40 | >>> import matplotlib.pyplot as plt 41 | >>> fig, axes = plt.subplots(2, 2, figsize=(10, 8)) 42 | >>> for i, directed in enumerate([False, True]): 43 | ... G = graphs.Path(N=10, directed=directed) 44 | ... _ = axes[i, 0].spy(G.W) 45 | ... _ = G.plot(ax=axes[i, 1]) 46 | 47 | The GFT of the path graph is the classical DCT. 48 | 49 | >>> from matplotlib import pyplot as plt 50 | >>> n_eigenvectors = 4 51 | >>> graph = graphs.Path(30) 52 | >>> fig, axes = plt.subplots(1, 2) 53 | >>> graph.set_coordinates('line1D') 54 | >>> graph.compute_fourier_basis() 55 | >>> _ = graph.plot(graph.U[:, :n_eigenvectors], ax=axes[0]) 56 | >>> _ = axes[0].legend(range(n_eigenvectors)) 57 | >>> _ = axes[1].plot(graph.e, '.') 58 | 59 | """ 60 | 61 | def __init__(self, N=16, directed=False, **kwargs): 62 | 63 | self.directed = directed 64 | if directed: 65 | sources = np.arange(0, N-1) 66 | targets = np.arange(1, N) 67 | n_edges = N - 1 68 | else: 69 | sources = np.concatenate((np.arange(0, N-1), np.arange(1, N))) 70 | targets = np.concatenate((np.arange(1, N), np.arange(0, N-1))) 71 | n_edges = 2 * (N - 1) 72 | weights = np.ones(n_edges) 73 | W = sparse.csr_matrix((weights, (sources, targets)), shape=(N, N)) 74 | plotting = {"limits": np.array([-1, N, -1, 1])} 75 | 76 | super(Path, self).__init__(W, plotting=plotting, **kwargs) 77 | 78 | self.set_coordinates('line2D') 79 | 80 | def _get_extra_repr(self): 81 | return dict(directed=self.directed) 82 | -------------------------------------------------------------------------------- /pygsp/graphs/randomregular.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from pygsp import utils 7 | from . import Graph # prevent circular import in Python < 3.5 8 | 9 | 10 | class RandomRegular(Graph): 11 | r"""Random k-regular graph. 12 | 13 | The random regular graph has the property that every node is connected to 14 | k other nodes. That graph is simple (without loops or double edges), 15 | k-regular (each vertex is adjacent to k nodes), and undirected. 16 | 17 | Parameters 18 | ---------- 19 | N : int 20 | Number of nodes (default is 64) 21 | k : int 22 | Number of connections, or degree, of each node (default is 6) 23 | max_iter : int 24 | Maximum number of iterations (default is 10) 25 | seed : int 26 | Seed for the random number generator (for reproducible graphs). 27 | 28 | Notes 29 | ----- 30 | The *pairing model* algorithm works as follows. First create n*d *half 31 | edges*. Then repeat as long as possible: pick a pair of half edges and if 32 | it's legal (doesn't create a loop nor a double edge) add it to the graph. 33 | 34 | References 35 | ---------- 36 | See :cite:`kim2003randomregulargraphs`. 37 | This code has been adapted from matlab to python. 38 | 39 | Examples 40 | -------- 41 | >>> import matplotlib.pyplot as plt 42 | >>> G = graphs.RandomRegular(N=64, k=5, seed=42) 43 | >>> G.set_coordinates(kind='spring', seed=42) 44 | >>> fig, axes = plt.subplots(1, 2) 45 | >>> _ = axes[0].spy(G.W, markersize=2) 46 | >>> _ = G.plot(ax=axes[1]) 47 | 48 | """ 49 | 50 | def __init__(self, N=64, k=6, max_iter=10, seed=None, **kwargs): 51 | 52 | self.k = k 53 | self.max_iter = max_iter 54 | self.seed = seed 55 | 56 | self.logger = utils.build_logger(__name__) 57 | 58 | rng = np.random.default_rng(seed) 59 | 60 | # continue until a proper graph is formed 61 | if (N * k) % 2 == 1: 62 | raise ValueError("input error: N*d must be even!") 63 | 64 | # a list of open half-edges 65 | U = np.kron(np.ones(k), np.arange(N)).astype(int) 66 | 67 | # the graphs adjacency matrix 68 | A = sparse.lil_matrix(np.zeros((N, N))) 69 | 70 | edgesTested = 0 71 | repetition = 1 72 | 73 | while np.size(U) and repetition < max_iter: 74 | edgesTested += 1 75 | 76 | if edgesTested % 5000 == 0: 77 | self.logger.debug("createRandRegGraph() progress: edges= " 78 | "{}/{}.".format(edgesTested, N*k/2)) 79 | 80 | # chose at random 2 half edges 81 | i1 = rng.integers(0, U.shape[0]) 82 | i2 = rng.integers(0, U.shape[0]) 83 | v1 = U[i1] 84 | v2 = U[i2] 85 | 86 | # check that there are no loops nor parallel edges 87 | if v1 == v2 or A[v1, v2] == 1: 88 | # restart process if needed 89 | if edgesTested == N*k: 90 | repetition = repetition + 1 91 | edgesTested = 0 92 | U = np.kron(np.ones(k), np.arange(N)) 93 | A = sparse.lil_matrix(np.zeros((N, N))) 94 | else: 95 | # add edge to graph 96 | A[v1, v2] = 1 97 | A[v2, v1] = 1 98 | 99 | # remove used half-edges 100 | v = sorted([i1, i2]) 101 | U = np.concatenate((U[:v[0]], U[v[0] + 1:v[1]], U[v[1] + 1:])) 102 | 103 | super(RandomRegular, self).__init__(A, **kwargs) 104 | 105 | self.is_regular() 106 | 107 | def is_regular(self): 108 | r""" 109 | Troubleshoot a given regular graph. 110 | 111 | """ 112 | warn = False 113 | msg = 'The given matrix' 114 | 115 | # check symmetry 116 | if np.abs(self.A - self.A.T).sum() > 0: 117 | warn = True 118 | msg = '{} is not symmetric,'.format(msg) 119 | 120 | # check parallel edged 121 | if self.A.max(axis=None) > 1: 122 | warn = True 123 | msg = '{} has parallel edges,'.format(msg) 124 | 125 | # check that d is d-regular 126 | if np.min(self.d) != np.max(self.d): 127 | warn = True 128 | msg = '{} is not d-regular,'.format(msg) 129 | 130 | # check that g doesn't contain any self-loop 131 | if self.A.diagonal().any(): 132 | warn = True 133 | msg = '{} has self loop.'.format(msg) 134 | 135 | if warn: 136 | self.logger.warning('{}.'.format(msg[:-1])) 137 | 138 | def _get_extra_repr(self): 139 | return dict(k=self.k, seed=self.seed) 140 | -------------------------------------------------------------------------------- /pygsp/graphs/randomring.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from pygsp import utils 7 | from . import Graph # prevent circular import in Python < 3.5 8 | 9 | 10 | class RandomRing(Graph): 11 | r"""Ring graph with randomly sampled vertices. 12 | 13 | Parameters 14 | ---------- 15 | N : int 16 | Number of vertices. 17 | angles : array_like, optional 18 | The angular coordinate, in :math:`[0, 2\pi]`, of the vertices. 19 | seed : int 20 | Seed for the random number generator (for reproducible graphs). 21 | 22 | Examples 23 | -------- 24 | >>> import matplotlib.pyplot as plt 25 | >>> G = graphs.RandomRing(N=10, seed=42) 26 | >>> fig, axes = plt.subplots(1, 2) 27 | >>> _ = axes[0].spy(G.W) 28 | >>> _ = G.plot(ax=axes[1]) 29 | >>> _ = axes[1].set_xlim(-1.1, 1.1) 30 | >>> _ = axes[1].set_ylim(-1.1, 1.1) 31 | 32 | """ 33 | 34 | def __init__(self, N=64, angles=None, seed=None, **kwargs): 35 | 36 | self.seed = seed 37 | 38 | if angles is None: 39 | rng = np.random.default_rng(seed) 40 | angles = np.sort(rng.uniform(0, 2*np.pi, size=N), axis=0) 41 | else: 42 | angles = np.asanyarray(angles) 43 | angles.sort() # Need to be sorted to take the difference. 44 | N = len(angles) 45 | if np.any(angles < 0) or np.any(angles >= 2*np.pi): 46 | raise ValueError('Angles should be in [0, 2 pi]') 47 | self.angles = angles 48 | 49 | if N < 3: 50 | # Asymmetric graph needed for 2 as 2 distances connect them. 51 | raise ValueError('There should be at least 3 vertices.') 52 | 53 | rows = range(0, N-1) 54 | cols = range(1, N) 55 | weights = np.diff(angles) 56 | 57 | # Close the loop. 58 | rows = np.concatenate((rows, [0])) 59 | cols = np.concatenate((cols, [N-1])) 60 | weights = np.concatenate((weights, [2*np.pi + angles[0] - angles[-1]])) 61 | 62 | W = sparse.coo_matrix((weights, (rows, cols)), shape=(N, N)) 63 | W = utils.symmetrize(W, method='triu') 64 | 65 | # Width as the expected angle. All angles are equal to that value when 66 | # the ring is uniformly sampled. 67 | width = 2 * np.pi / N 68 | assert (W.data.mean() - width) < 1e-10 69 | # TODO: why this kernel ? It empirically produces eigenvectors closer 70 | # to the sines and cosines. 71 | W.data = width / W.data 72 | 73 | coords = np.stack([np.cos(angles), np.sin(angles)], axis=1) 74 | plotting = {'limits': np.array([-1, 1, -1, 1])} 75 | 76 | # TODO: save angle and 2D position as graph signals 77 | super(RandomRing, self).__init__(W, coords=coords, plotting=plotting, 78 | **kwargs) 79 | 80 | def _get_extra_repr(self): 81 | return dict(seed=self.seed) 82 | -------------------------------------------------------------------------------- /pygsp/graphs/ring.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class Ring(Graph): 10 | r"""K-regular ring graph. 11 | 12 | A signal on the ring graph is akin to a 1-dimensional periodic signal in 13 | classical signal processing. 14 | 15 | On the ring graph, the graph Fourier transform (GFT) is the classical 16 | discrete Fourier transform (DFT_). 17 | Actually, the Laplacian of the ring graph is a `circulant matrix`_, and any 18 | circulant matrix is diagonalized by the DFT. 19 | 20 | .. _DFT: https://en.wikipedia.org/wiki/Discrete_Fourier_transform 21 | .. _circulant matrix: https://en.wikipedia.org/wiki/Circulant_matrix 22 | 23 | Parameters 24 | ---------- 25 | N : int 26 | Number of vertices. 27 | k : int 28 | Number of neighbors in each direction. 29 | 30 | See Also 31 | -------- 32 | Path : 1D line with even boundary conditions 33 | Torus : Kronecker product of two ring graphs 34 | 35 | Examples 36 | -------- 37 | >>> import matplotlib.pyplot as plt 38 | >>> G = graphs.Ring(N=10) 39 | >>> fig, axes = plt.subplots(1, 2) 40 | >>> _ = axes[0].spy(G.W) 41 | >>> _ = G.plot(ax=axes[1]) 42 | 43 | The GFT of the ring graph is the classical DFT. 44 | 45 | >>> from matplotlib import pyplot as plt 46 | >>> n_eigenvectors = 4 47 | >>> graph = graphs.Ring(30) 48 | >>> fig, axes = plt.subplots(1, 2) 49 | >>> graph.set_coordinates('line1D') 50 | >>> graph.compute_fourier_basis() 51 | >>> _ = graph.plot(graph.U[:, :n_eigenvectors], ax=axes[0]) 52 | >>> _ = axes[0].legend(range(n_eigenvectors)) 53 | >>> _ = axes[1].plot(graph.e, '.') 54 | 55 | """ 56 | 57 | def __init__(self, N=64, k=1, **kwargs): 58 | 59 | self.k = k 60 | 61 | if N < 3: 62 | # Asymmetric graph needed for 2 as 2 distances connect them. 63 | raise ValueError('There should be at least 3 vertices.') 64 | 65 | if 2*k > N: 66 | raise ValueError('Too many neighbors requested.') 67 | 68 | if 2*k == N: 69 | num_edges = N * (k - 1) + k 70 | else: 71 | num_edges = N * k 72 | 73 | i_inds = np.zeros((2 * num_edges)) 74 | j_inds = np.zeros((2 * num_edges)) 75 | 76 | tmpN = np.arange(N, dtype=int) 77 | for i in range(min(k, (N - 1) // 2)): 78 | i_inds[2*i * N + tmpN] = tmpN 79 | j_inds[2*i * N + tmpN] = np.remainder(tmpN + i + 1, N) 80 | i_inds[(2*i + 1)*N + tmpN] = np.remainder(tmpN + i + 1, N) 81 | j_inds[(2*i + 1)*N + tmpN] = tmpN 82 | 83 | if 2*k == N: 84 | i_inds[2*N*(k - 1) + tmpN] = tmpN 85 | i_inds[2*N*(k - 1) + tmpN] = np.remainder(tmpN + k + 1, N) 86 | 87 | W = sparse.csc_matrix((np.ones((2*num_edges)), (i_inds, j_inds)), 88 | shape=(N, N)) 89 | 90 | plotting = {'limits': np.array([-1, 1, -1, 1])} 91 | 92 | super(Ring, self).__init__(W, plotting=plotting, **kwargs) 93 | 94 | self.set_coordinates('ring2D') 95 | 96 | def _get_extra_repr(self): 97 | return dict(k=self.k) 98 | -------------------------------------------------------------------------------- /pygsp/graphs/star.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from . import Comet # prevent circular import in Python < 3.5 4 | 5 | 6 | class Star(Comet): 7 | r"""Star graph. 8 | 9 | A star with a central vertex and `N-1` branches. 10 | The central vertex has degree `N-1`, the others have degree 1. 11 | 12 | Parameters 13 | ---------- 14 | N : int 15 | Number of vertices. 16 | 17 | See Also 18 | -------- 19 | Comet : Generalization with a longer branch as a tail 20 | 21 | Examples 22 | -------- 23 | >>> import matplotlib.pyplot as plt 24 | >>> graph = graphs.Star(15) 25 | >>> graph 26 | Star(n_vertices=15, n_edges=14) 27 | >>> fig, axes = plt.subplots(1, 2) 28 | >>> _ = axes[0].spy(graph.W) 29 | >>> _ = graph.plot(ax=axes[1]) 30 | 31 | """ 32 | 33 | def __init__(self, N=10, **kwargs): 34 | plotting = dict(limits=[-1.1, 1.1, -1.1, 1.1]) 35 | plotting.update(kwargs.get('plotting', {})) 36 | super(Star, self).__init__(N, N-1, plotting=plotting, **kwargs) 37 | 38 | def _get_extra_repr(self): 39 | return dict() # Suppress Comet repr. 40 | -------------------------------------------------------------------------------- /pygsp/graphs/stochasticblockmodel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from pygsp import utils 7 | from . import Graph # prevent circular import in Python < 3.5 8 | 9 | 10 | class StochasticBlockModel(Graph): 11 | r"""Stochastic Block Model (SBM). 12 | 13 | The Stochastic Block Model graph is constructed by connecting nodes with a 14 | probability which depends on the cluster of the two nodes. One can define 15 | the clustering association of each node, denoted by vector z, but also the 16 | probability matrix M. All edge weights are equal to 1. By default, Mii > 17 | Mjk and nodes are uniformly clusterized. 18 | 19 | Parameters 20 | ---------- 21 | N : int 22 | Number of nodes (default is 1024). 23 | k : float 24 | Number of classes (default is 5). 25 | z : array_like 26 | the vector of length N containing the association between nodes and 27 | classes (default is random uniform). 28 | M : array_like 29 | the k by k matrix containing the probability of connecting nodes based 30 | on their class belonging (default using p and q). 31 | p : float or array_like 32 | the diagonal value(s) for the matrix M. If scalar they all have the 33 | same value. Otherwise expect a length k vector (default is p = 0.7). 34 | q : float or array_like 35 | the off-diagonal value(s) for the matrix M. If scalar they all have the 36 | same value. Otherwise expect a k x k matrix, diagonal will be 37 | discarded (default is q = 0.3/k). 38 | directed : bool 39 | Allow directed edges if True (default is False). 40 | self_loops : bool 41 | Allow self loops if True (default is False). 42 | connected : bool 43 | Force the graph to be connected (default is False). 44 | n_try : int or None 45 | Maximum number of trials to get a connected graph. If None, it will try 46 | forever. 47 | seed : int 48 | Seed for the random number generator (for reproducible graphs). 49 | 50 | Examples 51 | -------- 52 | >>> import matplotlib.pyplot as plt 53 | >>> G = graphs.StochasticBlockModel( 54 | ... 100, k=3, p=[0.4, 0.6, 0.3], q=0.02, seed=42) 55 | >>> G.set_coordinates(kind='spring', seed=42) 56 | >>> fig, axes = plt.subplots(1, 2) 57 | >>> _ = axes[0].spy(G.W, markersize=0.8) 58 | >>> _ = G.plot(ax=axes[1]) 59 | 60 | """ 61 | 62 | def __init__(self, N=1024, k=5, z=None, M=None, p=0.7, q=None, 63 | directed=False, self_loops=False, connected=False, 64 | n_try=10, seed=None, **kwargs): 65 | 66 | self.k = k 67 | self.directed = directed 68 | self.self_loops = self_loops 69 | self.connected = connected 70 | self.n_try = n_try 71 | self.seed = seed 72 | 73 | rng = np.random.default_rng(seed) 74 | 75 | if z is None: 76 | z = rng.integers(0, k, N) 77 | z.sort() # Sort for nice spy plot of W, where blocks are apparent. 78 | self.z = z 79 | 80 | if M is None: 81 | 82 | self.p = p 83 | p = np.asanyarray(p) 84 | if p.size == 1: 85 | p = p * np.ones(k) 86 | if p.shape != (k,): 87 | raise ValueError('Optional parameter p is neither a scalar ' 88 | 'nor a vector of length k.') 89 | 90 | if q is None: 91 | q = 0.3 / k 92 | self.q = q 93 | q = np.asanyarray(q) 94 | if q.size == 1: 95 | q = q * np.ones((k, k)) 96 | if q.shape != (k, k): 97 | raise ValueError('Optional parameter q is neither a scalar ' 98 | 'nor a matrix of size k x k.') 99 | 100 | M = q 101 | M.flat[::k+1] = p # edit the diagonal terms 102 | 103 | self.M = M 104 | 105 | if (M < 0).any() or (M > 1).any(): 106 | raise ValueError('Probabilities should be in [0, 1].') 107 | 108 | # TODO: higher memory, lesser computation alternative. 109 | # Along the lines of np.random.uniform(size=(N, N)) < p. 110 | # Or similar to sparse.random(N, N, p, data_rvs=lambda n: np.ones(n)). 111 | 112 | while (n_try is None) or (n_try > 0): 113 | 114 | nb_row, nb_col = 0, 0 115 | csr_data, csr_i, csr_j = [], [], [] 116 | for _ in range(N**2): 117 | if nb_row != nb_col or self_loops: 118 | if nb_row >= nb_col or directed: 119 | if rng.uniform() < M[z[nb_row], z[nb_col]]: 120 | csr_data.append(1) 121 | csr_i.append(nb_row) 122 | csr_j.append(nb_col) 123 | if nb_row < N-1: 124 | nb_row += 1 125 | else: 126 | nb_row = 0 127 | nb_col += 1 128 | 129 | W = sparse.csr_matrix((csr_data, (csr_i, csr_j)), shape=(N, N)) 130 | 131 | if not directed: 132 | W = utils.symmetrize(W, method='tril') 133 | 134 | if not connected: 135 | break 136 | if Graph(W).is_connected(): 137 | break 138 | if n_try is not None: 139 | n_try -= 1 140 | if connected and n_try == 0: 141 | raise ValueError('The graph could not be connected after {} ' 142 | 'trials. Increase the connection probability ' 143 | 'or the number of trials.'.format(self.n_try)) 144 | 145 | self.info = {'node_com': z, 'comm_sizes': np.bincount(z), 146 | 'world_rad': np.sqrt(N)} 147 | 148 | super(StochasticBlockModel, self).__init__(W, **kwargs) 149 | 150 | def _get_extra_repr(self): 151 | attrs = {'k': self.k} 152 | if type(self.p) is float: 153 | attrs['p'] = '{:.2f}'.format(self.p) 154 | if type(self.q) is float: 155 | attrs['q'] = '{:.2f}'.format(self.q) 156 | attrs.update({'directed': self.directed, 157 | 'self_loops': self.self_loops, 158 | 'connected': self.connected, 159 | 'seed': self.seed}) 160 | return attrs 161 | -------------------------------------------------------------------------------- /pygsp/graphs/swissroll.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | from pygsp import utils 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class SwissRoll(Graph): 10 | r"""Sampled Swiss roll manifold. 11 | 12 | Parameters 13 | ---------- 14 | N : int 15 | Number of vertices (default = 400) 16 | a : int 17 | (default = 1) 18 | b : int 19 | (default = 4) 20 | dim : int 21 | (default = 3) 22 | thresh : float 23 | (default = 1e-6) 24 | s : float 25 | sigma (default = sqrt(2./N)) 26 | noise : bool 27 | Wether to add noise or not (default = False) 28 | srtype : str 29 | Swiss roll Type, possible arguments are 'uniform' or 'classic' 30 | (default = 'uniform') 31 | seed : int 32 | Seed for the random number generator (for reproducible graphs). 33 | 34 | Examples 35 | -------- 36 | >>> import matplotlib.pyplot as plt 37 | >>> G = graphs.SwissRoll(N=200, seed=42) 38 | >>> fig = plt.figure() 39 | >>> ax1 = fig.add_subplot(121) 40 | >>> ax2 = fig.add_subplot(122, projection='3d') 41 | >>> _ = ax1.spy(G.W, markersize=1) 42 | >>> _ = G.plot(ax=ax2) 43 | 44 | """ 45 | 46 | def __init__(self, N=400, a=1, b=4, dim=3, thresh=1e-6, s=None, 47 | noise=False, srtype='uniform', seed=None, **kwargs): 48 | 49 | if s is None: 50 | s = np.sqrt(2. / N) 51 | 52 | self.a = a 53 | self.b = b 54 | self.dim = dim 55 | self.thresh = thresh 56 | self.s = s 57 | self.noise = noise 58 | self.srtype = srtype 59 | self.seed = seed 60 | 61 | rng = np.random.default_rng(seed) 62 | y1 = rng.uniform(size=N) 63 | y2 = rng.uniform(size=N) 64 | 65 | if srtype == 'uniform': 66 | tt = np.sqrt((b * b - a * a) * y1 + a * a) 67 | elif srtype == 'classic': 68 | tt = (b - a) * y1 + a 69 | tt *= np.pi 70 | 71 | if dim == 2: 72 | x = np.array((tt * np.cos(tt), tt * np.sin(tt))) 73 | elif dim == 3: 74 | x = np.array((tt * np.cos(tt), 21 * y2, tt * np.sin(tt))) 75 | 76 | if noise: 77 | x += rng.normal(size=x.shape) 78 | 79 | self.x = x 80 | self.dim = dim 81 | 82 | coords = utils.rescale_center(x) 83 | dist = utils.distanz(coords) 84 | W = np.exp(-np.power(dist, 2) / (2. * s**2)) 85 | W -= np.diag(np.diag(W)) 86 | W[W < thresh] = 0 87 | 88 | plotting = { 89 | 'vertex_size': 60, 90 | 'limits': np.array([-1, 1, -1, 1, -1, 1]), 91 | 'elevation': 15, 92 | 'azimuth': -90, 93 | 'distance': 7, 94 | } 95 | 96 | super(SwissRoll, self).__init__(W, coords=coords.T, 97 | plotting=plotting, 98 | **kwargs) 99 | 100 | def _get_extra_repr(self): 101 | return {'a': self.a, 102 | 'b': self.b, 103 | 'dim': self.dim, 104 | 'thresh': '{:.0e}'.format(self.thresh), 105 | 's': '{:.2f}'.format(self.s), 106 | 'noise': self.noise, 107 | 'srtype': self.srtype, 108 | 'seed': self.seed} 109 | -------------------------------------------------------------------------------- /pygsp/graphs/torus.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | from . import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class Torus(Graph): 10 | r"""Sampled torus manifold. 11 | 12 | On the torus, the graph Fourier transform (GFT) is the Kronecker product 13 | between the GFT of two :class:`~pygsp.graphs.Ring` graphs. 14 | 15 | Parameters 16 | ---------- 17 | Nv : int 18 | Number of vertices along the first dimension. 19 | Mv : int 20 | Number of vertices along the second dimension. Default is ``Nv``. 21 | 22 | See Also 23 | -------- 24 | Ring : 1D line with periodic boundary conditions 25 | Grid2d : Kronecker product of two path graphs 26 | 27 | Examples 28 | -------- 29 | >>> import matplotlib.pyplot as plt 30 | >>> G = graphs.Torus(10) 31 | >>> fig = plt.figure() 32 | >>> ax1 = fig.add_subplot(121) 33 | >>> ax2 = fig.add_subplot(122, projection='3d') 34 | >>> _ = ax1.spy(G.W, markersize=1.5) 35 | >>> _ = G.plot(ax=ax2) 36 | >>> _ = ax2.set_zlim(-1.5, 1.5) 37 | 38 | """ 39 | 40 | def __init__(self, Nv=16, Mv=None, **kwargs): 41 | 42 | if Mv is None: 43 | Mv = Nv 44 | 45 | self.Nv = Nv 46 | self.Mv = Mv 47 | 48 | # Create weighted adjancency matrix 49 | K = 2 * Nv 50 | J = 2 * Mv 51 | i_inds = np.zeros((K*Mv + J*Nv), dtype=float) 52 | j_inds = np.zeros((K*Mv + J*Nv), dtype=float) 53 | 54 | tmpK = np.arange(K, dtype=int) 55 | tmpNv1 = np.arange(Nv - 1) 56 | tmpNv = np.arange(Nv) 57 | 58 | for i in range(Mv): 59 | i_inds[i*K + tmpK] = i*Nv + \ 60 | np.concatenate((np.array([Nv - 1]), tmpNv1, tmpNv)) 61 | 62 | j_inds[i*K + tmpK] = i*Nv + \ 63 | np.concatenate((tmpNv, np.array([Nv - 1]), tmpNv1)) 64 | 65 | tmp2Nv = np.arange(2*Nv, dtype=int) 66 | 67 | for i in range(Mv - 1): 68 | i_inds[K*Mv + i*2*Nv + tmp2Nv] = \ 69 | np.concatenate((i*Nv + tmpNv, (i + 1)*Nv + tmpNv)) 70 | 71 | j_inds[K*Mv + i*2*Nv + tmp2Nv] = \ 72 | np.concatenate(((i + 1)*Nv + tmpNv, i*Nv + tmpNv)) 73 | 74 | i_inds[K*Mv + (Mv - 1)*2*Nv + tmp2Nv] = \ 75 | np.concatenate((tmpNv, (Mv - 1)*Nv + tmpNv)) 76 | 77 | j_inds[K*Mv + (Mv - 1)*2*Nv + tmp2Nv] = \ 78 | np.concatenate(((Mv - 1)*Nv + tmpNv, tmpNv)) 79 | 80 | W = sparse.csc_matrix((np.ones((K*Mv + J*Nv)), (i_inds, j_inds)), 81 | shape=(Mv*Nv, Mv*Nv)) 82 | 83 | # Create coordinate 84 | T = 1.5 + np.sin(np.arange(Mv)*2*np.pi/Mv).reshape(1, Mv) 85 | U = np.cos(np.arange(Mv)*2*np.pi/Mv).reshape(1, Mv) 86 | xtmp = np.cos(np.arange(Nv).reshape(Nv, 1)*2*np.pi/Nv)*T 87 | ytmp = np.sin(np.arange(Nv).reshape(Nv, 1)*2*np.pi/Nv)*T 88 | ztmp = np.kron(np.ones((Nv, 1)), U) 89 | coords = np.concatenate((np.reshape(xtmp, (Mv*Nv, 1), order='F'), 90 | np.reshape(ytmp, (Mv*Nv, 1), order='F'), 91 | np.reshape(ztmp, (Mv*Nv, 1), order='F')), 92 | axis=1) 93 | 94 | plotting = { 95 | 'vertex_size': 60, 96 | 'limits': np.array([-2.5, 2.5, -2.5, 2.5, -2.5, 2.5]) 97 | } 98 | 99 | super(Torus, self).__init__(W, coords=coords, 100 | plotting=plotting, **kwargs) 101 | 102 | def _get_extra_repr(self): 103 | return dict(Nv=self.Nv, Mv=self.Mv) 104 | -------------------------------------------------------------------------------- /pygsp/optimization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | r""" 4 | The :mod:`pygsp.optimization` module provides tools to solve convex 5 | optimization problems on graphs. 6 | """ 7 | 8 | from pygsp import utils 9 | 10 | 11 | logger = utils.build_logger(__name__) 12 | 13 | 14 | def _import_pyunlocbox(): 15 | try: 16 | from pyunlocbox import functions, solvers 17 | except Exception as e: 18 | raise ImportError('Cannot import pyunlocbox, which is needed to solve ' 19 | 'this optimization problem. Try to install it with ' 20 | 'pip (or conda) install pyunlocbox. ' 21 | 'Original exception: {}'.format(e)) 22 | return functions, solvers 23 | 24 | 25 | def prox_tv(x, gamma, G, A=None, At=None, nu=1, tol=10e-4, maxit=200, use_matrix=True): 26 | r""" 27 | Total Variation proximal operator for graphs. 28 | 29 | This function computes the TV proximal operator for graphs. The TV norm 30 | is the one norm of the gradient. The gradient is defined in the 31 | function :meth:`pygsp.graphs.Graph.grad`. 32 | This function requires the PyUNLocBoX to be executed. 33 | 34 | This function solves: 35 | 36 | :math:`sol = \min_{z} \frac{1}{2} \|x - z\|_2^2 + \gamma \|x\|_{TV}` 37 | 38 | Parameters 39 | ---------- 40 | x: int 41 | Input signal 42 | gamma: ndarray 43 | Regularization parameter 44 | G: graph object 45 | Graphs structure 46 | A: lambda function 47 | Forward operator, this parameter allows to solve the following problem: 48 | :math:`sol = \min_{z} \frac{1}{2} \|x - z\|_2^2 + \gamma \| A x\|_{TV}` 49 | (default = Id) 50 | At: lambda function 51 | Adjoint operator. (default = Id) 52 | nu: float 53 | Bound on the norm of the operator (default = 1) 54 | tol: float 55 | Stops criterion for the loop. The algorithm will stop if : 56 | :math:`\frac{n(t) - n(t - 1)} {n(t)} < tol` 57 | where :math:`n(t) = f(x) + 0.5 \|x-y\|_2^2` is the objective function at iteration :math:`t` 58 | (default = :math:`10e-4`) 59 | maxit: int 60 | Maximum iteration. (default = 200) 61 | use_matrix: bool 62 | If a matrix should be used. (default = True) 63 | 64 | Returns 65 | ------- 66 | sol: solution 67 | 68 | Examples 69 | -------- 70 | 71 | """ 72 | if A is None: 73 | def A(x): 74 | return x 75 | if At is None: 76 | def At(x): 77 | return x 78 | 79 | tight = 0 80 | l1_nu = 2 * G.lmax * nu 81 | 82 | if use_matrix: 83 | def l1_a(x): 84 | return G.Diff * A(x) 85 | 86 | def l1_at(x): 87 | return G.Diff * At(D.T * x) 88 | else: 89 | def l1_a(x): 90 | return G.grad(A(x)) 91 | 92 | def l1_at(x): 93 | return G.div(x) 94 | 95 | functions, _ = _import_pyunlocbox() 96 | functions.norm_l1(x, gamma, A=l1_a, At=l1_at, tight=tight, maxit=maxit, verbose=verbose, tol=tol) 97 | -------------------------------------------------------------------------------- /pygsp/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Test suite of the PyGSP package, broken by modules. 5 | 6 | """ 7 | 8 | import unittest 9 | 10 | from . import test_graphs 11 | from . import test_filters 12 | from . import test_utils 13 | from . import test_learning 14 | from . import test_docstrings 15 | from . import test_plotting 16 | 17 | 18 | suite = unittest.TestSuite([ 19 | test_graphs.suite, 20 | test_filters.suite, 21 | test_utils.suite, 22 | test_learning.suite, 23 | test_docstrings.suite, 24 | test_plotting.suite, # TODO: can SIGSEGV if not last 25 | ]) 26 | -------------------------------------------------------------------------------- /pygsp/tests/test_docstrings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Test suite for the docstrings of the pygsp package. 5 | 6 | """ 7 | 8 | import os 9 | import unittest 10 | import doctest 11 | 12 | 13 | def gen_recursive_file(root, ext): 14 | for root, _, filenames in os.walk(root): 15 | for name in filenames: 16 | if name.lower().endswith(ext): 17 | yield os.path.join(root, name) 18 | 19 | 20 | def test_docstrings(root, ext, setup=None): 21 | files = list(gen_recursive_file(root, ext)) 22 | return doctest.DocFileSuite(*files, setUp=setup, tearDown=teardown, 23 | module_relative=False) 24 | 25 | 26 | def setup(doctest): 27 | import numpy 28 | import pygsp 29 | doctest.globs = { 30 | 'graphs': pygsp.graphs, 31 | 'filters': pygsp.filters, 32 | 'utils': pygsp.utils, 33 | 'np': numpy, 34 | } 35 | 36 | 37 | def teardown(doctest): 38 | """Close matplotlib figures to avoid warning and save memory.""" 39 | import pygsp 40 | pygsp.plotting.close_all() 41 | 42 | 43 | # Docstrings from API reference. 44 | suite_reference = test_docstrings('pygsp', '.py', setup) 45 | 46 | # Docstrings from tutorials. No setup to not forget imports. 47 | suite_tutorials = test_docstrings('.', '.rst') 48 | 49 | suite = unittest.TestSuite([suite_reference, suite_tutorials]) 50 | -------------------------------------------------------------------------------- /pygsp/tests/test_learning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Test suite for the learning module of the pygsp package. 5 | 6 | """ 7 | 8 | import unittest 9 | 10 | import numpy as np 11 | 12 | from pygsp import graphs, filters, learning 13 | 14 | 15 | class TestCase(unittest.TestCase): 16 | 17 | def test_regression_tikhonov_1(self): 18 | """Solve a trivial regression problem.""" 19 | G = graphs.Ring(N=8) 20 | signal = np.array([0, np.nan, 4, np.nan, 4, np.nan, np.nan, np.nan]) 21 | signal_bak = signal.copy() 22 | mask = np.array([True, False, True, False, True, False, False, False]) 23 | truth = np.array([0, 2, 4, 4, 4, 3, 2, 1]) 24 | recovery = learning.regression_tikhonov(G, signal, mask, tau=0) 25 | np.testing.assert_allclose(recovery, truth) 26 | 27 | # Test the numpy solution. 28 | G = graphs.Graph(G.W.toarray()) 29 | recovery = learning.regression_tikhonov(G, signal, mask, tau=0) 30 | np.testing.assert_allclose(recovery, truth) 31 | np.testing.assert_allclose(signal_bak, signal) 32 | 33 | def test_regression_tikhonov_2(self): 34 | """Solve a regression problem with a constraint.""" 35 | G = graphs.Sensor(100) 36 | G.estimate_lmax() 37 | 38 | # Create a smooth signal. 39 | filt = filters.Filter(G, lambda x: 1 / (1 + 10*x)) 40 | rng = np.random.default_rng(1) 41 | signal = filt.analyze(rng.normal(size=(G.n_vertices, 5))) 42 | 43 | # Make the input signal. 44 | mask = rng.uniform(0, 1, G.n_vertices) > 0.5 45 | measures = signal.copy() 46 | measures[~mask] = np.nan 47 | measures_bak = measures.copy() 48 | 49 | # Solve the problem. 50 | recovery0 = learning.regression_tikhonov(G, measures, mask, tau=0) 51 | np.testing.assert_allclose(measures_bak, measures) 52 | 53 | recovery1 = np.zeros_like(recovery0) 54 | for i in range(recovery0.shape[1]): 55 | recovery1[:, i] = learning.regression_tikhonov( 56 | G, measures[:, i], mask, tau=0) 57 | np.testing.assert_allclose(measures_bak, measures) 58 | 59 | G = graphs.Graph(G.W.toarray()) 60 | recovery2 = learning.regression_tikhonov(G, measures, mask, tau=0) 61 | recovery3 = np.zeros_like(recovery0) 62 | for i in range(recovery0.shape[1]): 63 | recovery3[:, i] = learning.regression_tikhonov( 64 | G, measures[:, i], mask, tau=0) 65 | 66 | np.testing.assert_allclose(recovery1, recovery0) 67 | np.testing.assert_allclose(recovery2, recovery0) 68 | np.testing.assert_allclose(recovery3, recovery0) 69 | np.testing.assert_allclose(measures_bak, measures) 70 | 71 | def test_regression_tikhonov_3(self, tau=3.5): 72 | """Solve a relaxed regression problem.""" 73 | G = graphs.Sensor(100) 74 | G.estimate_lmax() 75 | 76 | # Create a smooth signal. 77 | filt = filters.Filter(G, lambda x: 1 / (1 + 10*x)) 78 | rng = np.random.default_rng(1) 79 | signal = filt.analyze(rng.normal(size=(G.n_vertices, 6))) 80 | 81 | # Make the input signal. 82 | mask = rng.uniform(0, 1, G.n_vertices) > 0.5 83 | measures = signal.copy() 84 | measures[~mask] = 18 85 | measures_bak = measures.copy() 86 | 87 | L = G.L.toarray() 88 | recovery = np.matmul(np.linalg.inv(np.diag(1*mask) + tau * L), 89 | (mask * measures.T).T) 90 | 91 | # Solve the problem. 92 | recovery0 = learning.regression_tikhonov(G, measures, mask, tau=tau) 93 | np.testing.assert_allclose(measures_bak, measures) 94 | recovery1 = np.zeros_like(recovery0) 95 | for i in range(recovery0.shape[1]): 96 | recovery1[:, i] = learning.regression_tikhonov( 97 | G, measures[:, i], mask, tau) 98 | np.testing.assert_allclose(measures_bak, measures) 99 | 100 | G = graphs.Graph(G.W.toarray()) 101 | recovery2 = learning.regression_tikhonov(G, measures, mask, tau) 102 | recovery3 = np.zeros_like(recovery0) 103 | for i in range(recovery0.shape[1]): 104 | recovery3[:, i] = learning.regression_tikhonov( 105 | G, measures[:, i], mask, tau) 106 | 107 | np.testing.assert_allclose(recovery0, recovery, atol=1e-5) 108 | np.testing.assert_allclose(recovery1, recovery, atol=1e-5) 109 | np.testing.assert_allclose(recovery2, recovery, atol=1e-5) 110 | np.testing.assert_allclose(recovery3, recovery, atol=1e-5) 111 | np.testing.assert_allclose(measures_bak, measures) 112 | 113 | def test_classification_tikhonov(self): 114 | """Solve a classification problem.""" 115 | G = graphs.Logo() 116 | signal = np.zeros([G.n_vertices], dtype=int) 117 | signal[G.info['idx_s']] = 1 118 | signal[G.info['idx_p']] = 2 119 | 120 | # Make the input signal. 121 | rng = np.random.default_rng(2) 122 | mask = rng.uniform(size=G.n_vertices) > 0.3 123 | 124 | measures = signal.copy() 125 | measures[~mask] = -1 126 | measures_bak = measures.copy() 127 | 128 | # Solve the classification problem. 129 | recovery = learning.classification_tikhonov(G, measures, mask, tau=0) 130 | recovery = np.argmax(recovery, axis=1) 131 | 132 | np.testing.assert_array_equal(recovery, signal) 133 | 134 | # Test the function with the simplex projection. 135 | recovery = learning.classification_tikhonov_simplex( 136 | G, measures, mask, tau=0.1) 137 | 138 | # Assert that the probabilities sums to 1 139 | np.testing.assert_allclose(np.sum(recovery, axis=1), 1) 140 | 141 | # Check the quality of the solution. 142 | recovery = np.argmax(recovery, axis=1) 143 | np.testing.assert_allclose(signal, recovery) 144 | np.testing.assert_allclose(measures_bak, measures) 145 | 146 | 147 | suite = unittest.TestLoader().loadTestsFromTestCase(TestCase) 148 | -------------------------------------------------------------------------------- /pygsp/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Test suite for the utils module of the pygsp package. 5 | 6 | """ 7 | 8 | import unittest 9 | 10 | import numpy as np 11 | from scipy import sparse 12 | 13 | from pygsp import graphs, utils 14 | 15 | 16 | class TestCase(unittest.TestCase): 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | pass 21 | 22 | @classmethod 23 | def tearDownClass(cls): 24 | pass 25 | 26 | def test_symmetrize(self): 27 | W = sparse.random(100, 100, random_state=42) 28 | for method in ['average', 'maximum', 'fill', 'tril', 'triu']: 29 | # Test that the regular and sparse versions give the same result. 30 | W1 = utils.symmetrize(W, method=method) 31 | W2 = utils.symmetrize(W.toarray(), method=method) 32 | np.testing.assert_equal(W1.toarray(), W2) 33 | self.assertRaises(ValueError, utils.symmetrize, W, 'sum') 34 | 35 | 36 | suite = unittest.TestLoader().loadTestsFromTestCase(TestCase) 37 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from setuptools import setup 5 | 6 | 7 | setup( 8 | name='PyGSP', 9 | version='0.5.1', 10 | description='Graph Signal Processing in Python', 11 | long_description=open('README.rst').read(), 12 | long_description_content_type='text/x-rst', 13 | author='EPFL LTS2', 14 | url='https://github.com/epfl-lts2/pygsp', 15 | project_urls={ 16 | 'Documentation': 'https://pygsp.readthedocs.io', 17 | 'Download': 'https://pypi.org/project/PyGSP', 18 | 'Source Code': 'https://github.com/epfl-lts2/pygsp', 19 | 'Bug Tracker': 'https://github.com/epfl-lts2/pygsp/issues', 20 | 'Try It Online': 'https://mybinder.org/v2/gh/epfl-lts2/pygsp/master?urlpath=lab/tree/examples/playground.ipynb', 21 | }, 22 | packages=[ 23 | 'pygsp', 24 | 'pygsp.graphs', 25 | 'pygsp.graphs.nngraphs', 26 | 'pygsp.filters', 27 | 'pygsp.tests', 28 | ], 29 | package_data={'pygsp': ['data/pointclouds/*.mat']}, 30 | test_suite='pygsp.tests.suite', 31 | python_requires='>=3.7', 32 | install_requires=[ 33 | 'numpy', 34 | 'scipy', 35 | ], 36 | extras_require={ 37 | # Optional dependencies for development. Some bring additional 38 | # functionalities, others are for testing, documentation, or packaging. 39 | 'dev': [ 40 | # Import and export. 41 | 'networkx', 42 | # 'graph-tool', cannot be installed by pip 43 | # Construct patch graphs from images. 44 | 'scikit-image', 45 | # Approximate nearest neighbors for kNN graphs. 46 | 'pyflann3', 47 | # Convex optimization on graph. 48 | 'pyunlocbox', 49 | # Plot graphs, signals, and filters. 50 | 'matplotlib', 51 | # Interactive graph visualization. 52 | 'pyqtgraph', 53 | 'PyOpenGL', 54 | 'PyQt5', 55 | # Run the tests. 56 | 'flake8', 57 | 'coverage', 58 | 'coveralls', 59 | # Build the documentation. 60 | 'sphinx', 61 | 'numpydoc', 62 | 'sphinxcontrib-bibtex', 63 | 'sphinx-gallery', 64 | 'memory_profiler', 65 | 'sphinx-rtd-theme', 66 | 'sphinx-copybutton', 67 | # Build and upload packages. 68 | 'wheel', 69 | 'twine', 70 | ], 71 | }, 72 | license="BSD", 73 | keywords='graph signal processing', 74 | platforms='any', 75 | classifiers=[ 76 | 'Development Status :: 4 - Beta', 77 | 'Topic :: Scientific/Engineering', 78 | 'Intended Audience :: Developers', 79 | 'Intended Audience :: Education', 80 | 'Intended Audience :: Science/Research', 81 | 'License :: OSI Approved :: BSD License', 82 | 'Operating System :: OS Independent', 83 | 'Programming Language :: Python', 84 | 'Programming Language :: Python :: 3', 85 | 'Programming Language :: Python :: 3.7', 86 | 'Programming Language :: Python :: 3.8', 87 | 'Programming Language :: Python :: 3.9', 88 | ], 89 | ) 90 | --------------------------------------------------------------------------------