├── pygsp ├── graphs │ ├── nngraphs │ │ ├── __init__.py │ │ ├── bunny.py │ │ ├── grid2dimgpatches.py │ │ ├── sphere.py │ │ ├── sensor.py │ │ ├── cube.py │ │ ├── imgpatches.py │ │ └── twomoons.py │ ├── fullconnected.py │ ├── logo.py │ ├── star.py │ ├── airfoil.py │ ├── minnesota.py │ ├── linegraph.py │ ├── erdosrenyi.py │ ├── davidsensornet.py │ ├── barabasialbert.py │ ├── comet.py │ ├── path.py │ ├── randomring.py │ ├── grid2d.py │ ├── lowstretchtree.py │ ├── ring.py │ ├── swissroll.py │ ├── torus.py │ ├── randomregular.py │ ├── __init__.py │ └── stochasticblockmodel.py ├── data │ ├── pointclouds │ │ ├── bunny.mat │ │ ├── airfoil.mat │ │ ├── david500.mat │ │ ├── david64.mat │ │ ├── logogsp.mat │ │ ├── minnesota.mat │ │ └── two_moons.mat │ ├── readme_example_filter.png │ └── readme_example_graph.png ├── tests │ ├── test_utils.py │ ├── test_learning.py │ └── test_docstrings.py ├── filters │ ├── halfcosine.py │ ├── itersine.py │ ├── simoncelli.py │ ├── papadakis.py │ ├── regular.py │ ├── held.py │ ├── mexicanhat.py │ ├── expwin.py │ ├── rectangular.py │ ├── meyer.py │ ├── simpletight.py │ ├── gabor.py │ ├── abspline.py │ ├── __init__.py │ ├── heat.py │ ├── wave.py │ └── modulation.py ├── __init__.py ├── features.py └── optimization.py ├── doc ├── changelog.rst ├── contributing.rst ├── reference │ ├── filters.rst │ ├── utils.rst │ ├── features.rst │ ├── learning.rst │ ├── plotting.rst │ ├── reduction.rst │ ├── optimization.rst │ ├── graphs.rst │ └── index.rst ├── references.rst ├── index.rst ├── tutorials │ ├── index.rst │ ├── pyramid.rst │ ├── wavelet.rst │ └── optimization.rst ├── conf.py └── references.bib ├── examples ├── README.txt ├── eigenvalue_concentration.py ├── heat_diffusion.py ├── wave_propagation.py ├── kernel_localization.py ├── fourier_transform.py ├── fourier_basis.py ├── random_walk.py ├── filtering.py ├── eigenvector_localization.py └── playground.ipynb ├── MANIFEST.in ├── postBuild ├── .readthedocs.yml ├── .gitignore ├── .zenodo.json ├── LICENSE.txt ├── Makefile ├── .pre-commit-config.yaml ├── pyproject.toml ├── .github └── workflows │ └── ci.yml └── CONTRIBUTING.rst /pygsp/graphs/nngraphs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/changelog.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CHANGELOG.rst 2 | -------------------------------------------------------------------------------- /doc/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | ======== 2 | Examples 3 | ======== 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE.txt 2 | include README.rst 3 | include CHANGELOG.rst 4 | -------------------------------------------------------------------------------- /doc/reference/filters.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Filters 3 | ======= 4 | 5 | .. automodule:: pygsp.filters 6 | -------------------------------------------------------------------------------- /doc/reference/utils.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Utilities 3 | ========= 4 | 5 | .. automodule:: pygsp.utils 6 | -------------------------------------------------------------------------------- /doc/reference/features.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Features 3 | ======== 4 | 5 | .. automodule:: pygsp.features 6 | -------------------------------------------------------------------------------- /doc/reference/learning.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Learning 3 | ======== 4 | 5 | .. automodule:: pygsp.learning 6 | -------------------------------------------------------------------------------- /doc/reference/plotting.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Plotting 3 | ======== 4 | 5 | .. automodule:: pygsp.plotting 6 | -------------------------------------------------------------------------------- /doc/reference/reduction.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Reduction 3 | ========= 4 | 5 | .. automodule:: pygsp.reduction 6 | -------------------------------------------------------------------------------- /postBuild: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Tell https://mybinder.org to simply install the package. 3 | pip install .[dev] 4 | -------------------------------------------------------------------------------- /pygsp/data/pointclouds/bunny.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/pointclouds/bunny.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/airfoil.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/pointclouds/airfoil.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/david500.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/pointclouds/david500.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/david64.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/pointclouds/david64.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/logogsp.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/pointclouds/logogsp.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/minnesota.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/pointclouds/minnesota.mat -------------------------------------------------------------------------------- /pygsp/data/pointclouds/two_moons.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/pointclouds/two_moons.mat -------------------------------------------------------------------------------- /pygsp/data/readme_example_filter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/readme_example_filter.png -------------------------------------------------------------------------------- /pygsp/data/readme_example_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-lts2/pygsp/HEAD/pygsp/data/readme_example_graph.png -------------------------------------------------------------------------------- /doc/reference/optimization.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Optimization 3 | ============ 4 | 5 | .. automodule:: pygsp.optimization 6 | -------------------------------------------------------------------------------- /doc/references.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | References 3 | ========== 4 | 5 | .. bibliography:: references.bib 6 | :cited: 7 | :style: alpha 8 | -------------------------------------------------------------------------------- /doc/reference/graphs.rst: -------------------------------------------------------------------------------- 1 | ====== 2 | Graphs 3 | ====== 4 | 5 | .. automodule:: pygsp.graphs 6 | :exclude-members: Graph 7 | 8 | .. autoclass:: Graph 9 | :inherited-members: 10 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | :hidden: 5 | 6 | Home 7 | tutorials/index 8 | examples/index 9 | reference/index 10 | changelog 11 | contributing 12 | references 13 | -------------------------------------------------------------------------------- /doc/reference/index.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | API reference 3 | ============= 4 | 5 | .. automodule:: pygsp 6 | 7 | .. toctree:: 8 | :hidden: 9 | 10 | graphs 11 | filters 12 | plotting 13 | learning 14 | reduction 15 | features 16 | optimization 17 | utils 18 | -------------------------------------------------------------------------------- /doc/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Tutorials 3 | ========= 4 | 5 | The following are some tutorials which explain how to use the toolbox and show 6 | how to use it to solve some problems. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | intro 12 | wavelet 13 | optimization 14 | pyramid 15 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | 8 | formats: 9 | - htmlzip 10 | 11 | sphinx: 12 | builder: html 13 | configuration: doc/conf.py 14 | 15 | python: 16 | install: 17 | - method: pip 18 | path: . 19 | extra_requirements: 20 | - dev 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | uv.lock 3 | __pycache__/ 4 | 5 | # Packages 6 | *.egg 7 | *.egg-info 8 | dist 9 | build 10 | eggs 11 | parts 12 | bin 13 | var 14 | sdist 15 | develop-eggs 16 | .installed.cfg 17 | 18 | # Installer logs 19 | pip-log.txt 20 | 21 | # Coverage reports 22 | .coverage 23 | htmlcov 24 | 25 | # Complexity 26 | output/*.html 27 | output/*/index.html 28 | 29 | # Sphinx documentation 30 | /doc/_build/ 31 | /doc/examples/ 32 | /doc/backrefs/ 33 | 34 | # Vim swap files 35 | .*.swp 36 | 37 | # Mac OS garbage 38 | .DS_Store 39 | 40 | # Jupyter notebook 41 | .ipynb_checkpoints/ 42 | 43 | # Visual Studio Code 44 | .vscode/ 45 | -------------------------------------------------------------------------------- /pygsp/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test suite for the utils module of the pygsp package. 3 | 4 | """ 5 | 6 | import numpy as np 7 | import pytest 8 | from scipy import sparse 9 | 10 | from pygsp import utils 11 | 12 | 13 | def test_symmetrize(): 14 | """Test matrix symmetrization methods.""" 15 | W = sparse.random(100, 100, random_state=42) 16 | for method in ["average", "maximum", "fill", "tril", "triu"]: 17 | # Test that the regular and sparse versions give the same result. 18 | W1 = utils.symmetrize(W, method=method) 19 | W2 = utils.symmetrize(W.toarray(), method=method) 20 | np.testing.assert_equal(W1.toarray(), W2) 21 | 22 | # Test that invalid method raises ValueError 23 | with pytest.raises(ValueError): 24 | utils.symmetrize(W, "sum") 25 | -------------------------------------------------------------------------------- /pygsp/graphs/fullconnected.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .graph import Graph # prevent circular import in Python < 3.5 4 | 5 | 6 | class FullConnected(Graph): 7 | r"""Fully connected graph. 8 | 9 | All weights are set to 1. There is no self-connections. 10 | 11 | Parameters 12 | ---------- 13 | N : int 14 | Number of vertices (default = 10) 15 | 16 | Examples 17 | -------- 18 | >>> import matplotlib.pyplot as plt 19 | >>> G = graphs.FullConnected(N=20) 20 | >>> G.set_coordinates(kind='spring', seed=42) 21 | >>> fig, axes = plt.subplots(1, 2) 22 | >>> _ = axes[0].spy(G.W, markersize=5) 23 | >>> _ = G.plot(ax=axes[1]) 24 | 25 | """ 26 | 27 | def __init__(self, N=10, **kwargs): 28 | W = np.ones((N, N)) - np.identity(N) 29 | plotting = {"limits": np.array([-1, 1, -1, 1])} 30 | 31 | super().__init__(W, plotting=plotting, **kwargs) 32 | -------------------------------------------------------------------------------- /pygsp/graphs/logo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pygsp import utils 4 | 5 | from .graph import Graph # prevent circular import in Python < 3.5 6 | 7 | 8 | class Logo(Graph): 9 | r"""GSP logo. 10 | 11 | Examples 12 | -------- 13 | >>> import matplotlib.pyplot as plt 14 | >>> G = graphs.Logo() 15 | >>> fig, axes = plt.subplots(1, 2) 16 | >>> _ = axes[0].spy(G.W, markersize=0.5) 17 | >>> _ = G.plot(ax=axes[1]) 18 | 19 | """ 20 | 21 | def __init__(self, **kwargs): 22 | data = utils.loadmat("pointclouds/logogsp") 23 | 24 | # Remove 1 because the index in python start at 0 and not at 1 25 | self.info = { 26 | "idx_g": data["idx_g"] - 1, 27 | "idx_s": data["idx_s"] - 1, 28 | "idx_p": data["idx_p"] - 1, 29 | } 30 | 31 | plotting = {"limits": np.array([0, 640, -400, 0])} 32 | 33 | super().__init__(data["W"], coords=data["coords"], plotting=plotting, **kwargs) 34 | -------------------------------------------------------------------------------- /.zenodo.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "PyGSP: Graph Signal Processing in Python", 3 | "description": "The PyGSP facilitates a wide variety of operations on graphs, like computing their Fourier basis, filtering or interpolating signals, plotting graphs, signals, and filters.", 4 | "upload_type": "software", 5 | "license": "BSD-3-Clause", 6 | "access_right": "open", 7 | "creators": [ 8 | { 9 | "name": "Micha\u00ebl Defferrard", 10 | "affiliation": "EPFL", 11 | "orcid": "0000-0002-6028-9024" 12 | }, 13 | { 14 | "name": "Lionel Martin", 15 | "affiliation": "EPFL" 16 | }, 17 | { 18 | "name": "Rodrigo Pena", 19 | "affiliation": "EPFL" 20 | }, 21 | { 22 | "name": "Nathana\u00ebl Perraudin", 23 | "affiliation": "EPFL", 24 | "orcid": "0000-0001-8285-1308" 25 | } 26 | ], 27 | "related_identifiers": [ 28 | { 29 | "scheme": "url", 30 | "identifier": "https://github.com/epfl-lts2/pygsp", 31 | "relation": "isSupplementTo" 32 | }, 33 | { 34 | "scheme": "doi", 35 | "identifier": "10.5281/zenodo.1003157", 36 | "relation": "isPartOf" 37 | } 38 | ] 39 | } 40 | -------------------------------------------------------------------------------- /pygsp/graphs/star.py: -------------------------------------------------------------------------------- 1 | from . import Comet # prevent circular import in Python < 3.5 2 | 3 | 4 | class Star(Comet): 5 | r"""Star graph. 6 | 7 | A star with a central vertex and `N-1` branches. 8 | The central vertex has degree `N-1`, the others have degree 1. 9 | 10 | Parameters 11 | ---------- 12 | N : int 13 | Number of vertices. 14 | 15 | See Also 16 | -------- 17 | Comet : Generalization with a longer branch as a tail 18 | 19 | Examples 20 | -------- 21 | >>> import matplotlib.pyplot as plt 22 | >>> graph = graphs.Star(15) 23 | >>> graph 24 | Star(n_vertices=15, n_edges=14) 25 | >>> fig, axes = plt.subplots(1, 2) 26 | >>> _ = axes[0].spy(graph.W) 27 | >>> _ = graph.plot(ax=axes[1]) 28 | 29 | """ 30 | 31 | def __init__(self, N=10, **kwargs): 32 | plotting = dict(limits=[-1.1, 1.1, -1.1, 1.1]) 33 | plotting.update(kwargs.get("plotting", {})) 34 | super().__init__(N, N - 1, plotting=plotting, **kwargs) 35 | 36 | def _get_extra_repr(self): 37 | return dict() # Suppress Comet repr. 38 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/bunny.py: -------------------------------------------------------------------------------- 1 | from pygsp import utils 2 | 3 | from .nngraph import NNGraph # prevent circular import in Python < 3.5 4 | 5 | 6 | class Bunny(NNGraph): 7 | r"""Stanford bunny (NN-graph). 8 | 9 | References 10 | ---------- 11 | See :cite:`turk1994zippered`. 12 | 13 | Examples 14 | -------- 15 | >>> import matplotlib.pyplot as plt 16 | >>> G = graphs.Bunny() 17 | >>> fig = plt.figure() 18 | >>> ax1 = fig.add_subplot(121) 19 | >>> ax2 = fig.add_subplot(122, projection='3d') 20 | >>> _ = ax1.spy(G.W, markersize=0.1) 21 | >>> _ = _ = G.plot(ax=ax2) 22 | 23 | """ 24 | 25 | def __init__(self, **kwargs): 26 | data = utils.loadmat("pointclouds/bunny") 27 | 28 | plotting = { 29 | "vertex_size": 10, 30 | "elevation": -90, 31 | "azimuth": 90, 32 | "distance": 8, 33 | } 34 | 35 | super().__init__( 36 | Xin=data["bunny"], 37 | epsilon=0.02, 38 | NNtype="radius", 39 | center=False, 40 | rescale=False, 41 | plotting=plotting, 42 | **kwargs, 43 | ) 44 | -------------------------------------------------------------------------------- /pygsp/graphs/airfoil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from pygsp import utils 5 | 6 | from .graph import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class Airfoil(Graph): 10 | r"""Airfoil graph. 11 | 12 | Examples 13 | -------- 14 | >>> import matplotlib.pyplot as plt 15 | >>> G = graphs.Airfoil() 16 | >>> fig, axes = plt.subplots(1, 2) 17 | >>> _ = axes[0].spy(G.W, markersize=0.5) 18 | >>> _ = G.plot(edges=True, ax=axes[1]) 19 | 20 | """ 21 | 22 | def __init__(self, **kwargs): 23 | data = utils.loadmat("pointclouds/airfoil") 24 | coords = np.concatenate((data["x"], data["y"]), axis=1) 25 | 26 | i_inds = np.reshape(data["i_inds"] - 1, 12289) 27 | j_inds = np.reshape(data["j_inds"] - 1, 12289) 28 | A = sparse.coo_matrix((np.ones(12289), (i_inds, j_inds)), shape=(4253, 4253)) 29 | W = (A + A.T) / 2.0 30 | 31 | plotting = { 32 | "vertex_size": 30, 33 | "limits": np.array( 34 | [-1e-4, 1.01 * data["x"].max(), -1e-4, 1.01 * data["y"].max()] 35 | ), 36 | } 37 | 38 | super().__init__(W, coords=coords, plotting=plotting, **kwargs) 39 | -------------------------------------------------------------------------------- /examples/eigenvalue_concentration.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Concentration of the eigenvalues 3 | ================================ 4 | 5 | The eigenvalues of the graph Laplacian concentrates to the same value as the 6 | graph becomes full. 7 | """ 8 | 9 | import numpy as np 10 | from matplotlib import pyplot as plt 11 | 12 | import pygsp as pg 13 | 14 | n_neighbors = [1, 2, 5, 8] 15 | fig, axes = plt.subplots(3, len(n_neighbors), figsize=(15, 8)) 16 | 17 | for k, ax in zip(n_neighbors, axes.T): 18 | graph = pg.graphs.Ring(17, k=k) 19 | graph.compute_fourier_basis() 20 | graph.plot(graph.U[:, 1], ax=ax[0]) 21 | ax[0].axis("equal") 22 | ax[1].spy(graph.W) 23 | ax[2].plot(graph.e, ".") 24 | ax[2].set_title(f"k={k}") 25 | # graph.set_coordinates('line1D') 26 | # graph.plot(graph.U[:, :4], ax=ax[3], title='') 27 | 28 | # Check that the DFT matrix is an eigenbasis of the Laplacian. 29 | U = np.fft.fft(np.identity(graph.n_vertices)) 30 | LambdaM = (graph.L.todense().dot(U)) / (U + 1e-15) 31 | # Eigenvalues should be real. 32 | assert np.all(np.abs(np.imag(LambdaM)) < 1e-10) 33 | LambdaM = np.real(LambdaM) 34 | # Check that the eigenvectors are really eigenvectors of the laplacian. 35 | Lambda = np.mean(LambdaM, axis=0) 36 | assert np.all(np.abs(LambdaM - Lambda) < 1e-10) 37 | 38 | fig.tight_layout() 39 | -------------------------------------------------------------------------------- /examples/heat_diffusion.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Heat diffusion 3 | ============== 4 | 5 | Solve the heat equation by filtering the initial conditions with the heat 6 | kernel :class:`pygsp.filters.Heat`. 7 | """ 8 | 9 | import numpy as np 10 | from matplotlib import pyplot as plt 11 | 12 | import pygsp as pg 13 | 14 | n_side = 13 15 | G = pg.graphs.Grid2d(n_side) 16 | G.compute_fourier_basis() 17 | 18 | sources = [ 19 | (n_side // 4 * n_side) + (n_side // 4), 20 | (n_side * 3 // 4 * n_side) + (n_side * 3 // 4), 21 | ] 22 | x = np.zeros(G.n_vertices) 23 | x[sources] = 5 24 | 25 | times = [0, 5, 10, 20] 26 | 27 | fig, axes = plt.subplots(2, len(times), figsize=(12, 5)) 28 | for i, t in enumerate(times): 29 | g = pg.filters.Heat(G, scale=t) 30 | title = r"$\hat{{f}}({0}) = g_{{1,{0}}} \odot \hat{{f}}(0)$".format(t) 31 | g.plot(alpha=1, ax=axes[0, i], title=title) 32 | axes[0, i].set_xlabel(r"$\lambda$") 33 | # axes[0, i].set_ylabel(r'$g(\lambda)$') 34 | if i > 0: 35 | axes[0, i].set_ylabel("") 36 | y = g.filter(x) 37 | (line,) = axes[0, i].plot(G.e, G.gft(y)) 38 | labels = [rf"$\hat{{f}}({t})$", rf"$g_{{1,{t}}}$"] 39 | axes[0, i].legend([line, axes[0, i].lines[-3]], labels, loc="lower right") 40 | G.plot(y, edges=False, highlight=sources, ax=axes[1, i], title=rf"$f({t})$") 41 | axes[1, i].set_aspect("equal", "box") 42 | axes[1, i].set_axis_off() 43 | 44 | fig.tight_layout() 45 | -------------------------------------------------------------------------------- /examples/wave_propagation.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Wave propagation 3 | ================ 4 | 5 | Solve the wave equation by filtering the initial conditions with the wave 6 | kernel :class:`pygsp.filters.Wave`. 7 | """ 8 | 9 | from os import path 10 | 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | 14 | import pygsp as pg 15 | 16 | n_side = 13 17 | G = pg.graphs.Grid2d(n_side) 18 | G.compute_fourier_basis() 19 | 20 | sources = [ 21 | (n_side // 4 * n_side) + (n_side // 4), 22 | (n_side * 3 // 4 * n_side) + (n_side * 3 // 4), 23 | ] 24 | x = np.zeros(G.n_vertices) 25 | x[sources] = 5 26 | 27 | times = [0, 5, 10, 20] 28 | 29 | fig, axes = plt.subplots(2, len(times), figsize=(12, 5)) 30 | for i, t in enumerate(times): 31 | g = pg.filters.Wave(G, time=t, speed=1) 32 | title = r"$\hat{{f}}({0}) = g_{{1,{0}}} \odot \hat{{f}}(0)$".format(t) 33 | g.plot(alpha=1, ax=axes[0, i], title=title) 34 | axes[0, i].set_xlabel(r"$\lambda$") 35 | # axes[0, i].set_ylabel(r'$g(\lambda)$') 36 | if i > 0: 37 | axes[0, i].set_ylabel("") 38 | y = g.filter(x) 39 | (line,) = axes[0, i].plot(G.e, G.gft(y)) 40 | labels = [rf"$\hat{{f}}({t})$", rf"$g_{{1,{t}}}$"] 41 | axes[0, i].legend([line, axes[0, i].lines[-3]], labels, loc="lower right") 42 | G.plot(y, edges=False, highlight=sources, ax=axes[1, i], title=rf"$f({t})$") 43 | axes[1, i].set_aspect("equal", "box") 44 | axes[1, i].set_axis_off() 45 | 46 | fig.tight_layout() 47 | -------------------------------------------------------------------------------- /pygsp/filters/halfcosine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class HalfCosine(Filter): 7 | r"""Design an half cosine filter bank (tight frame). 8 | 9 | Parameters 10 | ---------- 11 | G : graph 12 | Nf : int 13 | Number of filters from 0 to lmax (default = 6) 14 | 15 | Examples 16 | -------- 17 | 18 | Filter bank's representation in Fourier and time (ring graph) domains. 19 | 20 | >>> import matplotlib.pyplot as plt 21 | >>> G = graphs.Ring(N=20) 22 | >>> G.estimate_lmax() 23 | >>> G.set_coordinates('line1D') 24 | >>> g = filters.HalfCosine(G) 25 | >>> s = g.localize(G.N // 2) 26 | >>> fig, axes = plt.subplots(1, 2) 27 | >>> _ = g.plot(ax=axes[0]) 28 | >>> _ = G.plot(s, ax=axes[1]) 29 | 30 | """ 31 | 32 | def __init__(self, G, Nf=6): 33 | if Nf <= 2: 34 | raise ValueError("The number of filters must be greater than 2.") 35 | 36 | dila_fact = G.lmax * 3 / (Nf - 2) 37 | 38 | def kernel(x): 39 | y = np.cos(2 * np.pi * (x / dila_fact - 0.5)) 40 | y = np.multiply((0.5 + 0.5 * y), (x >= 0)) 41 | return np.multiply(y, (x <= dila_fact)) 42 | 43 | kernels = [] 44 | 45 | for i in range(Nf): 46 | 47 | def kernel_centered(x, i=i): 48 | return kernel(x - dila_fact / 3 * (i - 2)) 49 | 50 | kernels.append(kernel_centered) 51 | 52 | super().__init__(G, kernels) 53 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, EPFL LTS2 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of PyGSP nor the names of its contributors may be used 15 | to endorse or promote products derived from this software without specific 16 | prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | NB = $(sort $(wildcard examples/*.ipynb)) 2 | .PHONY: help clean install lint test doc dist release 3 | 4 | help: 5 | @echo "clean remove non-source files and clean source files" 6 | @echo "install install package in development mode with all dependencies" 7 | @echo "lint check style" 8 | @echo "test run tests and check coverage" 9 | @echo "doc generate HTML documentation and check links" 10 | @echo "dist package (source & wheel)" 11 | @echo "release package and upload to PyPI" 12 | 13 | clean: 14 | git clean -Xdf 15 | jupyter nbconvert --inplace --ClearOutputPreprocessor.enabled=True $(NB) 16 | 17 | install: 18 | uv sync --all-extras 19 | 20 | 21 | lint: 22 | uv run flake8 --doctests --exclude=doc,.venv,build --max-line-length=88 --extend-ignore=E203 23 | 24 | # Matplotlib doesn't print to screen. Also faster. 25 | export MPLBACKEND = agg 26 | # Virtual framebuffer nonetheless needed for the pyqtgraph backend. 27 | export DISPLAY = :99 28 | 29 | test: 30 | Xvfb $$DISPLAY -screen 0 800x600x24 & 31 | uv run coverage run --branch --source pygsp -m pytest 32 | uv run coverage report 33 | uv run coverage html 34 | killall Xvfb 35 | 36 | doc: 37 | uv run sphinx-build -b html -d doc/_build/doctrees doc doc/_build/html 38 | uv run sphinx-build -b linkcheck -d doc/_build/doctrees doc doc/_build/linkcheck 39 | 40 | dist: clean 41 | uv sync --all-extras 42 | uv build 43 | ls -lh dist/* 44 | uv run twine check dist/* 45 | @echo "The built packages are valid and can be uploaded successfully" 46 | 47 | release: dist 48 | uv publish 49 | -------------------------------------------------------------------------------- /examples/kernel_localization.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Kernel localization 3 | =================== 4 | 5 | In classical signal processing, a filter can be translated in the vertex 6 | domain. We cannot do that on graphs. Instead, we can 7 | :meth:`~pygsp.filters.Filter.localize` a filter kernel. Note how on classic 8 | structures (like the ring), the localized kernel is the same everywhere, while 9 | it changes when localized on irregular graphs. 10 | """ 11 | 12 | import numpy as np 13 | from matplotlib import pyplot as plt 14 | 15 | import pygsp as pg 16 | 17 | fig, axes = plt.subplots(2, 4, figsize=(10, 4)) 18 | 19 | graphs = [ 20 | pg.graphs.Ring(40), 21 | pg.graphs.Sensor(64, seed=42), 22 | ] 23 | 24 | locations = [0, 10, 20] 25 | 26 | for graph, axs in zip(graphs, axes): 27 | graph.compute_fourier_basis() 28 | g = pg.filters.Heat(graph) 29 | g.plot(ax=axs[0], title="heat kernel") 30 | axs[0].set_xlabel(r"eigenvalues $\lambda$") 31 | axs[0].set_ylabel( 32 | r"$g(\lambda) = \exp \left( \frac{{-{}\lambda}}{{\lambda_{{max}}}} \right)$".format( 33 | g.scale[0] 34 | ) 35 | ) 36 | maximum = 0 37 | for loc in locations: 38 | x = g.localize(loc) 39 | maximum = np.maximum(maximum, x.max()) 40 | for loc, ax in zip(locations, axs[1:]): 41 | graph.plot( 42 | g.localize(loc), 43 | limits=[0, maximum], 44 | highlight=loc, 45 | ax=ax, 46 | title=rf"$g(L) \delta_{{{loc}}}$", 47 | ) 48 | ax.set_axis_off() 49 | 50 | fig.tight_layout() 51 | -------------------------------------------------------------------------------- /examples/fourier_transform.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Fourier transform 3 | ================= 4 | 5 | The graph Fourier transform :meth:`pygsp.graphs.Graph.gft` transforms a 6 | signal from the vertex domain to the spectral domain. The smoother the signal 7 | (see :meth:`pygsp.graphs.Graph.dirichlet_energy`), the lower in the frequencies 8 | its energy is concentrated. 9 | """ 10 | 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | 14 | import pygsp as pg 15 | 16 | G = pg.graphs.Sensor(seed=42) 17 | G.compute_fourier_basis() 18 | 19 | scales = [10, 3, 0] 20 | limit = 0.44 21 | 22 | fig, axes = plt.subplots(2, len(scales), figsize=(12, 4)) 23 | fig.subplots_adjust(hspace=0.5) 24 | 25 | x0 = np.random.default_rng(1).normal(size=G.N) 26 | for i, scale in enumerate(scales): 27 | g = pg.filters.Heat(G, scale) 28 | x = g.filter(x0).squeeze() 29 | x /= np.linalg.norm(x) 30 | x_hat = G.gft(x).squeeze() 31 | 32 | assert np.all((-limit < x) & (x < limit)) 33 | G.plot(x, limits=[-limit, limit], ax=axes[0, i]) 34 | axes[0, i].set_axis_off() 35 | axes[0, i].set_title(f"$x^T L x = {G.dirichlet_energy(x):.2f}$") 36 | 37 | axes[1, i].plot(G.e, np.abs(x_hat), ".-") 38 | axes[1, i].set_xticks(range(0, 16, 4)) 39 | axes[1, i].set_xlabel(r"graph frequency $\lambda$") 40 | axes[1, i].set_ylim(-0.05, 0.95) 41 | 42 | axes[1, 0].set_ylabel(r"frequency content $\hat{x}(\lambda)$") 43 | 44 | # axes[0, 0].set_title(r'$x$: signal in the vertex domain') 45 | # axes[1, 0].set_title(r'$\hat{x}$: signal in the spectral domain') 46 | 47 | fig.tight_layout() 48 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/grid2dimgpatches.py: -------------------------------------------------------------------------------- 1 | # prevent circular import in Python < 3.5 2 | from ..graph import Graph 3 | from ..grid2d import Grid2d 4 | from .imgpatches import ImgPatches 5 | 6 | 7 | class Grid2dImgPatches(Graph): 8 | r"""Union of a patch graph with a 2D grid graph. 9 | 10 | Parameters 11 | ---------- 12 | img : array 13 | Input image. 14 | aggregate: callable, optional 15 | Function to aggregate the weights ``Wp`` of the patch graph and the 16 | ``Wg`` of the grid graph. Default is ``lambda Wp, Wg: Wp + Wg``. 17 | kwargs : dict 18 | Parameters passed to :class:`ImgPatches`. 19 | 20 | See Also 21 | -------- 22 | ImgPatches 23 | Grid2d 24 | 25 | Examples 26 | -------- 27 | >>> import matplotlib.pyplot as plt 28 | >>> from skimage import data, img_as_float 29 | >>> img = img_as_float(data.camera()[::64, ::64]) 30 | >>> G = graphs.Grid2dImgPatches(img) 31 | >>> fig, axes = plt.subplots(1, 2) 32 | >>> _ = axes[0].spy(G.W, markersize=2) 33 | >>> _ = G.plot(ax=axes[1]) 34 | 35 | """ 36 | 37 | def __init__(self, img, aggregate=lambda Wp, Wg: Wp + Wg, **kwargs): 38 | self.Gg = Grid2d(img.shape[0], img.shape[1]) 39 | self.Gp = ImgPatches(img, **kwargs) 40 | 41 | W = aggregate(self.Gp.W, self.Gg.W) 42 | super().__init__(W, coords=self.Gg.coords, plotting=self.Gg.plotting) 43 | 44 | def _get_extra_repr(self): 45 | attrs = self.Gg._get_extra_repr() 46 | attrs.update(self.Gp._get_extra_repr()) 47 | return attrs 48 | -------------------------------------------------------------------------------- /examples/fourier_basis.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Fourier basis 3 | ============= 4 | 5 | The eigenvectors of the graph Laplacian form the Fourier basis. 6 | The eigenvalues are a measure of variation of their corresponding eigenvector. 7 | The lower the eigenvalue, the smoother the eigenvector. They are hence a 8 | measure of "frequency". 9 | 10 | In classical signal processing, Fourier modes are completely delocalized, like 11 | on the grid graph. For general graphs however, Fourier modes might be 12 | localized. See :attr:`pygsp.graphs.Graph.coherence`. 13 | """ 14 | 15 | import numpy as np 16 | from matplotlib import pyplot as plt 17 | 18 | import pygsp as pg 19 | 20 | n_eigenvectors = 7 21 | 22 | fig, axes = plt.subplots(2, 7, figsize=(15, 4)) 23 | 24 | 25 | def plot_eigenvectors(G, axes): 26 | G.compute_fourier_basis(n_eigenvectors) 27 | limits = [f(G.U) for f in (np.min, np.max)] 28 | for i, ax in enumerate(axes): 29 | G.plot(G.U[:, i], limits=limits, colorbar=False, vertex_size=50, ax=ax) 30 | energy = abs(G.dirichlet_energy(G.U[:, i])) 31 | ax.set_title(r"$u_{0}^\top L u_{0} = {1:.2f}$".format(i + 1, energy)) 32 | ax.set_axis_off() 33 | 34 | 35 | G = pg.graphs.Grid2d(10, 10) 36 | plot_eigenvectors(G, axes[0]) 37 | fig.subplots_adjust(hspace=0.5, right=0.8) 38 | cax = fig.add_axes([0.82, 0.60, 0.01, 0.26]) 39 | fig.colorbar(axes[0, -1].collections[1], cax=cax, ticks=[-0.2, 0, 0.2]) 40 | 41 | G = pg.graphs.Sensor(seed=42) 42 | plot_eigenvectors(G, axes[1]) 43 | fig.subplots_adjust(hspace=0.5, right=0.8) 44 | cax = fig.add_axes([0.82, 0.16, 0.01, 0.26]) 45 | _ = fig.colorbar(axes[1, -1].collections[1], cax=cax, ticks=[-0.4, 0, 0.4]) 46 | -------------------------------------------------------------------------------- /pygsp/graphs/minnesota.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from pygsp import utils 5 | 6 | from .graph import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class Minnesota(Graph): 10 | r"""Minnesota road network (from MatlabBGL). 11 | 12 | Parameters 13 | ---------- 14 | connected : bool 15 | If True, the adjacency matrix is adjusted so that all edge weights are 16 | equal to 1, and the graph is connected. Set to False to get the 17 | original disconnected graph. 18 | 19 | References 20 | ---------- 21 | See :cite:`gleich`. 22 | 23 | Examples 24 | -------- 25 | >>> import matplotlib.pyplot as plt 26 | >>> G = graphs.Minnesota() 27 | >>> fig, axes = plt.subplots(1, 2) 28 | >>> _ = axes[0].spy(G.W, markersize=0.5) 29 | >>> _ = G.plot(ax=axes[1]) 30 | 31 | """ 32 | 33 | def __init__(self, connected=True, **kwargs): 34 | self.connected = connected 35 | 36 | data = utils.loadmat("pointclouds/minnesota") 37 | self.labels = data["labels"] 38 | A = data["A"] 39 | 40 | plotting = {"limits": np.array([-98, -89, 43, 50]), "vertex_size": 40} 41 | 42 | if connected: 43 | # Missing edges needed to connect the graph. 44 | A = sparse.lil_matrix(A) 45 | A[348, 354] = 1 46 | A[354, 348] = 1 47 | A = sparse.csc_matrix(A) 48 | 49 | # Binarize: 8 entries are equal to 2 instead of 1. 50 | A = (A > 0).astype(bool) 51 | 52 | super().__init__(A, coords=data["xy"], plotting=plotting, **kwargs) 53 | 54 | def _get_extra_repr(self): 55 | return dict(connected=self.connected) 56 | -------------------------------------------------------------------------------- /pygsp/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`pygsp` package is mainly organized around the following two modules: 3 | 4 | * :mod:`.graphs` to create and manipulate various kinds of graphs, 5 | * :mod:`.filters` to create and manipulate various graph filters. 6 | 7 | Moreover, the following modules provide additional functionality: 8 | 9 | * :mod:`.plotting` to plot, 10 | * :mod:`.reduction` to reduce a graph while keeping its structure, 11 | * :mod:`.features` to compute features on graphs, 12 | * :mod:`.learning` to solve learning problems, 13 | * :mod:`.optimization` to help solving convex optimization problems, 14 | * :mod:`.utils` for various utilities. 15 | 16 | """ 17 | 18 | from . import features # noqa: F401 19 | from . import filters # noqa: F401 20 | from . import graphs # noqa: F401 21 | from . import learning # noqa: F401 22 | from . import optimization # noqa: F401 23 | from . import plotting # noqa: F401 24 | from . import reduction # noqa: F401 25 | from . import utils # noqa: F401 26 | 27 | # Users only call the plot methods from the objects. 28 | # It's thus more convenient for them to have the doc there. 29 | # But it's more convenient for developers to have the doc alongside the code. 30 | try: 31 | filters.Filter.plot.__doc__ = plotting._plot_filter.__doc__ 32 | graphs.Graph.plot.__doc__ = plotting._plot_graph.__doc__ 33 | graphs.Graph.plot_spectrogram.__doc__ = plotting._plot_spectrogram.__doc__ 34 | except AttributeError: 35 | # For Python 2.7. 36 | filters.Filter.plot.__func__.__doc__ = plotting._plot_filter.__doc__ 37 | graphs.Graph.plot.__func__.__doc__ = plotting._plot_graph.__doc__ 38 | graphs.Graph.plot_spectrogram.__func__.__doc__ = plotting._plot_spectrogram.__doc__ 39 | 40 | __version__ = "0.6.1" 41 | __release_date__ = "2025-09-11" 42 | -------------------------------------------------------------------------------- /pygsp/graphs/linegraph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from pygsp import utils 5 | 6 | from .graph import Graph # prevent circular import in Python < 3.5 7 | 8 | logger = utils.build_logger(__name__) 9 | 10 | 11 | class LineGraph(Graph): 12 | r"""Build the line graph of a graph. 13 | 14 | Each vertex of the line graph represents an edge in the original graph. Two 15 | vertices are connected if the edges they represent share a vertex in the 16 | original graph. 17 | 18 | Parameters 19 | ---------- 20 | graph : :class:`Graph` 21 | 22 | Examples 23 | -------- 24 | >>> import matplotlib.pyplot as plt 25 | >>> graph = graphs.Sensor(5, k=2, seed=10) 26 | >>> line_graph = graphs.LineGraph(graph) 27 | >>> fig, ax = plt.subplots() 28 | >>> fig, ax = graph.plot('blue', edge_color='blue', indices=True, ax=ax) 29 | >>> fig, ax = line_graph.plot('red', edge_color='red', indices=True, ax=ax) 30 | >>> _ = ax.set_title('graph and its line graph') 31 | 32 | """ 33 | 34 | def __init__(self, graph, **kwargs): 35 | if graph.is_weighted(): 36 | logger.warning( 37 | "Your graph is weighted, and is considered " 38 | "unweighted to build a binary line graph." 39 | ) 40 | 41 | graph.compute_differential_operator() 42 | # incidence = np.abs(graph.D) # weighted? 43 | incidence = graph.D != 0 44 | 45 | adjacency = incidence.T.dot(incidence).astype(int) 46 | adjacency -= sparse.identity(graph.n_edges, dtype=int) 47 | 48 | try: 49 | coords = incidence.T.dot(graph.coords) / 2 50 | except AttributeError: 51 | coords = None 52 | 53 | super().__init__(adjacency, coords=coords, plotting=graph.plotting, **kwargs) 54 | -------------------------------------------------------------------------------- /pygsp/filters/itersine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Itersine(Filter): 7 | r"""Design an itersine filter bank (tight frame). 8 | 9 | Create an itersine half overlap filter bank of Nf filters. 10 | Going from 0 to lambda_max. 11 | 12 | Parameters 13 | ---------- 14 | G : graph 15 | Nf : int (optional) 16 | Number of filters from 0 to lmax. (default = 6) 17 | overlap : int (optional) 18 | (default = 2) 19 | 20 | Examples 21 | -------- 22 | 23 | Filter bank's representation in Fourier and time (ring graph) domains. 24 | 25 | >>> import matplotlib.pyplot as plt 26 | >>> G = graphs.Ring(N=20) 27 | >>> G.estimate_lmax() 28 | >>> G.set_coordinates('line1D') 29 | >>> g = filters.Itersine(G) 30 | >>> s = g.localize(G.N // 2) 31 | >>> fig, axes = plt.subplots(1, 2) 32 | >>> _ = g.plot(ax=axes[0]) 33 | >>> _ = G.plot(s, ax=axes[1]) 34 | 35 | """ 36 | 37 | def __init__(self, G, Nf=6, overlap=2): 38 | self.overlap = overlap 39 | self.mu = np.linspace(0, G.lmax, num=Nf) 40 | 41 | scales = G.lmax / (Nf - overlap + 1) * overlap 42 | 43 | def kernel(x): 44 | y = np.cos(x * np.pi) ** 2 45 | y = np.sin(0.5 * np.pi * y) 46 | return y * ((x >= -0.5) * (x <= 0.5)) 47 | 48 | kernels = [] 49 | for i in range(1, Nf + 1): 50 | 51 | def kernel_centered(x, i=i): 52 | y = kernel(x / scales - (i - overlap / 2) / overlap) 53 | return y * np.sqrt(2 / overlap) 54 | 55 | kernels.append(kernel_centered) 56 | 57 | super().__init__(G, kernels) 58 | 59 | def _get_extra_repr(self): 60 | return dict(overlap=f"{self.overlap:.2f}") 61 | -------------------------------------------------------------------------------- /pygsp/graphs/erdosrenyi.py: -------------------------------------------------------------------------------- 1 | # prevent circular import in Python < 3.5 2 | from .stochasticblockmodel import StochasticBlockModel 3 | 4 | 5 | class ErdosRenyi(StochasticBlockModel): 6 | r"""Erdos Renyi graph. 7 | 8 | The Erdos Renyi graph is constructed by randomly connecting nodes. Each 9 | edge is included in the graph with probability p, independently from any 10 | other edge. All edge weights are equal to 1. 11 | 12 | Parameters 13 | ---------- 14 | N : int 15 | Number of nodes (default is 100). 16 | p : float 17 | Probability to connect a node with another one. 18 | directed : bool 19 | Allow directed edges if True (default is False). 20 | self_loops : bool 21 | Allow self loops if True (default is False). 22 | connected : bool 23 | Force the graph to be connected (default is False). 24 | n_try : int 25 | Maximum number of trials to get a connected graph (default is 10). 26 | seed : int 27 | Seed for the random number generator (for reproducible graphs). 28 | 29 | Examples 30 | -------- 31 | >>> import matplotlib.pyplot as plt 32 | >>> G = graphs.ErdosRenyi(N=64, seed=42) 33 | >>> G.set_coordinates(kind='spring', seed=42) 34 | >>> fig, axes = plt.subplots(1, 2) 35 | >>> _ = axes[0].spy(G.W, markersize=2) 36 | >>> _ = G.plot(ax=axes[1]) 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | N=100, 43 | p=0.1, 44 | directed=False, 45 | self_loops=False, 46 | connected=False, 47 | n_try=10, 48 | seed=None, 49 | **kwargs, 50 | ): 51 | super().__init__( 52 | N=N, 53 | k=1, 54 | p=p, 55 | directed=directed, 56 | self_loops=self_loops, 57 | connected=connected, 58 | n_try=n_try, 59 | seed=seed, 60 | **kwargs, 61 | ) 62 | -------------------------------------------------------------------------------- /pygsp/graphs/davidsensornet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pygsp import utils 4 | 5 | from .graph import Graph # prevent circular import in Python < 3.5 6 | 7 | 8 | class DavidSensorNet(Graph): 9 | r"""Sensor network. 10 | 11 | Parameters 12 | ---------- 13 | N : int 14 | Number of vertices (default = 64). Values of 64 and 500 yield 15 | pre-computed and saved graphs. Other values yield randomly generated 16 | graphs. 17 | seed : int 18 | Seed for the random number generator (for reproducible graphs). 19 | 20 | Examples 21 | -------- 22 | >>> import matplotlib.pyplot as plt 23 | >>> G = graphs.DavidSensorNet() 24 | >>> fig, axes = plt.subplots(1, 2) 25 | >>> _ = axes[0].spy(G.W, markersize=2) 26 | >>> _ = G.plot(ax=axes[1]) 27 | 28 | """ 29 | 30 | def __init__(self, N=64, seed=None, **kwargs): 31 | self.seed = seed 32 | 33 | if N == 64: 34 | data = utils.loadmat("pointclouds/david64") 35 | assert data["N"][0, 0] == N 36 | W = data["W"] 37 | coords = data["coords"] 38 | 39 | elif N == 500: 40 | data = utils.loadmat("pointclouds/david500") 41 | assert data["N"][0, 0] == N 42 | W = data["W"] 43 | coords = data["coords"] 44 | 45 | else: 46 | coords = np.random.default_rng(seed).uniform(size=(N, 2)) 47 | 48 | target_dist_cutoff = -0.125 * N / 436.075 + 0.2183 49 | T = 0.6 50 | s = np.sqrt(-(target_dist_cutoff**2) / (2 * np.log(T))) 51 | d = utils.distanz(coords.T) 52 | W = np.exp(-np.power(d, 2) / (2.0 * s**2)) 53 | W[W < T] = 0 54 | W[np.diag_indices(N)] = 0 55 | 56 | plotting = {"limits": [0, 1, 0, 1]} 57 | 58 | super().__init__(W, coords=coords, plotting=plotting, **kwargs) 59 | 60 | def _get_extra_repr(self): 61 | return dict(seed=self.seed) 62 | -------------------------------------------------------------------------------- /examples/random_walk.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Random walks 3 | ============ 4 | 5 | Probability of a random walker to be on any given vertex after a given number 6 | of steps starting from a given distribution. 7 | """ 8 | 9 | # sphinx_gallery_thumbnail_number = 2 10 | 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | from scipy import sparse 14 | 15 | import pygsp as pg 16 | 17 | N = 7 18 | steps = [0, 1, 2, 3] 19 | 20 | graph = pg.graphs.Grid2d(N) 21 | delta = np.zeros(graph.N) 22 | delta[N // 2 * N + N // 2] = 1 23 | 24 | probability = sparse.diags(graph.dw ** (-1)).dot(graph.W) 25 | 26 | fig, axes = plt.subplots(1, len(steps), figsize=(12, 3)) 27 | for step, ax in zip(steps, axes): 28 | state = (probability**step).__rmatmul__(delta) ## = delta @ probability**step 29 | graph.plot(state, ax=ax, title=rf"$\delta P^{step}$") 30 | ax.set_axis_off() 31 | 32 | fig.tight_layout() 33 | 34 | ############################################################################### 35 | # Stationary distribution. 36 | 37 | graphs = [ 38 | pg.graphs.Ring(10), 39 | pg.graphs.Grid2d(5), 40 | pg.graphs.Comet(8, 4), 41 | pg.graphs.BarabasiAlbert(20, seed=42), 42 | ] 43 | 44 | fig, axes = plt.subplots(1, len(graphs), figsize=(12, 3)) 45 | 46 | for graph, ax in zip(graphs, axes): 47 | if not hasattr(graph, "coords"): 48 | graph.set_coordinates(seed=10) 49 | 50 | P = sparse.diags(graph.dw ** (-1)).dot(graph.W) 51 | 52 | # e, u = np.linalg.eig(P.T.toarray()) 53 | # np.testing.assert_allclose(np.linalg.inv(u.T) @ np.diag(e) @ u.T, 54 | # P.toarray(), atol=1e-10) 55 | # np.testing.assert_allclose(np.abs(e[0]), 1) 56 | # stationary = np.abs(u.T[0]) 57 | 58 | e, u = sparse.linalg.eigs(P.T, k=1, which="LR") 59 | np.testing.assert_allclose(e, 1) 60 | stationary = np.abs(u).squeeze() 61 | assert np.all(stationary < 0.71) 62 | 63 | colorbar = False if type(graph) is pg.graphs.Ring else True 64 | graph.plot(stationary, colorbar=colorbar, ax=ax, title="$xP = x$") 65 | ax.set_axis_off() 66 | 67 | fig.tight_layout() 68 | -------------------------------------------------------------------------------- /pygsp/filters/simoncelli.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Simoncelli(Filter): 7 | r"""Design 2 filters with the Simoncelli construction (tight frame). 8 | 9 | This function creates a Parseval filter bank of 2 filters. 10 | The low-pass filter is defined by the function 11 | 12 | .. math:: f_{l}=\begin{cases} 1 & \mbox{if }x\leq a\\ 13 | \cos\left(\frac{\pi}{2}\frac{\log\left(\frac{x}{2}\right)}{\log(2)}\right) & \mbox{if }a2a \end{cases} 15 | 16 | The high pass filter is adapted to obtain a tight frame. 17 | 18 | Parameters 19 | ---------- 20 | G : graph 21 | a : float 22 | See above equation for this parameter. 23 | The spectrum is scaled between 0 and 2 (default = 2/3). 24 | 25 | Examples 26 | -------- 27 | 28 | Filter bank's representation in Fourier and time (ring graph) domains. 29 | 30 | >>> import matplotlib.pyplot as plt 31 | >>> G = graphs.Ring(N=20) 32 | >>> G.estimate_lmax() 33 | >>> G.set_coordinates('line1D') 34 | >>> g = filters.Simoncelli(G) 35 | >>> s = g.localize(G.N // 2) 36 | >>> fig, axes = plt.subplots(1, 2) 37 | >>> _ = g.plot(ax=axes[0]) 38 | >>> _ = G.plot(s, ax=axes[1]) 39 | 40 | """ 41 | 42 | def __init__(self, G, a=2 / 3): 43 | self.a = a 44 | 45 | def kernel(x, a): 46 | y = np.empty(np.shape(x)) 47 | l1 = a 48 | l2 = 2 * a 49 | 50 | r1ind = (x >= 0) * (x < l1) 51 | r2ind = (x >= l1) * (x < l2) 52 | r3ind = x >= l2 53 | 54 | y[r1ind] = 1 55 | y[r2ind] = np.cos(np.pi / 2 * np.log(x[r2ind] / a) / np.log(2)) 56 | y[r3ind] = 0 57 | 58 | return y 59 | 60 | simoncelli = Filter(G, lambda x: kernel(x * 2 / G.lmax, a)) 61 | complement = simoncelli.complement(frame_bound=1) 62 | kernels = simoncelli._kernels + complement._kernels 63 | 64 | super().__init__(G, kernels) 65 | 66 | def _get_extra_repr(self): 67 | return dict(a=f"{self.a:.2f}") 68 | -------------------------------------------------------------------------------- /pygsp/filters/papadakis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Papadakis(Filter): 7 | r"""Design 2 filters with the Papadakis construction (tight frame). 8 | 9 | This function creates a Parseval filter bank of 2 filters. 10 | The low-pass filter is defined by the function 11 | 12 | .. math:: f_{l}=\begin{cases} 1 & \mbox{if }x\leq a\\ 13 | \sqrt{1-\frac{\sin\left(\frac{3\pi}{2a}x\right)}{2}} & \mbox{if }a\frac{5a}{3} \end{cases} 15 | 16 | The high pass filter is adapted to obtain a tight frame. 17 | 18 | Parameters 19 | ---------- 20 | G : graph 21 | a : float 22 | See above equation for this parameter. 23 | The spectrum is scaled between 0 and 2 (default = 3/4). 24 | 25 | Examples 26 | -------- 27 | 28 | Filter bank's representation in Fourier and time (ring graph) domains. 29 | 30 | >>> import matplotlib.pyplot as plt 31 | >>> G = graphs.Ring(N=20) 32 | >>> G.estimate_lmax() 33 | >>> G.set_coordinates('line1D') 34 | >>> g = filters.Papadakis(G) 35 | >>> s = g.localize(G.N // 2) 36 | >>> fig, axes = plt.subplots(1, 2) 37 | >>> _ = g.plot(ax=axes[0]) 38 | >>> _ = G.plot(s, ax=axes[1]) 39 | 40 | """ 41 | 42 | def __init__(self, G, a=0.75): 43 | self.a = a 44 | 45 | def kernel(x, a): 46 | y = np.empty(np.shape(x)) 47 | l1 = a 48 | l2 = a * 5 / 3 49 | 50 | r1ind = (x >= 0) * (x < l1) 51 | r2ind = (x >= l1) * (x < l2) 52 | r3ind = x >= l2 53 | 54 | y[r1ind] = 1 55 | y[r2ind] = np.sqrt((1 - np.sin(3 * np.pi / (2 * a) * x[r2ind])) / 2) 56 | y[r3ind] = 0 57 | 58 | return y 59 | 60 | papadakis = Filter(G, lambda x: kernel(x * 2 / G.lmax, a)) 61 | complement = papadakis.complement(frame_bound=1) 62 | kernels = papadakis._kernels + complement._kernels 63 | 64 | super().__init__(G, kernels) 65 | 66 | def _get_extra_repr(self): 67 | return dict(a=f"{self.a:.2f}") 68 | -------------------------------------------------------------------------------- /pygsp/filters/regular.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Regular(Filter): 7 | r"""Design 2 filters with the regular construction (tight frame). 8 | 9 | This function creates a Parseval filter bank of 2 filters. 10 | The low-pass filter is defined by a function :math:`f_l(x)` 11 | between :math:`0` and :math:`2`. For :math:`d = 0`. 12 | 13 | .. math:: f_{l}= \sin\left( \frac{\pi}{4} x \right) 14 | 15 | For :math:`d = 1` 16 | 17 | .. math:: f_{l}= \sin\left( \frac{\pi}{4} \left( 1+ \sin\left(\frac{\pi}{2}(x-1)\right) \right) \right) 18 | 19 | For :math:`d = 2` 20 | 21 | .. math:: f_{l}= \sin\left( \frac{\pi}{4} \left( 1+ \sin\left(\frac{\pi}{2} \sin\left(\frac{\pi}{2}(x-1)\right)\right) \right) \right) 22 | 23 | And so forth for other degrees :math:`d`. 24 | 25 | The high pass filter is adapted to obtain a tight frame. 26 | 27 | Parameters 28 | ---------- 29 | G : graph 30 | degree : float 31 | Degree (default = 3). See above equations. 32 | 33 | Examples 34 | -------- 35 | 36 | Filter bank's representation in Fourier and time (ring graph) domains. 37 | 38 | >>> import matplotlib.pyplot as plt 39 | >>> G = graphs.Ring(N=20) 40 | >>> G.estimate_lmax() 41 | >>> G.set_coordinates('line1D') 42 | >>> g = filters.Regular(G) 43 | >>> s = g.localize(G.N // 2) 44 | >>> fig, axes = plt.subplots(1, 2) 45 | >>> _ = g.plot(ax=axes[0]) 46 | >>> _ = G.plot(s, ax=axes[1]) 47 | 48 | """ 49 | 50 | def __init__(self, G, degree=3): 51 | self.degree = degree 52 | 53 | def kernel(x, degree): 54 | if degree == 0: 55 | return np.sin(np.pi / 4 * x) 56 | else: 57 | output = np.sin(np.pi * (x - 1) / 2) 58 | for _ in range(2, degree): 59 | output = np.sin(np.pi * output / 2) 60 | return np.sin(np.pi / 4 * (1 + output)) 61 | 62 | regular = Filter(G, lambda x: kernel(x * 2 / G.lmax, degree)) 63 | complement = regular.complement(frame_bound=1) 64 | kernels = regular._kernels + complement._kernels 65 | 66 | super().__init__(G, kernels) 67 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/sphere.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .nngraph import NNGraph # prevent circular import in Python < 3.5 4 | 5 | 6 | class Sphere(NNGraph): 7 | r"""Spherical-shaped graph (NN-graph). 8 | 9 | Parameters 10 | ---------- 11 | radius : float 12 | Radius of the sphere (default = 1) 13 | nb_pts : int 14 | Number of vertices (default = 300) 15 | nb_dim : int 16 | Dimension (default = 3) 17 | sampling : string 18 | Variance of the distance kernel (default = 'random') 19 | (Can now only be 'random') 20 | seed : int 21 | Seed for the random number generator (for reproducible graphs). 22 | 23 | Examples 24 | -------- 25 | >>> import matplotlib.pyplot as plt 26 | >>> G = graphs.Sphere(nb_pts=100, seed=42) 27 | >>> fig = plt.figure() 28 | >>> ax1 = fig.add_subplot(121) 29 | >>> ax2 = fig.add_subplot(122, projection='3d') 30 | >>> _ = ax1.spy(G.W, markersize=1.5) 31 | >>> _ = _ = G.plot(ax=ax2) 32 | 33 | """ 34 | 35 | def __init__( 36 | self, radius=1, nb_pts=300, nb_dim=3, sampling="random", seed=None, **kwargs 37 | ): 38 | self.radius = radius 39 | self.nb_pts = nb_pts 40 | self.nb_dim = nb_dim 41 | self.sampling = sampling 42 | self.seed = seed 43 | 44 | if self.sampling == "random": 45 | rs = np.random.RandomState(seed) 46 | pts = rs.normal(0, 1, (self.nb_pts, self.nb_dim)) 47 | 48 | for i in range(self.nb_pts): 49 | pts[i] /= np.linalg.norm(pts[i]) 50 | 51 | else: 52 | raise ValueError(f"Unknown sampling {sampling}") 53 | 54 | plotting = { 55 | "vertex_size": 80, 56 | } 57 | 58 | super().__init__( 59 | Xin=pts, k=10, center=False, rescale=False, plotting=plotting, **kwargs 60 | ) 61 | 62 | def _get_extra_repr(self): 63 | attrs = { 64 | "radius": f"{self.radius:.2f}", 65 | "nb_pts": self.nb_pts, 66 | "nb_dim": self.nb_dim, 67 | "sampling": self.sampling, 68 | "seed": self.seed, 69 | } 70 | attrs.update(super()._get_extra_repr()) 71 | return attrs 72 | -------------------------------------------------------------------------------- /pygsp/filters/held.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Held(Filter): 7 | r"""Design 2 filters with the Held construction (tight frame). 8 | 9 | This function create a parseval filterbank of :math:`2` filters. 10 | The low-pass filter is defined by the function 11 | 12 | .. math:: f_{l}=\begin{cases} 1 & \mbox{if }x\leq a\\ 13 | \sin\left(2\pi\mu\left(\frac{x}{8a}\right)\right) & \mbox{if }a2a \end{cases} 15 | 16 | with 17 | 18 | .. math:: \mu(x) = -1+24x-144*x^2+256*x^3 19 | 20 | The high pass filter is adapted to obtain a tight frame. 21 | 22 | Parameters 23 | ---------- 24 | G : graph 25 | a : float 26 | See equation above for this parameter 27 | The spectrum is scaled between 0 and 2 (default = 2/3) 28 | 29 | Examples 30 | -------- 31 | 32 | Filter bank's representation in Fourier and time (ring graph) domains. 33 | 34 | >>> import matplotlib.pyplot as plt 35 | >>> G = graphs.Ring(N=20) 36 | >>> G.estimate_lmax() 37 | >>> G.set_coordinates('line1D') 38 | >>> g = filters.Held(G) 39 | >>> s = g.localize(G.N // 2) 40 | >>> fig, axes = plt.subplots(1, 2) 41 | >>> _ = g.plot(ax=axes[0]) 42 | >>> _ = G.plot(s, ax=axes[1]) 43 | 44 | """ 45 | 46 | def __init__(self, G, a=2.0 / 3): 47 | self.a = a 48 | 49 | def kernel(x, a): 50 | y = np.empty(np.shape(x)) 51 | l1 = a 52 | l2 = 2 * a 53 | 54 | r1ind = (x >= 0) * (x < l1) 55 | r2ind = (x >= l1) * (x < l2) 56 | r3ind = x >= l2 57 | 58 | def mu(x): 59 | return -1 + 24 * x - 144 * x**2 + 256 * x**3 60 | 61 | y[r1ind] = 1 62 | y[r2ind] = np.sin(2 * np.pi * mu(x[r2ind] / 8 / a)) 63 | y[r3ind] = 0 64 | 65 | return y 66 | 67 | held = Filter(G, lambda x: kernel(x * 2 / G.lmax, a)) 68 | complement = held.complement(frame_bound=1) 69 | kernels = held._kernels + complement._kernels 70 | 71 | super().__init__(G, kernels) 72 | 73 | def _get_extra_repr(self): 74 | return dict(a=f"{self.a:.2f}") 75 | -------------------------------------------------------------------------------- /doc/tutorials/pyramid.rst: -------------------------------------------------------------------------------- 1 | =================================== 2 | Graph multiresolution: Kron pyramid 3 | =================================== 4 | 5 | In this demonstration file, we show how to reduce a graph using the PyGSP. Then we apply the pyramid to simple signal. 6 | To start open a python shell (IPython is recommended here) and import the required packages. You would probably also import numpy as you will need it to create matrices and arrays. 7 | 8 | >>> import numpy as np 9 | >>> from pygsp import graphs, reduction 10 | 11 | For this demo we will be using a sensor graph with 400 nodes. 12 | 13 | >>> G = graphs.Sensor(400, distributed=True) 14 | >>> G.compute_fourier_basis() 15 | 16 | The function graph_multiresolution computes the graph pyramid for you: 17 | 18 | >>> levels = 5 19 | >>> Gs = reduction.graph_multiresolution(G, levels, sparsify=False) 20 | 21 | Next, we will compute the fourier basis of our different graph layers: 22 | 23 | >>> for gr in Gs: 24 | ... gr.compute_fourier_basis() 25 | 26 | Those that were already computed are returning with an error, meaning that nothing happened. 27 | Let's now create two signals and a filter, resp f, f2 and g: 28 | 29 | >>> f = np.ones((G.N)) 30 | >>> f[np.arange(G.N//2)] = -1 31 | >>> f = f + 10*Gs[0].U[:, 7] 32 | 33 | >>> f2 = np.ones((G.N, 2)) 34 | >>> f2[np.arange(G.N//2)] = -1 35 | 36 | >>> g = [lambda x: 5./(5 + x)] 37 | 38 | We will run the analysis of the two signals on the pyramid and obtain a coarse approximation for each layer, with decreasing number of nodes. 39 | Additionally, we will also get prediction errors at each node at every layer. 40 | 41 | >>> ca, pe = reduction.pyramid_analysis(Gs, f, h_filters=g, method='exact') 42 | >>> ca2, pe2 = reduction.pyramid_analysis(Gs, f2, h_filters=g, method='exact') 43 | 44 | Given the pyramid, the coarsest approximation and the prediction errors, we will now reconstruct the original signal on the full graph. 45 | 46 | >>> f_pred, _ = reduction.pyramid_synthesis(Gs, ca[levels], pe, method='exact') 47 | >>> f_pred2, _ = reduction.pyramid_synthesis(Gs, ca2[levels], pe2, method='exact') 48 | 49 | Here are the final errors for each signal after reconstruction. 50 | 51 | >>> err = np.linalg.norm(f_pred-f)/np.linalg.norm(f) 52 | >>> err2 = np.linalg.norm(f_pred2-f2)/np.linalg.norm(f2) 53 | >>> assert (err < 1e-10) & (err2 < 1e-10) 54 | -------------------------------------------------------------------------------- /pygsp/graphs/barabasialbert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from .graph import Graph # prevent circular import in Python < 3.5 5 | 6 | 7 | class BarabasiAlbert(Graph): 8 | r"""Barabasi-Albert preferential attachment. 9 | 10 | The Barabasi-Albert graph is constructed by connecting nodes in two steps. 11 | First, m0 nodes are created. Then, nodes are added one by one. 12 | 13 | By lack of clarity, we take the liberty to create it as follows: 14 | 15 | 1. the m0 initial nodes are disconnected, 16 | 2. each node is connected to m of the older nodes with a probability 17 | distribution depending of the node-degrees of the other nodes, 18 | :math:`p_n(i) = \frac{1 + k_i}{\sum_j{1 + k_j}}`. 19 | 20 | Parameters 21 | ---------- 22 | N : int 23 | Number of nodes (default is 1000) 24 | m0 : int 25 | Number of initial nodes (default is 1) 26 | m : int 27 | Number of connections at each step (default is 1) 28 | m can never be larger than m0. 29 | seed : int 30 | Seed for the random number generator (for reproducible graphs). 31 | 32 | Examples 33 | -------- 34 | >>> import matplotlib.pyplot as plt 35 | >>> G = graphs.BarabasiAlbert(N=150, seed=42) 36 | >>> G.set_coordinates(kind='spring', seed=42) 37 | >>> fig, axes = plt.subplots(1, 2) 38 | >>> _ = axes[0].spy(G.W, markersize=2) 39 | >>> _ = G.plot(ax=axes[1]) 40 | 41 | """ 42 | 43 | def __init__(self, N=1000, m0=1, m=1, seed=None, **kwargs): 44 | if m > m0: 45 | raise ValueError("Parameter m cannot be above parameter m0.") 46 | 47 | self.m0 = m0 48 | self.m = m 49 | self.seed = seed 50 | 51 | W = sparse.lil_matrix((N, N)) 52 | rng = np.random.default_rng(seed) 53 | 54 | for i in range(m0, N): 55 | distr = W.sum(axis=1) 56 | distr += np.concatenate((np.ones((i, 1)), np.zeros((N - i, 1)))) 57 | 58 | connections = rng.choice( 59 | N, size=m, replace=False, p=np.ravel(distr / distr.sum()) 60 | ) 61 | for elem in connections: 62 | W[elem, i] = 1 63 | W[i, elem] = 1 64 | 65 | super().__init__(W, **kwargs) 66 | 67 | def _get_extra_repr(self): 68 | return dict(m0=self.m0, m=self.m, seed=self.seed) 69 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | - id: check-case-conflict 12 | - id: check-merge-conflict 13 | - id: check-toml 14 | - id: debug-statements 15 | - id: detect-private-key 16 | 17 | - repo: https://github.com/psf/black 18 | rev: 23.12.1 19 | hooks: 20 | - id: black 21 | language_version: python3 22 | args: [--line-length=88] 23 | 24 | - repo: https://github.com/pycqa/isort 25 | rev: 5.13.2 26 | hooks: 27 | - id: isort 28 | args: [--profile=black, --line-length=88] 29 | 30 | # - repo: https://github.com/pycqa/flake8 31 | # rev: 7.0.0 32 | # hooks: 33 | # - id: flake8 34 | # additional_dependencies: [flake8-docstrings] 35 | # args: [--max-line-length, "88", --extend-ignore, "E203,W503,D100,D101,D102,D103,D104,D105,D107,D200,D202,D205,D209,D400,D401,D402,D412,D414"] 36 | 37 | # - repo: https://github.com/PyCQA/bandit 38 | # rev: 1.7.5 39 | # hooks: 40 | # - id: bandit 41 | # args: [--skip=B101,B601] 42 | 43 | # - repo: https://github.com/pre-commit/mirrors-mypy 44 | # rev: v1.8.0 45 | # hooks: 46 | # - id: mypy 47 | # additional_dependencies: [types-all] 48 | # args: [--ignore-missing-imports, --no-strict-optional] 49 | # files: ^pygsp/ 50 | 51 | # - repo: https://github.com/pycqa/pydocstyle 52 | # rev: 6.3.0 53 | # hooks: 54 | # - id: pydocstyle 55 | # args: [--convention=numpy] 56 | # files: ^pygsp/ 57 | 58 | - repo: https://github.com/asottile/pyupgrade 59 | rev: v3.15.0 60 | hooks: 61 | - id: pyupgrade 62 | args: [--py38-plus] 63 | 64 | default_language_version: 65 | python: python3 66 | 67 | ci: 68 | autofix_commit_msg: | 69 | [pre-commit.ci] auto fixes from pre-commit.com hooks 70 | 71 | for more information, see https://pre-commit.ci 72 | autofix_prs: true 73 | autoupdate_branch: '' 74 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' 75 | autoupdate_schedule: weekly 76 | skip: [] 77 | submodules: false 78 | -------------------------------------------------------------------------------- /pygsp/graphs/comet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from .graph import Graph # prevent circular import in Python < 3.5 5 | 6 | 7 | class Comet(Graph): 8 | r"""Comet graph. 9 | 10 | The comet is a path graph with a star of degree `k` at one end. 11 | Equivalently, the comet is a star made of `k` branches, where a branch of 12 | length `N-k` acts as the tail. 13 | The central vertex has degree `N-1`, the others have degree 1. 14 | 15 | Parameters 16 | ---------- 17 | N : int 18 | Number of vertices. 19 | k : int 20 | Degree of central vertex. 21 | 22 | See Also 23 | -------- 24 | Path : Comet without star 25 | Star : Comet without tail (path) 26 | 27 | Examples 28 | -------- 29 | >>> import matplotlib.pyplot as plt 30 | >>> G = graphs.Comet(15, 10) 31 | >>> fig, axes = plt.subplots(1, 2) 32 | >>> _ = axes[0].spy(G.W) 33 | >>> _ = G.plot(ax=axes[1]) 34 | 35 | """ 36 | 37 | def __init__(self, N=32, k=12, **kwargs): 38 | if k > N - 1: 39 | raise ValueError( 40 | "The degree of the central vertex k={} must be " 41 | "smaller than the number of vertices N={}." 42 | "".format(k, N) 43 | ) 44 | 45 | self.k = k 46 | 47 | sources = np.concatenate( 48 | ( 49 | np.zeros(k), 50 | np.arange(k) + 1, # star 51 | np.arange(k, N - 1), 52 | np.arange(k + 1, N), # tail (path) 53 | ) 54 | ) 55 | targets = np.concatenate( 56 | ( 57 | np.arange(k) + 1, 58 | np.zeros(k), # star 59 | np.arange(k + 1, N), 60 | np.arange(k, N - 1), # tail (path) 61 | ) 62 | ) 63 | n_edges = N - 1 64 | weights = np.ones(2 * n_edges) 65 | W = sparse.csr_matrix((weights, (sources, targets)), shape=(N, N)) 66 | 67 | indices = np.arange(k) + 1 68 | coords = np.zeros((N, 2)) 69 | coords[1 : k + 1, 0] = np.cos(indices * 2 * np.pi / k) 70 | coords[1 : k + 1, 1] = np.sin(indices * 2 * np.pi / k) 71 | coords[k + 1 :, 0] = np.arange(1, N - k) + 1 72 | 73 | super().__init__(W, coords=coords, **kwargs) 74 | 75 | def _get_extra_repr(self): 76 | return dict(k=self.k) 77 | -------------------------------------------------------------------------------- /examples/filtering.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Filtering a signal 3 | ================== 4 | 5 | A graph signal is filtered by transforming it to the spectral domain (via the 6 | Fourier transform), performing a point-wise multiplication (motivated by the 7 | convolution theorem), and transforming it back to the vertex domain (via the 8 | inverse graph Fourier transform). 9 | 10 | .. note:: 11 | 12 | In practice, filtering is implemented in the vertex domain to avoid the 13 | computationally expensive graph Fourier transform. To do so, filters are 14 | implemented as polynomials of the eigenvalues / Laplacian. Hence, filtering 15 | a signal reduces to its multiplications with sparse matrices (the graph 16 | Laplacian). 17 | 18 | """ 19 | 20 | import numpy as np 21 | from matplotlib import pyplot as plt 22 | 23 | import pygsp as pg 24 | 25 | G = pg.graphs.Sensor(seed=42) 26 | G.compute_fourier_basis() 27 | 28 | # g = pg.filters.Rectangular(G, band_max=0.2) 29 | g = pg.filters.Expwin(G, band_max=0.5) 30 | 31 | fig, axes = plt.subplots(1, 3, figsize=(12, 4)) 32 | fig.subplots_adjust(hspace=0.5) 33 | 34 | x = np.random.default_rng(1).normal(size=G.N) 35 | # x = np.random.default_rng(42).uniform(-1, 1, size=G.N) 36 | x = 3 * x / np.linalg.norm(x) 37 | y = g.filter(x) 38 | x_hat = G.gft(x).squeeze() 39 | y_hat = G.gft(y).squeeze() 40 | 41 | limits = [x.min(), x.max()] 42 | 43 | G.plot(x, limits=limits, ax=axes[0], title="input signal $x$ in the vertex domain") 44 | axes[0].text(0, -0.1, f"$x^T L x = {G.dirichlet_energy(x):.2f}$") 45 | axes[0].set_axis_off() 46 | 47 | g.plot(ax=axes[1], alpha=1) 48 | line_filt = axes[1].lines[-2] 49 | (line_in,) = axes[1].plot(G.e, np.abs(x_hat), ".-") 50 | (line_out,) = axes[1].plot(G.e, np.abs(y_hat), ".-") 51 | # axes[1].set_xticks(range(0, 16, 4)) 52 | axes[1].set_xlabel(r"graph frequency $\lambda$") 53 | axes[1].set_ylabel(r"frequency content $\hat{x}(\lambda)$") 54 | axes[1].set_title(r"signals in the spectral domain") 55 | axes[1].legend([r"input signal $\hat{x}$"]) 56 | labels = [ 57 | r"input signal $\hat{x}$", 58 | "kernel $g$", 59 | r"filtered signal $\hat{y}$", 60 | ] 61 | axes[1].legend([line_in, line_filt, line_out], labels, loc="upper right") 62 | 63 | G.plot(y, limits=limits, ax=axes[2], title="filtered signal $y$ in the vertex domain") 64 | axes[2].text(0, -0.1, f"$y^T L y = {G.dirichlet_energy(y):.2f}$") 65 | axes[2].set_axis_off() 66 | 67 | fig.tight_layout() 68 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | import pygsp 2 | 3 | extensions = [ 4 | "sphinx.ext.viewcode", 5 | "sphinx.ext.autosummary", 6 | "sphinx.ext.mathjax", 7 | "sphinx.ext.inheritance_diagram", 8 | ] 9 | 10 | extensions.append("sphinx.ext.autodoc") 11 | autodoc_default_options = { 12 | "members": True, 13 | "undoc-members": True, 14 | "member-order": "groupwise", # alphabetical, groupwise, bysource 15 | } 16 | 17 | extensions.append("sphinx.ext.intersphinx") 18 | intersphinx_mapping = { 19 | "python": ("https://docs.python.org/3", None), 20 | "numpy": ("https://numpy.org/doc/stable", None), 21 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 22 | "matplotlib": ("https://matplotlib.org/stable", None), 23 | "pyunlocbox": ("https://pyunlocbox.readthedocs.io/en/stable", None), 24 | "networkx": ("https://networkx.org/documentation/stable", None), 25 | "graph_tool": ("https://graph-tool.skewed.de/static/doc", None), 26 | } 27 | 28 | extensions.append("numpydoc") 29 | numpydoc_show_class_members = False 30 | numpydoc_use_plots = True # Add the plot directive whenever mpl is imported. 31 | 32 | extensions.append("matplotlib.sphinxext.plot_directive") 33 | plot_include_source = True 34 | plot_html_show_source_link = False 35 | plot_html_show_formats = False 36 | plot_working_directory = "." 37 | plot_rcparams = {"figure.figsize": (10, 4)} 38 | plot_pre_code = """ 39 | import numpy as np 40 | from pygsp import graphs, filters, utils, plotting 41 | """ 42 | 43 | extensions.append("sphinx_gallery.gen_gallery") 44 | sphinx_gallery_conf = { 45 | "examples_dirs": "../examples", 46 | "gallery_dirs": "examples", 47 | "filename_pattern": "/", 48 | "reference_url": {"pygsp": None}, 49 | "backreferences_dir": "backrefs", 50 | "doc_module": "pygsp", 51 | "show_memory": True, 52 | } 53 | 54 | extensions.append("sphinx_copybutton") 55 | copybutton_prompt_text = ">>> " 56 | 57 | extensions.append("sphinxcontrib.bibtex") 58 | bibtex_bibfiles = ["references.bib"] 59 | 60 | exclude_patterns = ["_build"] 61 | source_suffix = ".rst" 62 | master_doc = "index" 63 | 64 | project = "PyGSP" 65 | version = pygsp.__version__ 66 | release = pygsp.__version__ 67 | copyright = "EPFL LTS2" 68 | 69 | pygments_style = "sphinx" 70 | html_theme = "sphinx_rtd_theme" 71 | html_theme_options = { 72 | "navigation_depth": 2, 73 | } 74 | latex_elements = { 75 | "papersize": "a4paper", 76 | "pointsize": "10pt", 77 | } 78 | latex_documents = [ 79 | ("index", "pygsp.tex", "PyGSP documentation", "EPFL LTS2", "manual"), 80 | ] 81 | -------------------------------------------------------------------------------- /examples/eigenvector_localization.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Localization of Fourier modes 3 | ============================= 4 | 5 | The Fourier modes (the eigenvectors of the graph Laplacian) can be localized in 6 | the spacial domain. As a consequence, graph signals can be localized in both 7 | space and frequency (which is impossible for Euclidean domains or manifolds, by 8 | the Heisenberg's uncertainty principle). 9 | 10 | This example demonstrates that the more isolated a node is, the more a Fourier 11 | mode will be localized on it. 12 | 13 | The mutual coherence between the basis of Kronecker deltas and the basis formed 14 | by the eigenvectors of the Laplacian, :attr:`pygsp.graphs.Graph.coherence`, is 15 | a measure of the localization of the Fourier modes. The larger the value, the 16 | more localized the eigenvectors can be. 17 | 18 | See `Global and Local Uncertainty Principles for Signals on Graphs 19 | `_ for details. 20 | """ 21 | 22 | import matplotlib as mpl 23 | import numpy as np 24 | from matplotlib import pyplot as plt 25 | 26 | import pygsp as pg 27 | 28 | fig, axes = plt.subplots(2, 2, figsize=(8, 8)) 29 | 30 | for w, ax in zip([10, 1, 0.1, 0.01], axes.flatten()): 31 | adjacency = [ 32 | [0, w, 0, 0], 33 | [w, 0, 1, 0], 34 | [0, 1, 0, 1], 35 | [0, 0, 1, 0], 36 | ] 37 | graph = pg.graphs.Graph(adjacency) 38 | graph.compute_fourier_basis() 39 | 40 | # Plot eigenvectors. 41 | ax.plot(graph.U) 42 | ax.set_ylim(-1, 1) 43 | ax.set_yticks([-1, 0, 1]) 44 | ax.legend( 45 | [ 46 | rf"$u_{i}(v)$, $\lambda_{i}={graph.e[i]:.1f}$" 47 | for i in range(graph.n_vertices) 48 | ], 49 | loc="upper right", 50 | ) 51 | 52 | ax.text( 53 | 0, 54 | -0.9, 55 | f"coherence = {graph.coherence:.2f}" 56 | rf"$\in [{1/np.sqrt(graph.n_vertices)}, 1]$", 57 | ) 58 | 59 | # Plot vertices. 60 | ax.set_xticks(range(graph.n_vertices)) 61 | ax.set_xticklabels([f"$v_{i}$" for i in range(graph.n_vertices)]) 62 | 63 | # Plot graph. 64 | x, y = np.arange(0, graph.n_vertices), -1.20 * np.ones(graph.n_vertices) 65 | line = mpl.lines.Line2D(x, y, lw=3, color="k", marker=".", markersize=20) 66 | line.set_clip_on(False) 67 | ax.add_line(line) 68 | 69 | # Plot edge weights. 70 | for i in range(graph.n_vertices - 1): 71 | j = i + 1 72 | ax.text( 73 | i + 0.5, 74 | -1.15, 75 | f"$w_{{{i}{j}}} = {adjacency[i][j]}$", 76 | horizontalalignment="center", 77 | ) 78 | 79 | fig.tight_layout() 80 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/sensor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .nngraph import NNGraph # prevent circular import in Python < 3.5 4 | 5 | 6 | class Sensor(NNGraph): 7 | r"""Random sensor graph. 8 | 9 | The sensor graph is built by randomly picking ``N`` points on the [0, 1] x 10 | [0, 1] plane and connecting each to its ``k`` nearest neighbors. 11 | 12 | Parameters 13 | ---------- 14 | N : int 15 | Number of nodes. 16 | Must be a perfect square if ``distributed=True``. 17 | k : int 18 | Minimum number of neighbors. 19 | distributed : bool 20 | Whether to distribute the vertices more evenly on the plane. 21 | If False, coordinates are taken uniformly at random in a [0, 1] square. 22 | If True, the vertices are arranged on a perturbed grid. 23 | seed : int 24 | Seed for the random number generator (for reproducible graphs). 25 | **kwargs : 26 | Additional keyword arguments for :class:`NNGraph`. 27 | 28 | Notes 29 | ----- 30 | 31 | The definition of this graph changed in February 2019. 32 | See the `GitHub PR `_. 33 | 34 | Examples 35 | -------- 36 | >>> import matplotlib.pyplot as plt 37 | >>> G = graphs.Sensor(N=64, seed=42) 38 | >>> fig, axes = plt.subplots(1, 2) 39 | >>> _ = axes[0].spy(G.W, markersize=2) 40 | >>> _ = G.plot(ax=axes[1]) 41 | 42 | >>> import matplotlib.pyplot as plt 43 | >>> G = graphs.Sensor(N=64, distributed=True, seed=42) 44 | >>> fig, axes = plt.subplots(1, 2) 45 | >>> _ = axes[0].spy(G.W, markersize=2) 46 | >>> _ = G.plot(ax=axes[1]) 47 | 48 | """ 49 | 50 | def __init__(self, N=64, k=6, distributed=False, seed=None, **kwargs): 51 | self.distributed = distributed 52 | self.seed = seed 53 | 54 | plotting = {"limits": np.array([0, 1, 0, 1])} 55 | 56 | rng = np.random.default_rng(self.seed) 57 | 58 | if distributed: 59 | m = np.sqrt(N) 60 | if not m.is_integer(): 61 | raise ValueError( 62 | "The number of vertices must be a " 63 | "perfect square if they are to be " 64 | "distributed on a grid." 65 | ) 66 | 67 | coords = np.mgrid[0 : 1 : 1 / m, 0 : 1 : 1 / m].reshape(2, -1).T 68 | coords += rng.uniform(0, 1 / m, (N, 2)) 69 | 70 | else: 71 | coords = rng.uniform(0, 1, (N, 2)) 72 | 73 | super().__init__( 74 | Xin=coords, k=k, rescale=False, center=False, plotting=plotting, **kwargs 75 | ) 76 | 77 | def _get_extra_repr(self): 78 | return {"k": self.k, "distributed": self.distributed, "seed": self.seed} 79 | -------------------------------------------------------------------------------- /pygsp/features.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`pygsp.features` module implements different feature extraction 3 | techniques based on :mod:`pygsp.graphs` and :mod:`pygsp.filters`. 4 | """ 5 | 6 | import numpy as np 7 | 8 | from pygsp import filters, utils 9 | 10 | 11 | def compute_avg_adj_deg(G): 12 | r""" 13 | Compute the average adjacency degree for each node. 14 | 15 | The average adjacency degree is the average of the degrees of a node and 16 | its neighbors. 17 | 18 | Parameters 19 | ---------- 20 | G: Graph 21 | Graph on which the statistic is extracted 22 | """ 23 | return np.sum(np.dot(G.A, G.A), axis=1) / (np.sum(G.A, axis=1) + 1.0) 24 | 25 | 26 | @utils.filterbank_handler 27 | def compute_tig(g, **kwargs): 28 | r""" 29 | Compute the Tig for a given filter or filter bank. 30 | 31 | .. math:: T_ig(n) = g(L)_{i, n} 32 | 33 | Parameters 34 | ---------- 35 | g: Filter 36 | One of :mod:`pygsp.filters`. 37 | kwargs: dict 38 | Additional parameters to be passed to the 39 | :func:`pygsp.filters.Filter.filter` method. 40 | """ 41 | return g.compute_frame() 42 | 43 | 44 | @utils.filterbank_handler 45 | def compute_norm_tig(g, **kwargs): 46 | r""" 47 | Compute the :math:`\ell_2` norm of the Tig. 48 | See :func:`compute_tig`. 49 | 50 | Parameters 51 | ---------- 52 | g: Filter 53 | The filter or filter bank. 54 | kwargs: dict 55 | Additional parameters to be passed to the 56 | :func:`pygsp.filters.Filter.filter` method. 57 | """ 58 | tig = compute_tig(g, **kwargs) 59 | return np.linalg.norm(tig, axis=1, ord=2) 60 | 61 | 62 | def compute_spectrogram(G, atom=None, M=100, **kwargs): 63 | r""" 64 | Compute the norm of the Tig for all nodes with a kernel shifted along the 65 | spectral axis. 66 | 67 | Parameters 68 | ---------- 69 | G : Graph 70 | Graph on which to compute the spectrogram. 71 | atom : func 72 | Kernel to use in the spectrogram (default = exp(-M*(x/lmax)²)). 73 | M : int (optional) 74 | Number of samples on the spectral scale. (default = 100) 75 | kwargs: dict 76 | Additional parameters to be passed to the 77 | :func:`pygsp.filters.Filter.filter` method. 78 | """ 79 | 80 | if not atom: 81 | 82 | def atom(x): 83 | return np.exp(-M * (x / G.lmax) ** 2) 84 | 85 | scale = np.linspace(0, G.lmax, M) 86 | spectr = np.empty((G.N, M)) 87 | 88 | for shift_idx in range(M): 89 | shift_filter = filters.Filter(G, lambda x: atom(x - scale[shift_idx])) 90 | tig = compute_norm_tig(shift_filter, **kwargs).squeeze() ** 2 91 | spectr[:, shift_idx] = tig 92 | 93 | G.spectr = spectr 94 | return spectr 95 | -------------------------------------------------------------------------------- /pygsp/graphs/path.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from .graph import Graph # prevent circular import in Python < 3.5 5 | 6 | 7 | class Path(Graph): 8 | r"""Path graph. 9 | 10 | A signal on the path graph is akin to a 1-dimensional signal in classical 11 | signal processing. 12 | 13 | On the path graph, the graph Fourier transform (GFT) is the classical 14 | discrete cosine transform (DCT_). 15 | As the type-II DCT, the GFT assumes even boundary conditions on both sides. 16 | 17 | .. _DCT: https://en.wikipedia.org/wiki/Discrete_cosine_transform 18 | 19 | Parameters 20 | ---------- 21 | N : int 22 | Number of vertices. 23 | 24 | See Also 25 | -------- 26 | Ring : 1D line with periodic boundary conditions 27 | Grid2d : Kronecker product of two path graphs 28 | Comet : Generalization with a star at one end 29 | 30 | References 31 | ---------- 32 | :cite:`strang1999dct` shows that each DCT basis contains the eigenvectors 33 | of a symmetric "second difference" matrix. 34 | They get the eight types of DCTs by varying the boundary conditions. 35 | 36 | Examples 37 | -------- 38 | >>> import matplotlib.pyplot as plt 39 | >>> fig, axes = plt.subplots(2, 2, figsize=(10, 8)) 40 | >>> for i, directed in enumerate([False, True]): 41 | ... G = graphs.Path(N=10, directed=directed) 42 | ... _ = axes[i, 0].spy(G.W) 43 | ... _ = G.plot(ax=axes[i, 1]) 44 | 45 | The GFT of the path graph is the classical DCT. 46 | 47 | >>> from matplotlib import pyplot as plt 48 | >>> n_eigenvectors = 4 49 | >>> graph = graphs.Path(30) 50 | >>> fig, axes = plt.subplots(1, 2) 51 | >>> graph.set_coordinates('line1D') 52 | >>> graph.compute_fourier_basis() 53 | >>> _ = graph.plot(graph.U[:, :n_eigenvectors], ax=axes[0]) 54 | >>> _ = axes[0].legend(range(n_eigenvectors)) 55 | >>> _ = axes[1].plot(graph.e, '.') 56 | 57 | """ 58 | 59 | def __init__(self, N=16, directed=False, **kwargs): 60 | self.directed = directed 61 | if directed: 62 | sources = np.arange(0, N - 1) 63 | targets = np.arange(1, N) 64 | n_edges = N - 1 65 | else: 66 | sources = np.concatenate((np.arange(0, N - 1), np.arange(1, N))) 67 | targets = np.concatenate((np.arange(1, N), np.arange(0, N - 1))) 68 | n_edges = 2 * (N - 1) 69 | weights = np.ones(n_edges) 70 | W = sparse.csr_matrix((weights, (sources, targets)), shape=(N, N)) 71 | plotting = {"limits": np.array([-1, N, -1, 1])} 72 | 73 | super().__init__(W, plotting=plotting, **kwargs) 74 | 75 | self.set_coordinates("line2D") 76 | 77 | def _get_extra_repr(self): 78 | return dict(directed=self.directed) 79 | -------------------------------------------------------------------------------- /pygsp/filters/mexicanhat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pygsp import utils 4 | 5 | from .filter import Filter # prevent circular import in Python < 3.5 6 | 7 | 8 | class MexicanHat(Filter): 9 | r"""Design a filter bank of Mexican hat wavelets. 10 | 11 | The Mexican hat wavelet is the second oder derivative of a Gaussian. Since 12 | we express the filter in the Fourier domain, we find: 13 | 14 | .. math:: \hat{g}_b(x) = x * \exp(-x) 15 | 16 | for the band-pass filter. Note that in our convention the eigenvalues of 17 | the Laplacian are equivalent to the square of graph frequencies, 18 | i.e. :math:`x = \lambda^2`. 19 | 20 | The low-pass filter is given by 21 | 22 | .. math: \hat{g}_l(x) = \exp(-x^4). 23 | 24 | Parameters 25 | ---------- 26 | G : graph 27 | Nf : int 28 | Number of filters to cover the interval [0, lmax]. 29 | lpfactor : float 30 | Low-pass factor. lmin=lmax/lpfactor will be used to determine scales. 31 | The scaling function will be created to fill the low-pass gap. 32 | scales : array_like 33 | Scales to be used. 34 | By default, initialized with :func:`pygsp.utils.compute_log_scales`. 35 | normalize : bool 36 | Whether to normalize the wavelet by the factor ``sqrt(scales)``. 37 | 38 | Examples 39 | -------- 40 | 41 | Filter bank's representation in Fourier and time (ring graph) domains. 42 | 43 | >>> import matplotlib.pyplot as plt 44 | >>> G = graphs.Ring(N=20) 45 | >>> G.estimate_lmax() 46 | >>> G.set_coordinates('line1D') 47 | >>> g = filters.MexicanHat(G) 48 | >>> s = g.localize(G.N // 2) 49 | >>> fig, axes = plt.subplots(1, 2) 50 | >>> _ = g.plot(ax=axes[0]) 51 | >>> _ = G.plot(s, ax=axes[1]) 52 | 53 | """ 54 | 55 | def __init__(self, G, Nf=6, lpfactor=20, scales=None, normalize=False): 56 | self.lpfactor = lpfactor 57 | self.normalize = normalize 58 | 59 | lmin = G.lmax / lpfactor 60 | 61 | if scales is None: 62 | scales = utils.compute_log_scales(lmin, G.lmax, Nf - 1) 63 | self.scales = scales 64 | 65 | if len(scales) != Nf - 1: 66 | raise ValueError("len(scales) should be Nf-1.") 67 | 68 | def band_pass(x): 69 | return x * np.exp(-x) 70 | 71 | def low_pass(x): 72 | return np.exp(-(x**4)) 73 | 74 | kernels = [lambda x: 1.2 * np.exp(-1) * low_pass(x / 0.4 / lmin)] 75 | 76 | for i in range(Nf - 1): 77 | 78 | def kernel(x, i=i): 79 | norm = np.sqrt(scales[i]) if normalize else 1 80 | return norm * band_pass(scales[i] * x) 81 | 82 | kernels.append(kernel) 83 | 84 | super().__init__(G, kernels) 85 | 86 | def _get_extra_repr(self): 87 | return dict(lpfactor=f"{self.lpfactor:.2f}", normalize=self.normalize) 88 | -------------------------------------------------------------------------------- /pygsp/graphs/randomring.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from pygsp import utils 5 | 6 | from .graph import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class RandomRing(Graph): 10 | r"""Ring graph with randomly sampled vertices. 11 | 12 | Parameters 13 | ---------- 14 | N : int 15 | Number of vertices. 16 | angles : array_like, optional 17 | The angular coordinate, in :math:`[0, 2\pi]`, of the vertices. 18 | seed : int 19 | Seed for the random number generator (for reproducible graphs). 20 | 21 | Examples 22 | -------- 23 | >>> import matplotlib.pyplot as plt 24 | >>> G = graphs.RandomRing(N=10, seed=42) 25 | >>> fig, axes = plt.subplots(1, 2) 26 | >>> _ = axes[0].spy(G.W) 27 | >>> _ = G.plot(ax=axes[1]) 28 | >>> _ = axes[1].set_xlim(-1.1, 1.1) 29 | >>> _ = axes[1].set_ylim(-1.1, 1.1) 30 | 31 | """ 32 | 33 | def __init__(self, N=64, angles=None, seed=None, **kwargs): 34 | self.seed = seed 35 | 36 | if angles is None: 37 | rng = np.random.default_rng(seed) 38 | angles = np.sort(rng.uniform(0, 2 * np.pi, size=N), axis=0) 39 | else: 40 | angles = np.asanyarray(angles) 41 | angles.sort() # Need to be sorted to take the difference. 42 | N = len(angles) 43 | if np.any(angles < 0) or np.any(angles >= 2 * np.pi): 44 | raise ValueError("Angles should be in [0, 2 pi]") 45 | self.angles = angles 46 | 47 | if N < 3: 48 | # Asymmetric graph needed for 2 as 2 distances connect them. 49 | raise ValueError("There should be at least 3 vertices.") 50 | 51 | rows = range(0, N - 1) 52 | cols = range(1, N) 53 | weights = np.diff(angles) 54 | 55 | # Close the loop. 56 | rows = np.concatenate((rows, [0])) 57 | cols = np.concatenate((cols, [N - 1])) 58 | weights = np.concatenate((weights, [2 * np.pi + angles[0] - angles[-1]])) 59 | 60 | W = sparse.coo_matrix((weights, (rows, cols)), shape=(N, N)) 61 | W = utils.symmetrize(W, method="triu") 62 | 63 | # Width as the expected angle. All angles are equal to that value when 64 | # the ring is uniformly sampled. 65 | width = 2 * np.pi / N 66 | assert (W.data.mean() - width) < 1e-10 67 | # TODO: why this kernel ? It empirically produces eigenvectors closer 68 | # to the sines and cosines. 69 | W.data = width / W.data 70 | 71 | coords = np.stack([np.cos(angles), np.sin(angles)], axis=1) 72 | plotting = {"limits": np.array([-1, 1, -1, 1])} 73 | 74 | # TODO: save angle and 2D position as graph signals 75 | super().__init__(W, coords=coords, plotting=plotting, **kwargs) 76 | 77 | def _get_extra_repr(self): 78 | return dict(seed=self.seed) 79 | -------------------------------------------------------------------------------- /pygsp/filters/expwin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Expwin(Filter): 7 | r"""Design an exponential window filter. 8 | 9 | The window has the shape of a (half) triangle with the corners smoothed by 10 | an exponential function. 11 | 12 | Parameters 13 | ---------- 14 | G : graph 15 | band_min : float 16 | Minimum relative band. The filter evaluates at 0.5 at this frequency. 17 | Zero corresponds to the smallest eigenvalue (which is itself equal to 18 | zero), one corresponds to the largest eigenvalue. 19 | If None, the filter is high-pass. 20 | band_max : float 21 | Maximum relative band. The filter evaluates at 0.5 at this frequency. 22 | If None, the filter is low-pass. 23 | slope : float 24 | The slope at cut-off. 25 | 26 | Examples 27 | -------- 28 | 29 | Filter bank's representation in Fourier and time (ring graph) domains. 30 | 31 | >>> import matplotlib.pyplot as plt 32 | >>> G = graphs.Ring(N=20) 33 | >>> G.estimate_lmax() 34 | >>> G.set_coordinates('line1D') 35 | >>> g = filters.Expwin(G, band_min=0.1, band_max=0.7, slope=5) 36 | >>> s = g.localize(G.N // 2) 37 | >>> fig, axes = plt.subplots(1, 2) 38 | >>> _ = g.plot(ax=axes[0]) 39 | >>> _ = G.plot(s, ax=axes[1]) 40 | 41 | """ 42 | 43 | def __init__(self, G, band_min=None, band_max=0.2, slope=1): 44 | self.band_min = band_min 45 | self.band_max = band_max 46 | self.slope = slope 47 | 48 | def exp(x): 49 | """Exponential function with canary to avoid division by zero and 50 | overflow.""" 51 | y = np.where(x <= 0, -1, x) 52 | y = np.exp(-slope / y) 53 | return np.where(x <= 0, 0, y) 54 | 55 | def h(x): 56 | y = exp(x) 57 | z = exp(1 - x) 58 | return y / (y + z) 59 | 60 | def kernel_lowpass(x): 61 | return h(0.5 - x / G.lmax + band_max) 62 | 63 | def kernel_highpass(x): 64 | return h(0.5 + x / G.lmax - band_min) 65 | 66 | if (band_min is None) and (band_max is None): 67 | kernel = lambda x: np.ones_like(x) 68 | elif band_min is None: 69 | kernel = kernel_lowpass 70 | elif band_max is None: 71 | kernel = kernel_highpass 72 | else: 73 | kernel = lambda x: kernel_lowpass(x) * kernel_highpass(x) 74 | 75 | super().__init__(G, kernel) 76 | 77 | def _get_extra_repr(self): 78 | attrs = dict() 79 | if self.band_min is not None: 80 | attrs.update(band_min=f"{self.band_min:.2f}") 81 | if self.band_max is not None: 82 | attrs.update(band_max=f"{self.band_max:.2f}") 83 | attrs.update(slope=f"{self.slope:.0f}") 84 | return attrs 85 | -------------------------------------------------------------------------------- /pygsp/graphs/grid2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from pygsp import utils 5 | 6 | from .graph import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class Grid2d(Graph): 10 | r"""2-dimensional grid graph. 11 | 12 | On the 2D grid, the graph Fourier transform (GFT) is the Kronecker product 13 | between the GFT of two :class:`~pygsp.graphs.Path` graphs. 14 | 15 | Parameters 16 | ---------- 17 | N1 : int 18 | Number of vertices along the first dimension. 19 | N2 : int 20 | Number of vertices along the second dimension. Default is ``N1``. 21 | diagonal : float 22 | Value of the diagnal edges. Default is ``0.0`` 23 | 24 | See Also 25 | -------- 26 | Path : 1D line with even boundary conditions 27 | Torus : Kronecker product of two ring graphs 28 | Grid2dImgPatches 29 | 30 | Examples 31 | -------- 32 | >>> import matplotlib.pyplot as plt 33 | >>> G = graphs.Grid2d(N1=5, N2=4) 34 | >>> fig, axes = plt.subplots(1, 2) 35 | >>> _ = axes[0].spy(G.W) 36 | >>> _ = G.plot(ax=axes[1]) 37 | 38 | """ 39 | 40 | def __init__(self, N1=16, N2=None, diagonal=0.0, **kwargs): 41 | if N2 is None: 42 | N2 = N1 43 | 44 | self.N1 = N1 45 | self.N2 = N2 46 | 47 | N = N1 * N2 48 | 49 | # Filling up the weight matrix this way is faster than 50 | # looping through all the grid points: 51 | diag_1 = np.ones(N - 1) 52 | diag_1[(N2 - 1) :: N2] = 0 53 | diag_2 = np.ones(N - N2) 54 | 55 | W = sparse.diags( 56 | diagonals=[diag_1, diag_2], 57 | offsets=[-1, -N2], 58 | shape=(N, N), 59 | format="csr", 60 | dtype="float", 61 | ) 62 | 63 | if min(N1, N2) > 1 and diagonal != 0.0: 64 | # Connecting node with they diagonal neighbours 65 | diag_3 = np.full(N - N2 - 1, diagonal) 66 | diag_4 = np.full(N - N2 + 1, diagonal) 67 | diag_3[N2 - 1 :: N2] = 0 68 | diag_4[0::N2] = 0 69 | D = sparse.diags( 70 | diagonals=[diag_3, diag_4], 71 | offsets=[-N2 - 1, -N2 + 1], 72 | shape=(N, N), 73 | format="csr", 74 | dtype="float", 75 | ) 76 | W += D 77 | 78 | W = utils.symmetrize(W, method="tril") 79 | 80 | x = np.kron(np.ones((N1, 1)), (np.arange(N2) / float(N2)).reshape(N2, 1)) 81 | y = np.kron(np.ones((N2, 1)), np.arange(N1) / float(N1)).reshape(N, 1) 82 | y = np.sort(y, axis=0)[::-1] 83 | coords = np.concatenate((x, y), axis=1) 84 | 85 | plotting = { 86 | "limits": np.array([-1.0 / N2, 1 + 1.0 / N2, 1.0 / N1, 1 + 1.0 / N1]) 87 | } 88 | 89 | super().__init__(W, coords=coords, plotting=plotting, **kwargs) 90 | 91 | def _get_extra_repr(self): 92 | return dict(N1=self.N1, N2=self.N2) 93 | -------------------------------------------------------------------------------- /pygsp/filters/rectangular.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Rectangular(Filter): 7 | r"""Design a rectangular filter. 8 | 9 | The filter evaluates at one in the interval [band_min, band_max] and zero 10 | everywhere else. 11 | 12 | The rectangular kernel is defined as 13 | 14 | .. math:: g(\lambda) = \begin{cases} 15 | 0 & \text{if } \lambda < \text{band}_\text{min}, \\ 16 | 1 & \text{if band}_\text{min} \leq \lambda \leq \text{band}_\text{max}, \\ 17 | 0 & \text{if } \lambda > \text{band}_\text{max}, \\ 18 | \end{cases} 19 | 20 | where :math:`\lambda \in [0, 1]` corresponds to the normalized graph 21 | eigenvalues. 22 | 23 | Parameters 24 | ---------- 25 | G : graph 26 | band_min : float 27 | Minimum relative band. The filter evaluates at 1 at this frequency. 28 | Zero corresponds to the smallest eigenvalue (which is itself equal to 29 | zero), one corresponds to the largest eigenvalue. 30 | If None, the filter has no lower bound (which corresponds to 31 | :math:`\text{band}_\text{min} = -\infty`) and is high-pass. 32 | band_max : float 33 | Maximum relative band. The filter evaluates at 1 at this frequency. 34 | If None, the filter has no upper bound (which corresponds to 35 | :math:`\text{band}_\text{min} = \infty`) and is high-pass. 36 | 37 | Examples 38 | -------- 39 | 40 | Filter bank's representation in Fourier and time (ring graph) domains. 41 | 42 | >>> import matplotlib.pyplot as plt 43 | >>> G = graphs.Ring(N=20) 44 | >>> G.estimate_lmax() 45 | >>> G.set_coordinates('line1D') 46 | >>> g = filters.Rectangular(G, band_min=0.1, band_max=0.5) 47 | >>> s = g.localize(G.N // 2) 48 | >>> fig, axes = plt.subplots(1, 2) 49 | >>> _ = g.plot(ax=axes[0]) 50 | >>> _ = G.plot(s, ax=axes[1]) 51 | 52 | """ 53 | 54 | def __init__(self, G, band_min=None, band_max=0.2): 55 | self.band_min = band_min 56 | self.band_max = band_max 57 | 58 | def kernel_lowpass(x): 59 | x = x / G.lmax 60 | return x <= band_max 61 | 62 | def kernel_highpass(x): 63 | x = x / G.lmax 64 | return x >= band_min 65 | 66 | if (band_min is None) and (band_max is None): 67 | kernel = lambda x: np.ones_like(x) 68 | elif band_min is None: 69 | kernel = kernel_lowpass 70 | elif band_max is None: 71 | kernel = kernel_highpass 72 | else: 73 | kernel = lambda x: kernel_lowpass(x) * kernel_highpass(x) 74 | 75 | super().__init__(G, kernel) 76 | 77 | def _get_extra_repr(self): 78 | attrs = dict() 79 | if self.band_min is not None: 80 | attrs.update(band_min=f"{self.band_min:.2f}") 81 | if self.band_max is not None: 82 | attrs.update(band_max=f"{self.band_max:.2f}") 83 | return attrs 84 | -------------------------------------------------------------------------------- /pygsp/filters/meyer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pygsp import utils 4 | 5 | from .filter import Filter # prevent circular import in Python < 3.5 6 | 7 | 8 | class Meyer(Filter): 9 | r"""Design a filter bank of Meyer wavelets (tight frame). 10 | 11 | Parameters 12 | ---------- 13 | G : graph 14 | Nf : int 15 | Number of filters from 0 to lmax (default = 6). 16 | scales : ndarray 17 | Vector of scales to be used (default: log scale). 18 | 19 | References 20 | ---------- 21 | Use of this kernel for SGWT proposed by Nora Leonardi and Dimitri Van De 22 | Ville in :cite:`leonardi2011wavelet`. 23 | 24 | Examples 25 | -------- 26 | 27 | Filter bank's representation in Fourier and time (ring graph) domains. 28 | 29 | >>> import matplotlib.pyplot as plt 30 | >>> G = graphs.Ring(N=20) 31 | >>> G.estimate_lmax() 32 | >>> G.set_coordinates('line1D') 33 | >>> g = filters.Meyer(G) 34 | >>> s = g.localize(G.N // 2) 35 | >>> fig, axes = plt.subplots(1, 2) 36 | >>> _ = g.plot(ax=axes[0]) 37 | >>> _ = G.plot(s, ax=axes[1]) 38 | 39 | """ 40 | 41 | def __init__(self, G, Nf=6, scales=None): 42 | if scales is None: 43 | scales = (4.0 / (3 * G.lmax)) * np.power(2.0, np.arange(Nf - 2, -1, -1)) 44 | self.scales = scales 45 | 46 | if len(scales) != Nf - 1: 47 | raise ValueError("len(scales) should be Nf-1.") 48 | 49 | kernels = [lambda x: kernel(scales[0] * x, "scaling_function")] 50 | 51 | for i in range(Nf - 1): 52 | kernels.append(lambda x, i=i: kernel(scales[i] * x, "wavelet")) 53 | 54 | def kernel(x, kernel_type): 55 | r""" 56 | Evaluates Meyer function and scaling function 57 | 58 | * meyer wavelet kernel: supported on [2/3,8/3] 59 | * meyer scaling function kernel: supported on [0,4/3] 60 | """ 61 | 62 | x = np.asanyarray(x) 63 | 64 | l1 = 2 / 3.0 65 | l2 = 4 / 3.0 # 2*l1 66 | l3 = 8 / 3.0 # 4*l1 67 | 68 | def v(x): 69 | return x**4 * (35 - 84 * x + 70 * x**2 - 20 * x**3) 70 | 71 | r1ind = x < l1 72 | r2ind = (x >= l1) * (x < l2) 73 | r3ind = (x >= l2) * (x < l3) 74 | 75 | # as we initialize r with zero, computed function will implicitly 76 | # be zero for all x not in one of the three regions defined above 77 | r = np.zeros(x.shape) 78 | if kernel_type == "scaling_function": 79 | r[r1ind] = 1 80 | r[r2ind] = np.cos((np.pi / 2) * v(np.abs(x[r2ind]) / l1 - 1)) 81 | elif kernel_type == "wavelet": 82 | r[r2ind] = np.sin((np.pi / 2) * v(np.abs(x[r2ind]) / l1 - 1)) 83 | r[r3ind] = np.cos((np.pi / 2) * v(np.abs(x[r3ind]) / l2 - 1)) 84 | else: 85 | raise ValueError(f"Unknown kernel type {kernel_type}") 86 | 87 | return r 88 | 89 | super().__init__(G, kernels) 90 | -------------------------------------------------------------------------------- /pygsp/graphs/lowstretchtree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from .graph import Graph # prevent circular import in Python < 3.5 5 | 6 | 7 | class LowStretchTree(Graph): 8 | r"""Low stretch tree. 9 | 10 | Build the root of a low stretch tree on a grid of points. There are 11 | :math:`2k` points on each side of the grid, and therefore :math:`2^{2k}` 12 | vertices in total. The edge weights are all equal to 1. 13 | 14 | Parameters 15 | ---------- 16 | k : int 17 | :math:`2^k` points on each side of the grid of vertices. 18 | 19 | Examples 20 | -------- 21 | >>> import matplotlib.pyplot as plt 22 | >>> G = graphs.LowStretchTree(k=2) 23 | >>> fig, axes = plt.subplots(1, 2) 24 | >>> _ = axes[0].spy(G.W) 25 | >>> _ = G.plot(ax=axes[1]) 26 | 27 | """ 28 | 29 | def __init__(self, k=6, **kwargs): 30 | self.k = k 31 | 32 | XCoords = np.array([1, 2, 1, 2], dtype=int) 33 | YCoords = np.array([1, 1, 2, 2], dtype=int) 34 | 35 | ii = np.array([0, 0, 1, 2, 2, 3], dtype=int) 36 | jj = np.array([1, 2, 1, 3, 0, 2], dtype=int) 37 | 38 | for p in range(1, k): 39 | ii = np.concatenate( 40 | ( 41 | ii, 42 | ii + 4**p, 43 | ii + 2 * 4**p, 44 | ii + 3 * 4**p, 45 | [4**p - 1], 46 | [4**p - 1], 47 | [4**p + (4 ** (p + 1) + 2) // 3 - 1], 48 | [(5 * 4**p + 1) // 3 - 1], 49 | [4**p + (4 ** (p + 1) + 2) // 3 - 1], 50 | [3 * 4**p], 51 | ) 52 | ) 53 | jj = np.concatenate( 54 | ( 55 | jj, 56 | jj + 4**p, 57 | jj + 2 * 4**p, 58 | jj + 3 * 4**p, 59 | [(5 * 4**p + 1) // 3 - 1], 60 | [4**p + (4 ** (p + 1) + 2) // 3 - 1], 61 | [3 * 4**p], 62 | [4**p - 1], 63 | [4**p - 1], 64 | [4**p + (4 ** (p + 1) + 2) // 3 - 1], 65 | ) 66 | ) 67 | 68 | YCoords = np.kron(np.ones((2), dtype=int), YCoords) 69 | YCoords = np.concatenate((YCoords, YCoords + 2**p)) 70 | 71 | XCoords = np.concatenate((XCoords, XCoords + 2**p)) 72 | XCoords = np.kron(np.ones((2), dtype=int), XCoords) 73 | 74 | W = sparse.csc_matrix((np.ones_like(ii), (ii, jj))) 75 | coords = np.concatenate( 76 | (XCoords[:, np.newaxis], YCoords[:, np.newaxis]), axis=1 77 | ) 78 | 79 | self.root = 4 ** (k - 1) 80 | 81 | plotting = { 82 | "edges_width": 1.25, 83 | "vertex_size": 75, 84 | "limits": np.array([0, 2**k + 1, 0, 2**k + 1]), 85 | } 86 | 87 | super().__init__(W, coords=coords, plotting=plotting, **kwargs) 88 | 89 | def _get_extra_repr(self): 90 | return dict(k=self.k) 91 | -------------------------------------------------------------------------------- /pygsp/filters/simpletight.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pygsp import utils 4 | 5 | from .filter import Filter # prevent circular import in Python < 3.5 6 | 7 | 8 | class SimpleTight(Filter): 9 | r"""Design a simple tight frame filter bank (tight frame). 10 | 11 | These filters have been designed to be a simple tight frame wavelet filter 12 | bank. The kernel is similar to Meyer, but simpler. The function is 13 | essentially :math:`\sin^2(x)` in ascending part and :math:`\cos^2` in 14 | descending part. 15 | 16 | Parameters 17 | ---------- 18 | G : graph 19 | Nf : int 20 | Number of filters to cover the interval [0, lmax]. 21 | scales : array_like 22 | Scales to be used. Defaults to log scale. 23 | 24 | Examples 25 | -------- 26 | 27 | Filter bank's representation in Fourier and time (ring graph) domains. 28 | 29 | >>> import matplotlib.pyplot as plt 30 | >>> G = graphs.Ring(N=20) 31 | >>> G.estimate_lmax() 32 | >>> G.set_coordinates('line1D') 33 | >>> g = filters.SimpleTight(G) 34 | >>> s = g.localize(G.N // 2) 35 | >>> fig, axes = plt.subplots(1, 2) 36 | >>> _ = g.plot(ax=axes[0]) 37 | >>> _ = G.plot(s, ax=axes[1]) 38 | 39 | """ 40 | 41 | def __init__(self, G, Nf=6, scales=None): 42 | def kernel(x, kerneltype): 43 | r""" 44 | Evaluates 'simple' tight-frame kernel. 45 | 46 | * simple tf wavelet kernel: supported on [1/4, 1] 47 | * simple tf scaling function kernel: supported on [0, 1/2] 48 | 49 | Parameters 50 | ---------- 51 | x : ndarray 52 | Array of independent variable values 53 | kerneltype : str 54 | Can be either 'sf' or 'wavelet' 55 | 56 | Returns 57 | ------- 58 | r : ndarray 59 | 60 | """ 61 | 62 | l1 = 0.25 63 | l2 = 0.5 64 | l3 = 1.0 65 | 66 | def h(x): 67 | return np.sin(np.pi * x / 2.0) ** 2 68 | 69 | r1ind = x < l1 70 | r2ind = (x >= l1) * (x < l2) 71 | r3ind = (x >= l2) * (x < l3) 72 | 73 | r = np.zeros(x.shape) 74 | if kerneltype == "sf": 75 | r[r1ind] = 1.0 76 | r[r2ind] = np.sqrt(1 - h(4 * x[r2ind] - 1) ** 2) 77 | elif kerneltype == "wavelet": 78 | r[r2ind] = h(4 * (x[r2ind] - 1 / 4.0)) 79 | r[r3ind] = np.sqrt(1 - h(2 * x[r3ind] - 1) ** 2) 80 | else: 81 | raise TypeError("Unknown kernel type", kerneltype) 82 | 83 | return r 84 | 85 | if not scales: 86 | scales = 1.0 / (2.0 * G.lmax) * np.power(2, np.arange(Nf - 2, -1, -1)) 87 | self.scales = scales 88 | 89 | if len(scales) != Nf - 1: 90 | raise ValueError("len(scales) should be Nf-1.") 91 | 92 | kernels = [lambda x: kernel(scales[0] * x, "sf")] 93 | 94 | for i in range(Nf - 1): 95 | kernels.append(lambda x, i=i: kernel(scales[i] * x, "wavelet")) 96 | 97 | super().__init__(G, kernels) 98 | -------------------------------------------------------------------------------- /pygsp/optimization.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`pygsp.optimization` module provides tools to solve convex 3 | optimization problems on graphs. 4 | """ 5 | 6 | from pygsp import utils 7 | 8 | logger = utils.build_logger(__name__) 9 | 10 | 11 | def _import_pyunlocbox(): 12 | try: 13 | from pyunlocbox import functions, solvers 14 | except Exception as e: 15 | raise ImportError( 16 | "Cannot import pyunlocbox, which is needed to solve " 17 | "this optimization problem. Try to install it with " 18 | "pip (or conda) install pyunlocbox. " 19 | "Original exception: {}".format(e) 20 | ) 21 | return functions, solvers 22 | 23 | 24 | def prox_tv(x, gamma, G, A=None, At=None, nu=1, tol=10e-4, maxit=200, use_matrix=True): 25 | r""" 26 | Total Variation proximal operator for graphs. 27 | 28 | This function computes the TV proximal operator for graphs. The TV norm 29 | is the one norm of the gradient. The gradient is defined in the 30 | function :meth:`pygsp.graphs.Graph.grad`. 31 | This function requires the PyUNLocBoX to be executed. 32 | 33 | This function solves: 34 | 35 | :math:`sol = \min_{z} \frac{1}{2} \|x - z\|_2^2 + \gamma \|x\|_{TV}` 36 | 37 | Parameters 38 | ---------- 39 | x: int 40 | Input signal 41 | gamma: ndarray 42 | Regularization parameter 43 | G: graph object 44 | Graphs structure 45 | A: lambda function 46 | Forward operator, this parameter allows to solve the following problem: 47 | :math:`sol = \min_{z} \frac{1}{2} \|x - z\|_2^2 + \gamma \| A x\|_{TV}` 48 | (default = Id) 49 | At: lambda function 50 | Adjoint operator. (default = Id) 51 | nu: float 52 | Bound on the norm of the operator (default = 1) 53 | tol: float 54 | Stops criterion for the loop. The algorithm will stop if : 55 | :math:`\frac{n(t) - n(t - 1)} {n(t)} < tol` 56 | where :math:`n(t) = f(x) + 0.5 \|x-y\|_2^2` is the objective function at iteration :math:`t` 57 | (default = :math:`10e-4`) 58 | maxit: int 59 | Maximum iteration. (default = 200) 60 | use_matrix: bool 61 | If a matrix should be used. (default = True) 62 | 63 | Returns 64 | ------- 65 | sol: solution 66 | 67 | Examples 68 | -------- 69 | 70 | """ 71 | if A is None: 72 | 73 | def A(x): 74 | return x 75 | 76 | if At is None: 77 | 78 | def At(x): 79 | return x 80 | 81 | tight = 0 82 | l1_nu = 2 * G.lmax * nu 83 | 84 | if use_matrix: 85 | 86 | def l1_a(x): 87 | return G.Diff * A(x) 88 | 89 | def l1_at(x): 90 | return G.Diff * At(D.T * x) 91 | 92 | else: 93 | 94 | def l1_a(x): 95 | return G.grad(A(x)) 96 | 97 | def l1_at(x): 98 | return G.div(x) 99 | 100 | functions, _ = _import_pyunlocbox() 101 | functions.norm_l1( 102 | x, gamma, A=l1_a, At=l1_at, tight=tight, maxit=maxit, verbose=verbose, tol=tol 103 | ) 104 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/cube.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .nngraph import NNGraph # prevent circular import in Python < 3.5 4 | 5 | 6 | class Cube(NNGraph): 7 | r"""Hyper-cube (NN-graph). 8 | 9 | Parameters 10 | ---------- 11 | radius : float 12 | Edge lenght (default = 1) 13 | nb_pts : int 14 | Number of vertices (default = 300) 15 | nb_dim : int 16 | Dimension (default = 3) 17 | sampling : string 18 | Variance of the distance kernel (default = 'random') 19 | (Can now only be 'random') 20 | seed : int 21 | Seed for the random number generator (for reproducible graphs). 22 | 23 | Examples 24 | -------- 25 | >>> import matplotlib.pyplot as plt 26 | >>> G = graphs.Cube(seed=42) 27 | >>> fig = plt.figure() 28 | >>> ax1 = fig.add_subplot(121) 29 | >>> ax2 = fig.add_subplot(122, projection='3d') 30 | >>> _ = ax1.spy(G.W, markersize=0.5) 31 | >>> _ = G.plot(ax=ax2) 32 | 33 | """ 34 | 35 | def __init__( 36 | self, radius=1, nb_pts=300, nb_dim=3, sampling="random", seed=None, **kwargs 37 | ): 38 | self.radius = radius 39 | self.nb_pts = nb_pts 40 | self.nb_dim = nb_dim 41 | self.sampling = sampling 42 | self.seed = seed 43 | rs = np.random.RandomState(seed) 44 | 45 | if self.nb_dim > 3: 46 | raise NotImplementedError("Dimension > 3 not supported yet!") 47 | 48 | if self.sampling == "random": 49 | if self.nb_dim == 2: 50 | pts = rs.rand(self.nb_pts, self.nb_dim) 51 | 52 | elif self.nb_dim == 3: 53 | n = self.nb_pts // 6 54 | 55 | pts = np.zeros((n * 6, 3)) 56 | pts[:n, 1:] = rs.rand(n, 2) 57 | pts[n : 2 * n, :] = np.concatenate( 58 | (np.ones((n, 1)), rs.rand(n, 2)), axis=1 59 | ) 60 | 61 | pts[2 * n : 3 * n, :] = np.concatenate( 62 | (rs.rand(n, 1), np.zeros((n, 1)), rs.rand(n, 1)), axis=1 63 | ) 64 | pts[3 * n : 4 * n, :] = np.concatenate( 65 | (rs.rand(n, 1), np.ones((n, 1)), rs.rand(n, 1)), axis=1 66 | ) 67 | 68 | pts[4 * n : 5 * n, :2] = rs.rand(n, 2) 69 | pts[5 * n : 6 * n, :] = np.concatenate( 70 | (rs.rand(n, 2), np.ones((n, 1))), axis=1 71 | ) 72 | 73 | else: 74 | raise ValueError("Unknown sampling !") 75 | 76 | plotting = { 77 | "vertex_size": 80, 78 | "elevation": 15, 79 | "azimuth": 0, 80 | "distance": 9, 81 | } 82 | 83 | super().__init__( 84 | Xin=pts, k=10, center=False, rescale=False, plotting=plotting, **kwargs 85 | ) 86 | 87 | def _get_extra_repr(self): 88 | attrs = { 89 | "radius": f"{self.radius:.2f}", 90 | "nb_pts": self.nb_pts, 91 | "nb_dim": self.nb_dim, 92 | "sampling": self.sampling, 93 | "seed": self.seed, 94 | } 95 | attrs.update(super()._get_extra_repr()) 96 | return attrs 97 | -------------------------------------------------------------------------------- /pygsp/graphs/ring.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from .graph import Graph # prevent circular import in Python < 3.5 5 | 6 | 7 | class Ring(Graph): 8 | r"""K-regular ring graph. 9 | 10 | A signal on the ring graph is akin to a 1-dimensional periodic signal in 11 | classical signal processing. 12 | 13 | On the ring graph, the graph Fourier transform (GFT) is the classical 14 | discrete Fourier transform (DFT_). 15 | Actually, the Laplacian of the ring graph is a `circulant matrix`_, and any 16 | circulant matrix is diagonalized by the DFT. 17 | 18 | .. _DFT: https://en.wikipedia.org/wiki/Discrete_Fourier_transform 19 | .. _circulant matrix: https://en.wikipedia.org/wiki/Circulant_matrix 20 | 21 | Parameters 22 | ---------- 23 | N : int 24 | Number of vertices. 25 | k : int 26 | Number of neighbors in each direction. 27 | 28 | See Also 29 | -------- 30 | Path : 1D line with even boundary conditions 31 | Torus : Kronecker product of two ring graphs 32 | 33 | Examples 34 | -------- 35 | >>> import matplotlib.pyplot as plt 36 | >>> G = graphs.Ring(N=10) 37 | >>> fig, axes = plt.subplots(1, 2) 38 | >>> _ = axes[0].spy(G.W) 39 | >>> _ = G.plot(ax=axes[1]) 40 | 41 | The GFT of the ring graph is the classical DFT. 42 | 43 | >>> from matplotlib import pyplot as plt 44 | >>> n_eigenvectors = 4 45 | >>> graph = graphs.Ring(30) 46 | >>> fig, axes = plt.subplots(1, 2) 47 | >>> graph.set_coordinates('line1D') 48 | >>> graph.compute_fourier_basis() 49 | >>> _ = graph.plot(graph.U[:, :n_eigenvectors], ax=axes[0]) 50 | >>> _ = axes[0].legend(range(n_eigenvectors)) 51 | >>> _ = axes[1].plot(graph.e, '.') 52 | 53 | """ 54 | 55 | def __init__(self, N=64, k=1, **kwargs): 56 | self.k = k 57 | 58 | if N < 3: 59 | # Asymmetric graph needed for 2 as 2 distances connect them. 60 | raise ValueError("There should be at least 3 vertices.") 61 | 62 | if 2 * k > N: 63 | raise ValueError("Too many neighbors requested.") 64 | 65 | if 2 * k == N: 66 | num_edges = N * (k - 1) + k 67 | else: 68 | num_edges = N * k 69 | 70 | i_inds = np.zeros(2 * num_edges) 71 | j_inds = np.zeros(2 * num_edges) 72 | 73 | tmpN = np.arange(N, dtype=int) 74 | for i in range(min(k, (N - 1) // 2)): 75 | i_inds[2 * i * N + tmpN] = tmpN 76 | j_inds[2 * i * N + tmpN] = np.remainder(tmpN + i + 1, N) 77 | i_inds[(2 * i + 1) * N + tmpN] = np.remainder(tmpN + i + 1, N) 78 | j_inds[(2 * i + 1) * N + tmpN] = tmpN 79 | 80 | if 2 * k == N: 81 | i_inds[2 * N * (k - 1) + tmpN] = tmpN 82 | i_inds[2 * N * (k - 1) + tmpN] = np.remainder(tmpN + k + 1, N) 83 | 84 | W = sparse.csc_matrix((np.ones(2 * num_edges), (i_inds, j_inds)), shape=(N, N)) 85 | 86 | plotting = {"limits": np.array([-1, 1, -1, 1])} 87 | 88 | super().__init__(W, plotting=plotting, **kwargs) 89 | 90 | self.set_coordinates("ring2D") 91 | 92 | def _get_extra_repr(self): 93 | return dict(k=self.k) 94 | -------------------------------------------------------------------------------- /examples/playground.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Playing with the PyGSP\n", 8 | "" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "%matplotlib inline\n", 18 | "\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from pygsp import graphs, filters\n", 22 | "\n", 23 | "plt.rcParams['figure.figsize'] = (17, 5)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## 1 Example\n", 31 | "\n", 32 | "The following demonstrates how to instantiate a graph and a filter, the two main objects of the package." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "G = graphs.Logo()\n", 42 | "G.estimate_lmax()\n", 43 | "g = filters.Heat(G, tau=100)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "Let's now create a graph signal: a set of three Kronecker deltas for that example. We can now look at one step of heat diffusion by filtering the deltas with the above defined filter. Note how the diffusion follows the local structure!" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "DELTAS = [20, 30, 1090]\n", 60 | "s = np.zeros(G.N)\n", 61 | "s[DELTAS] = 1\n", 62 | "s = g.filter(s)\n", 63 | "G.plot(s, highlight=DELTAS, backend='matplotlib')" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## 2 Tutorials and examples\n", 71 | "\n", 72 | "Try our [tutorials](https://pygsp.readthedocs.io/en/stable/tutorials/index.html) or [examples](https://pygsp.readthedocs.io/en/stable/examples/index.html)." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# Your code here." 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "## 3 Playground\n", 89 | "\n", 90 | "Try something of your own!\n", 91 | "The [API reference](https://pygsp.readthedocs.io/en/stable/reference/index.html) is your friend." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# Your code here." 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "If you miss a package, you can install it with:" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "%pip install numpy" 117 | ] 118 | } 119 | ], 120 | "metadata": { 121 | "kernelspec": { 122 | "display_name": "Python 3", 123 | "language": "python", 124 | "name": "python3" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 4 129 | } 130 | -------------------------------------------------------------------------------- /pygsp/graphs/swissroll.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pygsp import utils 4 | 5 | from .graph import Graph # prevent circular import in Python < 3.5 6 | 7 | 8 | class SwissRoll(Graph): 9 | r"""Sampled Swiss roll manifold. 10 | 11 | Parameters 12 | ---------- 13 | N : int 14 | Number of vertices (default = 400) 15 | a : int 16 | (default = 1) 17 | b : int 18 | (default = 4) 19 | dim : int 20 | (default = 3) 21 | thresh : float 22 | (default = 1e-6) 23 | s : float 24 | sigma (default = sqrt(2./N)) 25 | noise : bool 26 | Wether to add noise or not (default = False) 27 | srtype : str 28 | Swiss roll Type, possible arguments are 'uniform' or 'classic' 29 | (default = 'uniform') 30 | seed : int 31 | Seed for the random number generator (for reproducible graphs). 32 | 33 | Examples 34 | -------- 35 | >>> import matplotlib.pyplot as plt 36 | >>> G = graphs.SwissRoll(N=200, seed=42) 37 | >>> fig = plt.figure() 38 | >>> ax1 = fig.add_subplot(121) 39 | >>> ax2 = fig.add_subplot(122, projection='3d') 40 | >>> _ = ax1.spy(G.W, markersize=1) 41 | >>> _ = G.plot(ax=ax2) 42 | 43 | """ 44 | 45 | def __init__( 46 | self, 47 | N=400, 48 | a=1, 49 | b=4, 50 | dim=3, 51 | thresh=1e-6, 52 | s=None, 53 | noise=False, 54 | srtype="uniform", 55 | seed=None, 56 | **kwargs, 57 | ): 58 | if s is None: 59 | s = np.sqrt(2.0 / N) 60 | 61 | self.a = a 62 | self.b = b 63 | self.dim = dim 64 | self.thresh = thresh 65 | self.s = s 66 | self.noise = noise 67 | self.srtype = srtype 68 | self.seed = seed 69 | 70 | rng = np.random.default_rng(seed) 71 | y1 = rng.uniform(size=N) 72 | y2 = rng.uniform(size=N) 73 | 74 | if srtype == "uniform": 75 | tt = np.sqrt((b * b - a * a) * y1 + a * a) 76 | elif srtype == "classic": 77 | tt = (b - a) * y1 + a 78 | tt *= np.pi 79 | 80 | if dim == 2: 81 | x = np.array((tt * np.cos(tt), tt * np.sin(tt))) 82 | elif dim == 3: 83 | x = np.array((tt * np.cos(tt), 21 * y2, tt * np.sin(tt))) 84 | 85 | if noise: 86 | x += rng.normal(size=x.shape) 87 | 88 | self.x = x 89 | self.dim = dim 90 | 91 | coords = utils.rescale_center(x) 92 | dist = utils.distanz(coords) 93 | W = np.exp(-np.power(dist, 2) / (2.0 * s**2)) 94 | W -= np.diag(np.diag(W)) 95 | W[W < thresh] = 0 96 | 97 | plotting = { 98 | "vertex_size": 60, 99 | "limits": np.array([-1, 1, -1, 1, -1, 1]), 100 | "elevation": 15, 101 | "azimuth": -90, 102 | "distance": 7, 103 | } 104 | 105 | super().__init__(W, coords=coords.T, plotting=plotting, **kwargs) 106 | 107 | def _get_extra_repr(self): 108 | return { 109 | "a": self.a, 110 | "b": self.b, 111 | "dim": self.dim, 112 | "thresh": f"{self.thresh:.0e}", 113 | "s": f"{self.s:.2f}", 114 | "noise": self.noise, 115 | "srtype": self.srtype, 116 | "seed": self.seed, 117 | } 118 | -------------------------------------------------------------------------------- /pygsp/filters/gabor.py: -------------------------------------------------------------------------------- 1 | from pygsp import utils 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Gabor(Filter): 7 | r"""Design a filter bank with a kernel centered at each frequency. 8 | 9 | Design a filter bank from translated versions of a mother filter. 10 | The mother filter is translated to each eigenvalue of the Laplacian. 11 | That is equivalent to convolutions with deltas placed at those eigenvalues. 12 | 13 | In classical image processing, a Gabor filter is a sinusoidal wave 14 | multiplied by a Gaussian function (here, the kernel). It analyzes whether 15 | there are any specific frequency content in the image in specific 16 | directions in a localized region around the point of analysis. This 17 | implementation for graph signals allows arbitrary (but isotropic) kernels. 18 | 19 | This filter bank can be used to compute the frequency content of a signal 20 | at each vertex. After filtering, one obtains a vertex-frequency 21 | representation :math:`Sf(i,k)` of a signal :math:`f` as 22 | 23 | .. math:: Sf(i, k) = \langle g_{i,k}, f \rangle, 24 | 25 | where :math:`g_{i,k}` is the mother kernel centered on eigenvalue 26 | :math:`\lambda_k` and localized on vertex :math:`v_i`. 27 | 28 | While :math:`g_{i,k}` should ideally be localized in both the spectral and 29 | vertex domains, that is impossible for some graphs due to the localization 30 | of some eigenvectors. See :attr:`pygsp.graphs.Graph.coherence`. 31 | 32 | Parameters 33 | ---------- 34 | graph : :class:`pygsp.graphs.Graph` 35 | kernel : :class:`pygsp.filters.Filter` 36 | Kernel function to be centered at each graph frequency (eigenvalue of 37 | the graph Laplacian). 38 | 39 | See Also 40 | -------- 41 | Modulation : Another way to translate a filter in the spectral domain. 42 | 43 | Notes 44 | ----- 45 | The eigenvalues of the graph Laplacian (i.e., the Fourier basis) are needed 46 | to center the kernels. 47 | 48 | Examples 49 | -------- 50 | 51 | Filter bank's representation in Fourier and time (path graph) domains. 52 | 53 | >>> import matplotlib.pyplot as plt 54 | >>> G = graphs.Path(N=7) 55 | >>> G.compute_fourier_basis() 56 | >>> G.set_coordinates('line1D') 57 | >>> 58 | >>> g1 = filters.Expwin(G, band_min=None, band_max=0, slope=3) 59 | >>> g2 = filters.Rectangular(G, band_min=-0.05, band_max=0.05) 60 | >>> g3 = filters.Heat(G, scale=10) 61 | >>> 62 | >>> fig, axes = plt.subplots(3, 2, figsize=(10, 10)) 63 | >>> for g, ax in zip([g1, g2, g3], axes): 64 | ... g = filters.Gabor(G, g) 65 | ... s = g.localize(G.N // 2, method='exact') 66 | ... _ = g.plot(ax=ax[0], sum=False) 67 | ... _ = G.plot(s, ax=ax[1]) 68 | >>> fig.tight_layout() 69 | 70 | """ 71 | 72 | def __init__(self, graph, kernel): 73 | if kernel.n_filters != 1: 74 | raise ValueError( 75 | "A kernel must be one filter. The passed " 76 | "filter bank {} has {}.".format(kernel, kernel.n_filters) 77 | ) 78 | if kernel.G is not graph: 79 | raise ValueError( 80 | "The graph passed to this filter bank must " 81 | "be the one used to build the mother kernel." 82 | ) 83 | 84 | kernels = [] 85 | for i in range(graph.n_vertices): 86 | kernels.append(lambda x, i=i: kernel.evaluate(x - graph.e[i])) 87 | 88 | super().__init__(graph, kernels) 89 | 90 | def filter(self, s, method="exact", order=None): 91 | """TODO: indirection will be removed when poly filtering is merged.""" 92 | return super().filter(s, method="exact") 93 | -------------------------------------------------------------------------------- /pygsp/filters/abspline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import optimize 3 | 4 | from pygsp import utils 5 | 6 | from .filter import Filter # prevent circular import in Python < 3.5 7 | 8 | 9 | class Abspline(Filter): 10 | r"""Design an A B cubic spline wavelet filter bank. 11 | 12 | Parameters 13 | ---------- 14 | G : graph 15 | Nf : int 16 | Number of filters from 0 to lmax (default = 6) 17 | lpfactor : float 18 | Low-pass factor lmin=lmax/lpfactor will be used to determine scales, 19 | the scaling function will be created to fill the lowpass gap. 20 | (default = 20) 21 | scales : ndarray 22 | Vector of scales to be used. 23 | By default, initialized with :func:`pygsp.utils.compute_log_scales`. 24 | 25 | Examples 26 | -------- 27 | 28 | Filter bank's representation in Fourier and time (ring graph) domains. 29 | 30 | >>> import matplotlib.pyplot as plt 31 | >>> G = graphs.Ring(N=20) 32 | >>> G.estimate_lmax() 33 | >>> G.set_coordinates('line1D') 34 | >>> g = filters.Abspline(G) 35 | >>> s = g.localize(G.N // 2) 36 | >>> fig, axes = plt.subplots(1, 2) 37 | >>> _ = g.plot(ax=axes[0]) 38 | >>> _ = G.plot(s, ax=axes[1]) 39 | 40 | """ 41 | 42 | def __init__(self, G, Nf=6, lpfactor=20, scales=None): 43 | def kernel_abspline3(x, alpha, beta, t1, t2): 44 | M = np.array( 45 | [ 46 | [1, t1, t1**2, t1**3], 47 | [1, t2, t2**2, t2**3], 48 | [0, 1, 2 * t1, 3 * t1**2], 49 | [0, 1, 2 * t2, 3 * t2**2], 50 | ] 51 | ) 52 | v = np.array( 53 | [ 54 | 1, 55 | 1, 56 | t1 ** (-alpha) * alpha * t1 ** (alpha - 1), 57 | -beta * t2 ** (-beta - 1) * t2**beta, 58 | ] 59 | ) 60 | a = np.linalg.solve(M, v) 61 | 62 | r1 = x <= t1 63 | r2 = (x >= t1) * (x < t2) 64 | r3 = x >= t2 65 | 66 | if isinstance(x, np.float64): 67 | if r1: 68 | r = x[r1] ** alpha * t1 ** (-alpha) 69 | if r2: 70 | r = a[0] + a[1] * x + a[2] * x**2 + a[3] * x**3 71 | if r3: 72 | r = x[r3] ** (-beta) * t2**beta 73 | 74 | else: 75 | r = np.zeros(x.shape) 76 | 77 | x2 = x[r2] 78 | 79 | r[r1] = x[r1] ** alpha * t1 ** (-alpha) 80 | r[r2] = a[0] + a[1] * x2 + a[2] * x2**2 + a[3] * x2**3 81 | r[r3] = x[r3] ** (-beta) * t2**beta 82 | 83 | return r 84 | 85 | self.lpfactor = lpfactor 86 | 87 | lmin = G.lmax / lpfactor 88 | 89 | if scales is None: 90 | scales = utils.compute_log_scales(lmin, G.lmax, Nf - 1) 91 | self.scales = scales 92 | 93 | gb = lambda x: kernel_abspline3(x, 2, 2, 1, 2) 94 | gl = lambda x: np.exp(-np.power(x, 4)) 95 | 96 | lminfac = 0.4 * lmin 97 | 98 | g = [lambda x: 1.2 * np.exp(-1) * gl(x / lminfac)] 99 | for i in range(0, Nf - 1): 100 | g.append(lambda x, i=i: gb(self.scales[i] * x)) 101 | 102 | f = lambda x: -gb(x) 103 | xstar = optimize.minimize_scalar(f, bounds=(1, 2), method="bounded") 104 | gamma_l = -f(xstar.x) 105 | lminfac = 0.6 * lmin 106 | g[0] = lambda x: gamma_l * gl(x / lminfac) 107 | 108 | super().__init__(G, g) 109 | 110 | def _get_extra_repr(self): 111 | return dict(lpfactor=f"{self.lpfactor:.2f}") 112 | -------------------------------------------------------------------------------- /pygsp/graphs/torus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from .graph import Graph # prevent circular import in Python < 3.5 5 | 6 | 7 | class Torus(Graph): 8 | r"""Sampled torus manifold. 9 | 10 | On the torus, the graph Fourier transform (GFT) is the Kronecker product 11 | between the GFT of two :class:`~pygsp.graphs.Ring` graphs. 12 | 13 | Parameters 14 | ---------- 15 | Nv : int 16 | Number of vertices along the first dimension. 17 | Mv : int 18 | Number of vertices along the second dimension. Default is ``Nv``. 19 | 20 | See Also 21 | -------- 22 | Ring : 1D line with periodic boundary conditions 23 | Grid2d : Kronecker product of two path graphs 24 | 25 | Examples 26 | -------- 27 | >>> import matplotlib.pyplot as plt 28 | >>> G = graphs.Torus(10) 29 | >>> fig = plt.figure() 30 | >>> ax1 = fig.add_subplot(121) 31 | >>> ax2 = fig.add_subplot(122, projection='3d') 32 | >>> _ = ax1.spy(G.W, markersize=1.5) 33 | >>> _ = G.plot(ax=ax2) 34 | >>> _ = ax2.set_zlim(-1.5, 1.5) 35 | 36 | """ 37 | 38 | def __init__(self, Nv=16, Mv=None, **kwargs): 39 | if Mv is None: 40 | Mv = Nv 41 | 42 | self.Nv = Nv 43 | self.Mv = Mv 44 | 45 | # Create weighted adjancency matrix 46 | K = 2 * Nv 47 | J = 2 * Mv 48 | i_inds = np.zeros((K * Mv + J * Nv), dtype=float) 49 | j_inds = np.zeros((K * Mv + J * Nv), dtype=float) 50 | 51 | tmpK = np.arange(K, dtype=int) 52 | tmpNv1 = np.arange(Nv - 1) 53 | tmpNv = np.arange(Nv) 54 | 55 | for i in range(Mv): 56 | i_inds[i * K + tmpK] = i * Nv + np.concatenate( 57 | (np.array([Nv - 1]), tmpNv1, tmpNv) 58 | ) 59 | 60 | j_inds[i * K + tmpK] = i * Nv + np.concatenate( 61 | (tmpNv, np.array([Nv - 1]), tmpNv1) 62 | ) 63 | 64 | tmp2Nv = np.arange(2 * Nv, dtype=int) 65 | 66 | for i in range(Mv - 1): 67 | i_inds[K * Mv + i * 2 * Nv + tmp2Nv] = np.concatenate( 68 | (i * Nv + tmpNv, (i + 1) * Nv + tmpNv) 69 | ) 70 | 71 | j_inds[K * Mv + i * 2 * Nv + tmp2Nv] = np.concatenate( 72 | ((i + 1) * Nv + tmpNv, i * Nv + tmpNv) 73 | ) 74 | 75 | i_inds[K * Mv + (Mv - 1) * 2 * Nv + tmp2Nv] = np.concatenate( 76 | (tmpNv, (Mv - 1) * Nv + tmpNv) 77 | ) 78 | 79 | j_inds[K * Mv + (Mv - 1) * 2 * Nv + tmp2Nv] = np.concatenate( 80 | ((Mv - 1) * Nv + tmpNv, tmpNv) 81 | ) 82 | 83 | W = sparse.csc_matrix( 84 | (np.ones(K * Mv + J * Nv), (i_inds, j_inds)), shape=(Mv * Nv, Mv * Nv) 85 | ) 86 | 87 | # Create coordinate 88 | T = 1.5 + np.sin(np.arange(Mv) * 2 * np.pi / Mv).reshape(1, Mv) 89 | U = np.cos(np.arange(Mv) * 2 * np.pi / Mv).reshape(1, Mv) 90 | xtmp = np.cos(np.arange(Nv).reshape(Nv, 1) * 2 * np.pi / Nv) * T 91 | ytmp = np.sin(np.arange(Nv).reshape(Nv, 1) * 2 * np.pi / Nv) * T 92 | ztmp = np.kron(np.ones((Nv, 1)), U) 93 | coords = np.concatenate( 94 | ( 95 | np.reshape(xtmp, (Mv * Nv, 1), order="F"), 96 | np.reshape(ytmp, (Mv * Nv, 1), order="F"), 97 | np.reshape(ztmp, (Mv * Nv, 1), order="F"), 98 | ), 99 | axis=1, 100 | ) 101 | 102 | plotting = { 103 | "vertex_size": 60, 104 | "limits": np.array([-2.5, 2.5, -2.5, 2.5, -2.5, 2.5]), 105 | } 106 | 107 | super().__init__(W, coords=coords, plotting=plotting, **kwargs) 108 | 109 | def _get_extra_repr(self): 110 | return dict(Nv=self.Nv, Mv=self.Mv) 111 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/imgpatches.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .nngraph import NNGraph # prevent circular import in Python < 3.5 4 | 5 | 6 | class ImgPatches(NNGraph): 7 | r"""NN-graph between patches of an image. 8 | 9 | Extract a feature vector in the form of a patch for every pixel of an 10 | image, then construct a nearest-neighbor graph between these feature 11 | vectors. The feature matrix, i.e. the patches, can be found in :attr:`Xin`. 12 | 13 | Parameters 14 | ---------- 15 | img : array 16 | Input image. 17 | patch_shape : tuple, optional 18 | Dimensions of the patch window. Syntax: (height, width), or (height,), 19 | in which case width = height. 20 | kwargs : dict 21 | Parameters passed to :class:`NNGraph`. 22 | 23 | See Also 24 | -------- 25 | Grid2dImgPatches 26 | 27 | Notes 28 | ----- 29 | The feature vector of a pixel `i` will consist of the stacking of the 30 | intensity values of all pixels in the patch centered at `i`, for all color 31 | channels. So, if the input image has `d` color channels, the dimension of 32 | the feature vector of each pixel is (patch_shape[0] * patch_shape[1] * d). 33 | 34 | Examples 35 | -------- 36 | >>> import matplotlib.pyplot as plt 37 | >>> from skimage import data, img_as_float 38 | >>> img = img_as_float(data.camera()[::64, ::64]) 39 | >>> G = graphs.ImgPatches(img, patch_shape=(3, 3)) 40 | >>> print('{} nodes ({} x {} pixels)'.format(G.Xin.shape[0], *img.shape)) 41 | 64 nodes (8 x 8 pixels) 42 | >>> print('{} features per node'.format(G.Xin.shape[1])) 43 | 9 features per node 44 | >>> G.set_coordinates(kind='spring', seed=42) 45 | >>> fig, axes = plt.subplots(1, 2) 46 | >>> _ = axes[0].spy(G.W, markersize=2) 47 | >>> _ = G.plot(ax=axes[1]) 48 | 49 | """ 50 | 51 | def __init__(self, img, patch_shape=(3, 3), **kwargs): 52 | self.img = img 53 | self.patch_shape = patch_shape 54 | 55 | try: 56 | h, w, d = img.shape 57 | except ValueError: 58 | try: 59 | h, w = img.shape 60 | d = 0 61 | except ValueError: 62 | print("Image should be at least a 2D array.") 63 | 64 | try: 65 | r, c = patch_shape 66 | except ValueError: 67 | r = patch_shape[0] 68 | c = r 69 | 70 | pad_width = [ 71 | (int((r - 0.5) / 2.0), int((r + 0.5) / 2.0)), 72 | (int((c - 0.5) / 2.0), int((c + 0.5) / 2.0)), 73 | ] 74 | 75 | if d == 0: 76 | window_shape = (r, c) 77 | d = 1 # For the reshape in the return call 78 | else: 79 | pad_width += [(0, 0)] 80 | window_shape = (r, c, d) 81 | 82 | # Pad the image. 83 | img = np.pad(img, pad_width=pad_width, mode="symmetric") 84 | 85 | # Extract patches as node features. 86 | # Alternative: sklearn.feature_extraction.image.extract_patches_2d. 87 | # sklearn has much less dependencies than skimage. 88 | try: 89 | import skimage 90 | except Exception as e: 91 | raise ImportError( 92 | "Cannot import skimage, which is needed to " 93 | "extract patches. Try to install it with " 94 | "pip (or conda) install scikit-image. " 95 | "Original exception: {}".format(e) 96 | ) 97 | patches = skimage.util.view_as_windows(img, window_shape=window_shape) 98 | patches = patches.reshape((h * w, r * c * d)) 99 | 100 | super().__init__(patches, **kwargs) 101 | 102 | def _get_extra_repr(self): 103 | attrs = dict(patch_shape=self.patch_shape) 104 | attrs.update(super()._get_extra_repr()) 105 | return attrs 106 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "PyGSP" 7 | version = "0.6.1" 8 | description = "Graph Signal Processing in Python" 9 | readme = "README.rst" 10 | requires-python = ">=3.8" 11 | license = {text = "BSD"} 12 | keywords = ["graph", "signal", "processing"] 13 | authors = [ 14 | {name = "EPFL LTS2"} 15 | ] 16 | maintainers = [ 17 | {name = "EPFL LTS2"} 18 | ] 19 | classifiers = [ 20 | "Development Status :: 4 - Beta", 21 | "Topic :: Scientific/Engineering", 22 | "Intended Audience :: Developers", 23 | "Intended Audience :: Education", 24 | "Intended Audience :: Science/Research", 25 | "License :: OSI Approved :: BSD License", 26 | "Operating System :: OS Independent", 27 | "Programming Language :: Python", 28 | "Programming Language :: Python :: 3", 29 | "Programming Language :: Python :: 3.7", 30 | "Programming Language :: Python :: 3.8", 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10", 33 | "Programming Language :: Python :: 3.11", 34 | "Programming Language :: Python :: 3.12", 35 | "Programming Language :: Python :: 3.13", 36 | ] 37 | 38 | dependencies = [ 39 | "numpy", 40 | "scipy", 41 | ] 42 | 43 | [project.optional-dependencies] 44 | dev = [ 45 | # Import and export 46 | "networkx", 47 | # Construct patch graphs from images 48 | "scikit-image", 49 | # Fallback for nearest neighbors when pyflann is not available 50 | "scikit-learn", 51 | # Approximate nearest neighbors for kNN graphs 52 | "pyflann3", 53 | # Convex optimization 54 | "pyunlocbox>=0.6.1", 55 | # Plot graphs, signals, and filters 56 | "matplotlib", 57 | # Interactive graph visualization 58 | "pyqtgraph", 59 | "PyOpenGL", 60 | "PyQt6", 61 | # Run the tests 62 | "flake8", 63 | "coverage", 64 | "coveralls", 65 | "pytest", 66 | "pytest-cov", 67 | # Build the documentation 68 | "sphinx", 69 | "numpydoc", 70 | "sphinxcontrib-bibtex", 71 | "sphinx-gallery", 72 | "memory_profiler", 73 | "sphinx-rtd-theme", 74 | "sphinx-copybutton", 75 | # Build and upload packages 76 | "build", 77 | "wheel", 78 | "twine", 79 | "pre-commit", 80 | ] 81 | 82 | test = [ 83 | "pytest", 84 | "pytest-cov", 85 | "flake8", 86 | "coverage", 87 | ] 88 | 89 | docs = [ 90 | "sphinx", 91 | "numpydoc", 92 | "sphinxcontrib-bibtex", 93 | "sphinx-gallery", 94 | "memory_profiler", 95 | "sphinx-rtd-theme", 96 | "sphinx-copybutton", 97 | ] 98 | 99 | plot = [ 100 | "matplotlib", 101 | "pyqtgraph", 102 | "PyOpenGL", 103 | "PyQt6", 104 | ] 105 | 106 | [project.urls] 107 | Homepage = "https://github.com/epfl-lts2/pygsp" 108 | Documentation = "https://pygsp.readthedocs.io" 109 | Download = "https://pypi.org/project/PyGSP" 110 | "Source Code" = "https://github.com/epfl-lts2/pygsp" 111 | "Bug Tracker" = "https://github.com/epfl-lts2/pygsp/issues" 112 | "Try It Online" = "https://mybinder.org/v2/gh/epfl-lts2/pygsp/master?urlpath=lab/tree/examples/playground.ipynb" 113 | 114 | [tool.setuptools.packages.find] 115 | include = ["pygsp*"] 116 | 117 | [tool.setuptools.package-data] 118 | pygsp = ["data/pointclouds/*.mat"] 119 | 120 | [tool.pytest.ini_options] 121 | testpaths = ["pygsp/tests"] 122 | python_files = ["test_*.py"] 123 | python_classes = ["Test*"] 124 | python_functions = ["test_*"] 125 | 126 | [tool.coverage.run] 127 | source = ["pygsp"] 128 | omit = ["*/tests/*"] 129 | 130 | [tool.coverage.report] 131 | exclude_lines = [ 132 | "pragma: no cover", 133 | "def __repr__", 134 | "if self.debug:", 135 | "if settings.DEBUG", 136 | "raise AssertionError", 137 | "raise NotImplementedError", 138 | "if 0:", 139 | "if __name__ == .__main__.:", 140 | ] 141 | 142 | [tool.flake8] 143 | max-line-length = 88 144 | extend-ignore = ["E203", "W503"] 145 | 146 | [dependency-groups] 147 | dev = [ 148 | "pre-commit>=3.5.0", 149 | ] 150 | -------------------------------------------------------------------------------- /pygsp/filters/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`pygsp.filters` module implements methods used for filtering and 3 | defines commonly used filters that can be applied to :mod:`pygsp.graphs`. A 4 | filter is associated to a graph and is defined with one or several functions. 5 | We define by filter bank a list of filters, usually centered around different 6 | frequencies, applied to a single graph. 7 | 8 | Interface 9 | --------- 10 | 11 | The :class:`Filter` base class implements a common interface to all filters: 12 | 13 | .. autosummary:: 14 | 15 | Filter.evaluate 16 | Filter.filter 17 | Filter.analyze 18 | Filter.synthesize 19 | Filter.complement 20 | Filter.inverse 21 | Filter.compute_frame 22 | Filter.estimate_frame_bounds 23 | Filter.plot 24 | Filter.localize 25 | 26 | Filters 27 | ------- 28 | 29 | Then, derived classes implement various common graph filters. 30 | 31 | **Filters that solve differential equations** 32 | 33 | The following filters solve partial differential equations (PDEs) on graphs, 34 | which model processes such as heat diffusion or wave propagation. 35 | 36 | .. autosummary:: 37 | 38 | Heat 39 | Wave 40 | 41 | **Low-pass filters** 42 | 43 | .. autosummary:: 44 | 45 | Heat 46 | 47 | **Band-pass filters** 48 | 49 | These filters can be configured to be low-pass, high-pass, or band-pass. 50 | 51 | .. autosummary:: 52 | 53 | Expwin 54 | Rectangular 55 | 56 | **Filter banks of two filters: a low-pass and a high-pass** 57 | 58 | .. autosummary:: 59 | 60 | Regular 61 | Held 62 | Simoncelli 63 | Papadakis 64 | 65 | **Filter banks composed of dilated or translated filters** 66 | 67 | .. autosummary:: 68 | 69 | Abspline 70 | HalfCosine 71 | Itersine 72 | MexicanHat 73 | Meyer 74 | SimpleTight 75 | 76 | **Filter banks for vertex-frequency analyzes** 77 | 78 | Those filter banks are composed of shifted versions of a mother filter, one per 79 | graph frequency (Laplacian eigenvalue). They can analyze frequency content 80 | locally, as a windowed graph Fourier transform. 81 | 82 | .. autosummary:: 83 | 84 | Gabor 85 | Modulation 86 | 87 | Approximations 88 | -------------- 89 | 90 | Moreover, two approximation methods are provided for fast filtering. The 91 | computational complexity of filtering with those approximations is linear with 92 | the number of edges. The complexity of the exact solution, which is to use the 93 | Fourier basis, is quadratic with the number of nodes (without taking into 94 | account the cost of the necessary eigendecomposition of the graph Laplacian). 95 | 96 | **Chebyshev polynomials** 97 | 98 | .. autosummary:: 99 | 100 | compute_cheby_coeff 101 | compute_jackson_cheby_coeff 102 | cheby_op 103 | cheby_rect 104 | 105 | **Lanczos algorithm** 106 | 107 | .. autosummary:: 108 | 109 | lanczos 110 | lanczos_op 111 | 112 | """ 113 | 114 | from .abspline import Abspline # noqa: F401 115 | from .approximations import cheby_op # noqa: F401 116 | from .approximations import cheby_rect # noqa: F401 117 | from .approximations import compute_cheby_coeff # noqa: F401 118 | from .approximations import compute_jackson_cheby_coeff # noqa: F401 119 | from .approximations import lanczos # noqa: F401 120 | from .approximations import lanczos_op # noqa: F401 121 | from .expwin import Expwin # noqa: F401 122 | from .filter import Filter # noqa: F401 123 | from .gabor import Gabor # noqa: F401 124 | from .halfcosine import HalfCosine # noqa: F401 125 | from .heat import Heat # noqa: F401 126 | from .held import Held # noqa: F401 127 | from .itersine import Itersine # noqa: F401 128 | from .mexicanhat import MexicanHat # noqa: F401 129 | from .meyer import Meyer # noqa: F401 130 | from .modulation import Modulation # noqa: F401 131 | from .papadakis import Papadakis # noqa: F401 132 | from .rectangular import Rectangular # noqa: F401 133 | from .regular import Regular # noqa: F401 134 | from .simoncelli import Simoncelli # noqa: F401 135 | from .simpletight import SimpleTight # noqa: F401 136 | from .wave import Wave # noqa: F401 137 | -------------------------------------------------------------------------------- /pygsp/graphs/nngraphs/twomoons.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pygsp import utils 4 | 5 | from .nngraph import NNGraph # prevent circular import in Python < 3.5 6 | 7 | 8 | class TwoMoons(NNGraph): 9 | r"""Two Moons (NN-graph). 10 | 11 | Parameters 12 | ---------- 13 | moontype : 'standard' or 'synthesized' 14 | You have the freedom to chose if you want to create a standard 15 | two_moons graph or a synthesized one (default is 'standard'). 16 | 'standard' : Create a two_moons graph from a based graph. 17 | 'synthesized' : Create a synthesized two_moon 18 | sigmag : float 19 | Variance of the distance kernel (default = 0.05) 20 | dim : int 21 | The dimensionality of the points (default = 2). 22 | Only valid for moontype == 'standard'. 23 | N : int 24 | Number of vertices (default = 2000) 25 | Only valid for moontype == 'synthesized'. 26 | sigmad : float 27 | Variance of the data (do not set it too high or you won't see anything) 28 | (default = 0.05) 29 | Only valid for moontype == 'synthesized'. 30 | distance : float 31 | Distance between the two moons (default = 0.5) 32 | Only valid for moontype == 'synthesized'. 33 | seed : int 34 | Seed for the random number generator (for reproducible graphs). 35 | 36 | Examples 37 | -------- 38 | >>> import matplotlib.pyplot as plt 39 | >>> G = graphs.TwoMoons() 40 | >>> fig, axes = plt.subplots(1, 2) 41 | >>> _ = axes[0].spy(G.W, markersize=0.5) 42 | >>> _ = G.plot(edges=True, ax=axes[1]) 43 | 44 | """ 45 | 46 | def _create_arc_moon(self, N, sigmad, distance, number, seed): 47 | rng = np.random.default_rng(seed) 48 | phi = rng.uniform(size=(N, 1)) * np.pi 49 | r = 1 50 | rb = sigmad * rng.normal(size=(N, 1)) 51 | ab = rng.uniform(size=(N, 1)) * 2 * np.pi 52 | b = rb * np.exp(1j * ab) 53 | bx = np.real(b) 54 | by = np.imag(b) 55 | 56 | if number == 1: 57 | moonx = np.cos(phi) * r + bx + 0.5 58 | moony = -np.sin(phi) * r + by - (distance - 1) / 2.0 59 | elif number == 2: 60 | moonx = np.cos(phi) * r + bx - 0.5 61 | moony = np.sin(phi) * r + by + (distance - 1) / 2.0 62 | 63 | return np.concatenate((moonx, moony), axis=1) 64 | 65 | def __init__( 66 | self, 67 | moontype="standard", 68 | dim=2, 69 | sigmag=0.05, 70 | N=400, 71 | sigmad=0.07, 72 | distance=0.5, 73 | seed=None, 74 | **kwargs, 75 | ): 76 | self.moontype = moontype 77 | self.dim = dim 78 | self.sigmag = sigmag 79 | self.sigmad = sigmad 80 | self.distance = distance 81 | self.seed = seed 82 | 83 | if moontype == "standard": 84 | N1, N2 = 1000, 1000 85 | data = utils.loadmat("pointclouds/two_moons") 86 | Xin = data["features"][:dim].T 87 | 88 | elif moontype == "synthesized": 89 | N1 = N // 2 90 | N2 = N - N1 91 | 92 | coords1 = self._create_arc_moon(N1, sigmad, distance, 1, seed) 93 | coords2 = self._create_arc_moon(N2, sigmad, distance, 2, seed) 94 | 95 | Xin = np.concatenate((coords1, coords2)) 96 | 97 | else: 98 | raise ValueError(f"Unknown moontype {moontype}") 99 | 100 | self.labels = np.concatenate((np.zeros(N1), np.ones(N2))) 101 | 102 | plotting = { 103 | "vertex_size": 30, 104 | } 105 | 106 | super().__init__( 107 | Xin=Xin, 108 | sigma=sigmag, 109 | k=5, 110 | center=False, 111 | rescale=False, 112 | plotting=plotting, 113 | **kwargs, 114 | ) 115 | 116 | def _get_extra_repr(self): 117 | attrs = { 118 | "moontype": self.moontype, 119 | "dim": self.dim, 120 | "sigmag": f"{self.sigmag:.2f}", 121 | "sigmad": f"{self.sigmad:.2f}", 122 | "distance": f"{self.distance:.2f}", 123 | "seed": self.seed, 124 | } 125 | attrs.update(super()._get_extra_repr()) 126 | return attrs 127 | -------------------------------------------------------------------------------- /pygsp/filters/heat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .filter import Filter # prevent circular import in Python < 3.5 4 | 5 | 6 | class Heat(Filter): 7 | r"""Design a filter bank of heat kernels. 8 | 9 | The (low-pass) heat kernel is defined in the spectral domain as 10 | 11 | .. math:: g_\tau(\lambda) = \exp(-\tau \lambda), 12 | 13 | where :math:`\lambda \in [0, 1]` are the normalized eigenvalues of the 14 | graph Laplacian, and :math:`\tau` is a parameter that captures both time 15 | and thermal diffusivity. 16 | 17 | The heat kernel is the fundamental solution to the heat equation 18 | 19 | .. math:: - \tau L f(t) = \partial_t f(t), 20 | 21 | where :math:`f: \mathbb{R}_+ \rightarrow \mathbb{R}^N` is the heat 22 | distribution over the graph at time :math:`t`. Given the initial condition 23 | :math:`f(0)`, the solution of the heat equation is expressed as 24 | 25 | .. math:: f(t) = e^{-\tau t L} f(0) 26 | = U e^{-\tau t \Lambda} U^\top f(0) 27 | = g_{\tau t}(L) f(0). 28 | 29 | The above is, by definition, the convolution of the signal :math:`f(0)` 30 | with the kernel :math:`g_{\tau t}(\lambda) = \exp(-\tau t \lambda)`. 31 | Hence, applying this filter to a signal simulates heat diffusion. 32 | 33 | Since the kernel is applied to the graph eigenvalues :math:`\lambda`, which 34 | can be interpreted as squared frequencies, it can also be considered as a 35 | generalization of the Gaussian kernel on graphs. 36 | 37 | Parameters 38 | ---------- 39 | G : graph 40 | scale : float or iterable 41 | Scaling parameter. When solving heat diffusion, it encompasses both 42 | time and thermal diffusivity. 43 | If iterable, creates a filter bank with one filter per value. 44 | normalize : bool 45 | Whether to normalize the kernel to have unit L2 norm. 46 | The normalization needs the eigenvalues of the graph Laplacian. 47 | 48 | Examples 49 | -------- 50 | 51 | Filter bank's representation in Fourier and time (ring graph) domains. 52 | 53 | >>> import matplotlib.pyplot as plt 54 | >>> G = graphs.Ring(N=20) 55 | >>> G.estimate_lmax() 56 | >>> G.set_coordinates('line1D') 57 | >>> g = filters.Heat(G, scale=[5, 10, 100]) 58 | >>> s = g.localize(G.N // 2) 59 | >>> fig, axes = plt.subplots(1, 2) 60 | >>> _ = g.plot(ax=axes[0]) 61 | >>> _ = G.plot(s, ax=axes[1]) 62 | 63 | Heat diffusion from two sources on a grid. 64 | 65 | >>> import matplotlib.pyplot as plt 66 | >>> n_side = 11 67 | >>> graph = graphs.Grid2d(n_side) 68 | >>> graph.estimate_lmax() 69 | >>> sources = [ 70 | ... (n_side//4 * n_side) + (n_side//4), 71 | ... (n_side*3//4 * n_side) + (n_side*3//4), 72 | ... ] 73 | >>> delta = np.zeros(graph.n_vertices) 74 | >>> delta[sources] = 5 75 | >>> steps = np.array([1, 5]) 76 | >>> diffusivity = 10 77 | >>> g = filters.Heat(graph, scale=diffusivity*steps) 78 | >>> diffused = g.filter(delta) 79 | >>> fig, axes = plt.subplots(1, len(steps), figsize=(10, 4)) 80 | >>> _ = fig.suptitle('Heat diffusion', fontsize=16) 81 | >>> for i, ax in enumerate(axes): 82 | ... _ = graph.plot(diffused[:, i], highlight=sources, 83 | ... title='step {}'.format(steps[i]), ax=ax) 84 | ... ax.set_aspect('equal', 'box') 85 | ... ax.set_axis_off() 86 | 87 | Normalized heat kernel. 88 | 89 | >>> G = graphs.Logo() 90 | >>> G.compute_fourier_basis() 91 | >>> g = filters.Heat(G, scale=5) 92 | >>> y = g.evaluate(G.e) 93 | >>> print('norm: {:.2f}'.format(np.linalg.norm(y[0]))) 94 | norm: 9.76 95 | >>> g = filters.Heat(G, scale=5, normalize=True) 96 | >>> y = g.evaluate(G.e) 97 | >>> print('norm: {:.2f}'.format(np.linalg.norm(y[0]))) 98 | norm: 1.00 99 | 100 | """ 101 | 102 | def __init__(self, G, scale=10, normalize=False): 103 | try: 104 | iter(scale) 105 | except TypeError: 106 | scale = [scale] 107 | 108 | self.scale = scale 109 | self.normalize = normalize 110 | 111 | def kernel(x, scale): 112 | return np.minimum(np.exp(-scale * x / G.lmax), 1) 113 | 114 | kernels = [] 115 | for s in scale: 116 | norm = np.linalg.norm(kernel(G.e, s)) if normalize else 1 117 | kernels.append(lambda x, s=s, norm=norm: kernel(x, s) / norm) 118 | 119 | super().__init__(G, kernels) 120 | 121 | def _get_extra_repr(self): 122 | scale = "[" + ", ".join(f"{s:.2f}" for s in self.scale) + "]" 123 | return dict(scale=scale, normalize=self.normalize) 124 | -------------------------------------------------------------------------------- /pygsp/graphs/randomregular.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from pygsp import utils 5 | 6 | from .graph import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class RandomRegular(Graph): 10 | r"""Random k-regular graph. 11 | 12 | The random regular graph has the property that every node is connected to 13 | k other nodes. That graph is simple (without loops or double edges), 14 | k-regular (each vertex is adjacent to k nodes), and undirected. 15 | 16 | Parameters 17 | ---------- 18 | N : int 19 | Number of nodes (default is 64) 20 | k : int 21 | Number of connections, or degree, of each node (default is 6) 22 | max_iter : int 23 | Maximum number of iterations (default is 10) 24 | seed : int 25 | Seed for the random number generator (for reproducible graphs). 26 | 27 | Notes 28 | ----- 29 | The *pairing model* algorithm works as follows. First create n*d *half 30 | edges*. Then repeat as long as possible: pick a pair of half edges and if 31 | it's legal (doesn't create a loop nor a double edge) add it to the graph. 32 | 33 | References 34 | ---------- 35 | See :cite:`kim2003randomregulargraphs`. 36 | This code has been adapted from matlab to python. 37 | 38 | Examples 39 | -------- 40 | >>> import matplotlib.pyplot as plt 41 | >>> G = graphs.RandomRegular(N=64, k=5, seed=42) 42 | >>> G.set_coordinates(kind='spring', seed=42) 43 | >>> fig, axes = plt.subplots(1, 2) 44 | >>> _ = axes[0].spy(G.W, markersize=2) 45 | >>> _ = G.plot(ax=axes[1]) 46 | 47 | """ 48 | 49 | def __init__(self, N=64, k=6, max_iter=10, seed=None, **kwargs): 50 | self.k = k 51 | self.max_iter = max_iter 52 | self.seed = seed 53 | 54 | self.logger = utils.build_logger(__name__) 55 | 56 | rng = np.random.default_rng(seed) 57 | 58 | # continue until a proper graph is formed 59 | if (N * k) % 2 == 1: 60 | raise ValueError("input error: N*d must be even!") 61 | 62 | # a list of open half-edges 63 | U = np.kron(np.ones(k), np.arange(N)).astype(int) 64 | 65 | # the graphs adjacency matrix 66 | A = sparse.lil_matrix(np.zeros((N, N))) 67 | 68 | edgesTested = 0 69 | repetition = 1 70 | 71 | while np.size(U) and repetition < max_iter: 72 | edgesTested += 1 73 | 74 | if edgesTested % 5000 == 0: 75 | self.logger.debug( 76 | "createRandRegGraph() progress: edges= " 77 | "{}/{}.".format(edgesTested, N * k / 2) 78 | ) 79 | 80 | # chose at random 2 half edges 81 | i1 = rng.integers(0, U.shape[0]) 82 | i2 = rng.integers(0, U.shape[0]) 83 | v1 = U[i1] 84 | v2 = U[i2] 85 | 86 | # check that there are no loops nor parallel edges 87 | if v1 == v2 or A[v1, v2] == 1: 88 | # restart process if needed 89 | if edgesTested == N * k: 90 | repetition = repetition + 1 91 | edgesTested = 0 92 | U = np.kron(np.ones(k), np.arange(N)) 93 | A = sparse.lil_matrix(np.zeros((N, N))) 94 | else: 95 | # add edge to graph 96 | A[v1, v2] = 1 97 | A[v2, v1] = 1 98 | 99 | # remove used half-edges 100 | v = sorted([i1, i2]) 101 | U = np.concatenate((U[: v[0]], U[v[0] + 1 : v[1]], U[v[1] + 1 :])) 102 | 103 | super().__init__(A, **kwargs) 104 | 105 | self.is_regular() 106 | 107 | def is_regular(self): 108 | r""" 109 | Troubleshoot a given regular graph. 110 | 111 | """ 112 | warn = False 113 | msg = "The given matrix" 114 | 115 | # check symmetry 116 | if np.abs(self.A - self.A.T).sum() > 0: 117 | warn = True 118 | msg = f"{msg} is not symmetric," 119 | 120 | # check parallel edged 121 | if self.A.max(axis=None) > 1: 122 | warn = True 123 | msg = f"{msg} has parallel edges," 124 | 125 | # check that d is d-regular 126 | if np.min(self.d) != np.max(self.d): 127 | warn = True 128 | msg = f"{msg} is not d-regular," 129 | 130 | # check that g doesn't contain any self-loop 131 | if self.A.diagonal().any(): 132 | warn = True 133 | msg = f"{msg} has self loop." 134 | 135 | if warn: 136 | self.logger.warning(f"{msg[:-1]}.") 137 | 138 | def _get_extra_repr(self): 139 | return dict(k=self.k, seed=self.seed) 140 | -------------------------------------------------------------------------------- /pygsp/filters/wave.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | 5 | from .filter import Filter # prevent circular import in Python < 3.5 6 | 7 | 8 | class Wave(Filter): 9 | r"""Design a filter bank of wave kernels. 10 | 11 | The wave kernel is defined in the spectral domain as 12 | 13 | .. math:: g_{\tau, t}(\lambda) = \cos \left( t 14 | \arccos \left( 1 - \frac{\tau^2}{2} \lambda \right) \right), 15 | 16 | where :math:`\lambda \in [0, 1]` are the normalized eigenvalues of the 17 | graph Laplacian, :math:`t` is time, and :math:`\tau` is the propagation 18 | speed. 19 | 20 | The wave kernel is the fundamental solution to the wave equation 21 | 22 | .. math:: - \tau^2 L f(t) = \partial_{tt} f(t), 23 | 24 | where :math:`f: \mathbb{R}_+ \rightarrow \mathbb{R}^N` models, for example, 25 | the mechanical displacement of a wave on a graph. Given the initial 26 | condition :math:`f(0)` and assuming a vanishing initial velocity, i.e., the 27 | first derivative in time of the initial distribution equals zero, the 28 | solution of the wave equation is expressed as 29 | 30 | .. math:: f(t) = U g_{\tau, t}(\Lambda) U^\top f(0) 31 | = g_{\tau, t}(L) f(0). 32 | 33 | The above is, by definition, the convolution of the signal :math:`f(0)` 34 | with the kernel :math:`g_{\tau, t}`. 35 | Hence, applying this filter to a signal simulates wave propagation. 36 | 37 | Parameters 38 | ---------- 39 | G : graph 40 | time : float or iterable 41 | Time step. 42 | If iterable, creates a filter bank with one filter per value. 43 | speed : float or iterable 44 | Propagation speed, bounded by 0 (included) and 2 (excluded). 45 | If iterable, creates a filter bank with one filter per value. 46 | 47 | References 48 | ---------- 49 | :cite:`grassi2016timevertex`, :cite:`grassi2018timevertex` 50 | 51 | Examples 52 | -------- 53 | 54 | Filter bank's representation in Fourier and time (ring graph) domains. 55 | 56 | >>> import matplotlib.pyplot as plt 57 | >>> G = graphs.Ring(N=20) 58 | >>> G.estimate_lmax() 59 | >>> G.set_coordinates('line1D') 60 | >>> g = filters.Wave(G, time=[5, 15], speed=1) 61 | >>> s = g.localize(G.N // 2) 62 | >>> fig, axes = plt.subplots(1, 2) 63 | >>> _ = g.plot(ax=axes[0]) 64 | >>> _ = G.plot(s, ax=axes[1]) 65 | 66 | Wave propagation from two sources on a grid. 67 | 68 | >>> import matplotlib.pyplot as plt 69 | >>> n_side = 11 70 | >>> graph = graphs.Grid2d(n_side) 71 | >>> graph.estimate_lmax() 72 | >>> sources = [ 73 | ... (n_side//4 * n_side) + (n_side//4), 74 | ... (n_side*3//4 * n_side) + (n_side*3//4), 75 | ... ] 76 | >>> delta = np.zeros(graph.n_vertices) 77 | >>> delta[sources] = 5 78 | >>> steps = np.array([5, 10]) 79 | >>> g = filters.Wave(graph, time=steps, speed=1) 80 | >>> propagated = g.filter(delta) 81 | >>> fig, axes = plt.subplots(1, len(steps), figsize=(10, 4)) 82 | >>> _ = fig.suptitle('Wave propagation', fontsize=16) 83 | >>> for i, ax in enumerate(axes): 84 | ... _ = graph.plot(propagated[:, i], highlight=sources, 85 | ... title='step {}'.format(steps[i]), ax=ax) 86 | ... ax.set_aspect('equal', 'box') 87 | ... ax.set_axis_off() 88 | 89 | """ 90 | 91 | def __init__(self, G, time=10, speed=1): 92 | try: 93 | iter(time) 94 | except TypeError: 95 | time = [time] 96 | try: 97 | iter(speed) 98 | except TypeError: 99 | speed = [speed] 100 | 101 | self.time = time 102 | self.speed = speed 103 | 104 | if len(time) != len(speed): 105 | if len(speed) == 1: 106 | speed = speed * len(time) 107 | elif len(time) == 1: 108 | time = time * len(speed) 109 | else: 110 | raise ValueError( 111 | "If both parameters are iterable, " 112 | "they should have the same length." 113 | ) 114 | 115 | if np.any(np.asanyarray(speed) >= 2): 116 | raise ValueError("The wave propagation speed should be in [0, 2[") 117 | 118 | def kernel(x, time, speed): 119 | return np.cos(time * np.arccos(1 - speed**2 * x / G.lmax / 2)) 120 | 121 | kernels = [partial(kernel, time=t, speed=s) for t, s in zip(time, speed)] 122 | 123 | super().__init__(G, kernels) 124 | 125 | def _get_extra_repr(self): 126 | time = "[" + ", ".join(f"{t:.2f}" for t in self.time) + "]" 127 | speed = "[" + ", ".join(f"{s:.2f}" for s in self.speed) + "]" 128 | return dict(time=time, speed=speed) 129 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main, master] 6 | pull_request: 7 | branches: [ main, master] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: true 14 | matrix: 15 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Set up conda 21 | uses: conda-incubator/setup-miniconda@v3 22 | with: 23 | auto-update-conda: true 24 | python-version: ${{ matrix.python-version }} 25 | channels: conda-forge,defaults 26 | channel-priority: strict 27 | 28 | - name: Install system dependencies 29 | run: | 30 | # Install system packages for Qt6 and OpenGL 31 | sudo apt-get update 32 | sudo apt-get install -y \ 33 | qt6-base-dev \ 34 | libqt6gui6 \ 35 | libqt6widgets6 \ 36 | libqt6opengl6-dev \ 37 | xvfb \ 38 | libegl1-mesa-dev \ 39 | libgl1-mesa-dev \ 40 | libglu1-mesa-dev \ 41 | freeglut3-dev \ 42 | libxkbcommon-x11-0 \ 43 | libxcb-icccm4 \ 44 | libxcb-image0 \ 45 | libxcb-keysyms1 \ 46 | libxcb-randr0 \ 47 | libxcb-render-util0 \ 48 | libxcb-shape0 \ 49 | libxcb-xfixes0 \ 50 | libxcb-xinerama0 51 | 52 | - name: Install graph-tool via conda 53 | shell: bash -l {0} 54 | run: | 55 | conda install -c conda-forge graph-tool 56 | 57 | - name: Verify graph-tool installation 58 | shell: bash -l {0} 59 | run: | 60 | python -c "import graph_tool.all as gt; print('✅ graph-tool imported successfully')" 61 | 62 | - name: Cache pip packages 63 | uses: actions/cache@v4 64 | with: 65 | path: ~/.cache/pip 66 | key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} 67 | restore-keys: | 68 | ${{ runner.os }}-pip-${{ matrix.python-version }}- 69 | ${{ runner.os }}-pip- 70 | 71 | - name: Install Python dependencies 72 | shell: bash -l {0} 73 | run: | 74 | python -m pip install --upgrade pip setuptools wheel 75 | pip install --upgrade --upgrade-strategy eager .[dev] 76 | 77 | - name: Set up virtual display 78 | run: | 79 | export DISPLAY=:99 80 | Xvfb :99 -screen 0 800x600x24 > /dev/null 2>&1 & 81 | echo "DISPLAY=:99" >> $GITHUB_ENV 82 | echo "QT_QPA_PLATFORM=offscreen" >> $GITHUB_ENV 83 | echo "QT_QPA_FONTDIR=/usr/share/fonts" >> $GITHUB_ENV 84 | echo "MESA_GL_VERSION_OVERRIDE=3.3" >> $GITHUB_ENV 85 | echo "MESA_GLSL_VERSION_OVERRIDE=330" >> $GITHUB_ENV 86 | 87 | # - name: Run linting 88 | # shell: bash -l {0} 89 | # run: | 90 | # flake8 --doctests --exclude=doc 91 | 92 | - name: Run tests with coverage 93 | shell: bash -l {0} 94 | run: | 95 | export MPLBACKEND=agg 96 | coverage run --branch --source pygsp -m pytest 97 | coverage report 98 | coverage xml 99 | 100 | - name: Upload coverage to Coveralls 101 | if: matrix.python-version == '3.11' # Only upload once 102 | uses: coverallsapp/github-action@v2 103 | with: 104 | github-token: ${{ secrets.GITHUB_TOKEN }} 105 | file: coverage.xml 106 | 107 | docs: 108 | runs-on: ubuntu-latest 109 | needs: test 110 | 111 | steps: 112 | - uses: actions/checkout@v4 113 | 114 | - name: Set up conda 115 | uses: conda-incubator/setup-miniconda@v3 116 | with: 117 | auto-update-conda: true 118 | python-version: "3.11" 119 | channels: conda-forge,defaults 120 | channel-priority: strict 121 | 122 | - name: Install graph-tool via conda 123 | shell: bash -l {0} 124 | run: | 125 | conda install -c conda-forge graph-tool 126 | 127 | - name: Verify graph-tool installation 128 | shell: bash -l {0} 129 | run: | 130 | python -c "import graph_tool.all as gt; print('✅ graph-tool imported successfully')" 131 | 132 | - name: Install Python dependencies 133 | shell: bash -l {0} 134 | run: | 135 | python -m pip install --upgrade pip setuptools wheel 136 | pip install --upgrade --upgrade-strategy eager .[dev,doc] 137 | 138 | - name: Build documentation 139 | shell: bash -l {0} 140 | run: | 141 | export MPLBACKEND=agg 142 | sphinx-build -b html -d doc/_build/doctrees doc doc/_build/html 143 | sphinx-build -b linkcheck -d doc/_build/doctrees doc doc/_build/linkcheck 144 | 145 | - name: Upload documentation artifacts 146 | uses: actions/upload-artifact@v4 147 | with: 148 | name: documentation 149 | path: doc/_build/html/ 150 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | Contributions are welcome, and they are greatly appreciated! The development of 6 | this package takes place on `GitHub `_. 7 | Issues, bugs, and feature requests should be reported `there 8 | `_. 9 | Code and documentation can be improved by submitting a `pull request 10 | `_. Please add documentation and 11 | tests for any new code. 12 | 13 | The package can be set up (ideally in a fresh virtual environment) for local 14 | development with the following:: 15 | 16 | $ git clone https://github.com/epfl-lts2/pygsp.git 17 | $ cd pygsp 18 | $ make install 19 | 20 | The ``make install`` command (which runs ``uv sync --all-extras``) ensures that 21 | all dependencies required for development (to run the test suite and build the 22 | documentation) are installed. Only `graph-tool `_ 23 | will be missing: install it manually as it cannot be installed by uv. 24 | 25 | You can improve or add functionality in the ``pygsp`` folder, along with 26 | corresponding unit tests in ``pygsp/tests/test_*.py`` (with reasonable 27 | coverage). 28 | If you have a nice example to demonstrate the use of the introduced 29 | functionality, please consider adding a tutorial in ``doc/tutorials`` or a 30 | short example in ``examples``. 31 | 32 | Update ``README.rst`` and ``CHANGELOG.rst`` if applicable. 33 | 34 | After making any change, please check the style, run the tests, and build the 35 | documentation with the following (enforced by GitHub Actions):: 36 | 37 | $ make lint 38 | $ make test 39 | $ make doc 40 | 41 | Check the generated coverage report at ``htmlcov/index.html`` to make sure the 42 | tests reasonably cover the changes you've introduced. 43 | 44 | To iterate faster, you can partially run the test suite, at various degrees of 45 | granularity, as follows:: 46 | 47 | $ python -m pytest pygsp/tests/test_docstrings.py 48 | $ python -m pytest pygsp/tests/test_graphs.py::TestGraphs 49 | $ python -m pytest pygsp/tests/test_graphs.py::test_save_load 50 | 51 | Making a release 52 | ---------------- 53 | 54 | #. Update the version number and release date in ``setup.py``, 55 | ``pygsp/__init__.py`` and ``CHANGELOG.rst``. 56 | #. Create a git tag with ``git tag -a v0.5.0 -m "PyGSP v0.5.0"``. 57 | #. Push the tag to GitHub with ``git push github v0.5.0``. The tag should now 58 | appear in the releases and tags tab. 59 | #. `Create a release `_ on 60 | GitHub and select the created tag. A DOI should then be issued by Zenodo. 61 | #. Go on Zenodo and fix the metadata if necessary. 62 | #. Build the distribution with ``make dist`` and check that the 63 | ``dist/PyGSP-0.5.0.tar.gz`` source archive contains all required files. The 64 | binary wheel should be found as ``dist/PyGSP-0.5.0-py2.py3-none-any.whl``. 65 | #. Test the upload and installation process:: 66 | 67 | $ uv publish --publish-url https://test.pypi.org/legacy/ dist/* 68 | $ pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple pygsp 69 | 70 | Log in as the LTS2 user. 71 | #. Build and upload the distribution to the real PyPI with ``make release``. 72 | #. Update the conda feedstock (at least the version number and sha256 in 73 | ``recipe/meta.yaml``) by sending a PR to 74 | `conda-forge `_. 75 | 76 | Repository organization 77 | ----------------------- 78 | 79 | :: 80 | 81 | LICENSE.txt Project license 82 | *.rst Important documentation 83 | Makefile Targets for make 84 | setup.py Meta information about package (published on PyPI) 85 | .gitignore Files ignored by the git revision control system 86 | .github/workflows/ Defines testing on GitHub Actions continuous integration 87 | 88 | pygsp/ Contains the modules (the actual toolbox implementation) 89 | __init.py__ Load modules at package import 90 | *.py One file per module 91 | 92 | pygsp/tests/ Contains the test suites (will be distributed to end user) 93 | __init.py__ Load modules at package import 94 | test_*.py One test suite per module 95 | test_docstrings.py Test the examples in the docstrings (reference doc) 96 | test_tutorials.py Test the tutorials in doc/tutorials 97 | test_all.py Launch all the tests (docstrings, tutorials, modules) 98 | 99 | doc/ Package documentation 100 | conf.py Sphinx configuration 101 | index.rst Documentation entry page 102 | *.rst Include doc files from root directory 103 | 104 | doc/reference/ Reference documentation 105 | index.rst Reference entry page 106 | *.rst Only directives, the actual doc is alongside the code 107 | 108 | doc/tutorials/ 109 | index.rst Tutorials entry page 110 | *.rst One file per tutorial 111 | -------------------------------------------------------------------------------- /pygsp/tests/test_learning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test suite for the learning module of the pygsp package. 3 | 4 | """ 5 | 6 | import numpy as np 7 | 8 | from pygsp import filters, graphs, learning 9 | 10 | 11 | def test_regression_tikhonov_trivial(): 12 | """Solve a trivial regression problem.""" 13 | G = graphs.Ring(N=8) 14 | signal = np.array([0, np.nan, 4, np.nan, 4, np.nan, np.nan, np.nan]) 15 | signal_bak = signal.copy() 16 | mask = np.array([True, False, True, False, True, False, False, False]) 17 | truth = np.array([0, 2, 4, 4, 4, 3, 2, 1]) 18 | recovery = learning.regression_tikhonov(G, signal, mask, tau=0) 19 | np.testing.assert_allclose(recovery, truth) 20 | 21 | # Test the numpy solution. 22 | G = graphs.Graph(G.W.toarray()) 23 | recovery = learning.regression_tikhonov(G, signal, mask, tau=0) 24 | np.testing.assert_allclose(recovery, truth) 25 | np.testing.assert_allclose(signal_bak, signal) 26 | 27 | 28 | def test_regression_tikhonov_constrained(): 29 | """Solve a regression problem with a constraint.""" 30 | G = graphs.Sensor(100) 31 | G.estimate_lmax() 32 | 33 | # Create a smooth signal. 34 | filt = filters.Filter(G, lambda x: 1 / (1 + 10 * x)) 35 | rng = np.random.default_rng(1) 36 | signal = filt.analyze(rng.normal(size=(G.n_vertices, 5))) 37 | 38 | # Make the input signal. 39 | mask = rng.uniform(0, 1, G.n_vertices) > 0.5 40 | measures = signal.copy() 41 | measures[~mask] = np.nan 42 | measures_bak = measures.copy() 43 | 44 | # Solve the problem. 45 | recovery0 = learning.regression_tikhonov(G, measures, mask, tau=0) 46 | np.testing.assert_allclose(measures_bak, measures) 47 | 48 | recovery1 = np.zeros_like(recovery0) 49 | for i in range(recovery0.shape[1]): 50 | recovery1[:, i] = learning.regression_tikhonov(G, measures[:, i], mask, tau=0) 51 | np.testing.assert_allclose(measures_bak, measures) 52 | 53 | G = graphs.Graph(G.W.toarray()) 54 | recovery2 = learning.regression_tikhonov(G, measures, mask, tau=0) 55 | recovery3 = np.zeros_like(recovery0) 56 | for i in range(recovery0.shape[1]): 57 | recovery3[:, i] = learning.regression_tikhonov(G, measures[:, i], mask, tau=0) 58 | 59 | np.testing.assert_allclose(recovery1, recovery0) 60 | np.testing.assert_allclose(recovery2, recovery0) 61 | np.testing.assert_allclose(recovery3, recovery0) 62 | np.testing.assert_allclose(measures_bak, measures) 63 | 64 | 65 | def test_regression_tikhonov_relaxed(tau=3.5): 66 | """Solve a relaxed regression problem.""" 67 | G = graphs.Sensor(100) 68 | G.estimate_lmax() 69 | 70 | # Create a smooth signal. 71 | filt = filters.Filter(G, lambda x: 1 / (1 + 10 * x)) 72 | rng = np.random.default_rng(1) 73 | signal = filt.analyze(rng.normal(size=(G.n_vertices, 6))) 74 | 75 | # Make the input signal. 76 | mask = rng.uniform(0, 1, G.n_vertices) > 0.5 77 | measures = signal.copy() 78 | measures[~mask] = 18 79 | measures_bak = measures.copy() 80 | 81 | L = G.L.toarray() 82 | recovery = np.matmul( 83 | np.linalg.inv(np.diag(1 * mask) + tau * L), (mask * measures.T).T 84 | ) 85 | 86 | # Solve the problem. 87 | recovery0 = learning.regression_tikhonov(G, measures, mask, tau=tau) 88 | np.testing.assert_allclose(measures_bak, measures) 89 | recovery1 = np.zeros_like(recovery0) 90 | for i in range(recovery0.shape[1]): 91 | recovery1[:, i] = learning.regression_tikhonov(G, measures[:, i], mask, tau) 92 | np.testing.assert_allclose(measures_bak, measures) 93 | 94 | G = graphs.Graph(G.W.toarray()) 95 | recovery2 = learning.regression_tikhonov(G, measures, mask, tau) 96 | recovery3 = np.zeros_like(recovery0) 97 | for i in range(recovery0.shape[1]): 98 | recovery3[:, i] = learning.regression_tikhonov(G, measures[:, i], mask, tau) 99 | 100 | np.testing.assert_allclose(recovery0, recovery, atol=1e-5) 101 | np.testing.assert_allclose(recovery1, recovery, atol=1e-5) 102 | np.testing.assert_allclose(recovery2, recovery, atol=1e-5) 103 | np.testing.assert_allclose(recovery3, recovery, atol=1e-5) 104 | np.testing.assert_allclose(measures_bak, measures) 105 | 106 | 107 | def test_classification_tikhonov(): 108 | """Solve a classification problem.""" 109 | G = graphs.Logo() 110 | signal = np.zeros([G.n_vertices], dtype=int) 111 | signal[G.info["idx_s"]] = 1 112 | signal[G.info["idx_p"]] = 2 113 | 114 | # Make the input signal. 115 | rng = np.random.default_rng(2) 116 | mask = rng.uniform(size=G.n_vertices) > 0.3 117 | 118 | measures = signal.copy() 119 | measures[~mask] = -1 120 | measures_bak = measures.copy() 121 | 122 | # Solve the classification problem. 123 | recovery = learning.classification_tikhonov(G, measures, mask, tau=0) 124 | recovery = np.argmax(recovery, axis=1) 125 | 126 | np.testing.assert_array_equal(recovery, signal) 127 | 128 | # Test the function with the simplex projection. 129 | recovery = learning.classification_tikhonov_simplex(G, measures, mask, tau=0.1) 130 | 131 | # Assert that the probabilities sums to 1 132 | np.testing.assert_allclose(np.sum(recovery, axis=1), 1) 133 | 134 | # Check the quality of the solution. 135 | recovery = np.argmax(recovery, axis=1) 136 | np.testing.assert_allclose(signal, recovery) 137 | np.testing.assert_allclose(measures_bak, measures) 138 | -------------------------------------------------------------------------------- /doc/tutorials/wavelet.rst: -------------------------------------------------------------------------------- 1 | ======================================= 2 | Introduction to spectral graph wavelets 3 | ======================================= 4 | 5 | This tutorial will show you how to easily construct a wavelet_ frame, a kind of 6 | filter bank, and apply it to a signal. This tutorial will walk you into 7 | computing the wavelet coefficients of a graph, visualizing filters in the 8 | vertex domain, and using the wavelets to estimate the curvature of a 3D shape. 9 | 10 | .. _wavelet: https://en.wikipedia.org/wiki/Wavelet 11 | 12 | As usual, we first have to import some packages. 13 | 14 | .. plot:: 15 | :context: reset 16 | 17 | >>> import numpy as np 18 | >>> import matplotlib.pyplot as plt 19 | >>> from pygsp import graphs, filters, plotting, utils 20 | 21 | Then we can load a graph. The graph we'll use is a nearest-neighbor graph of a 22 | point cloud of the Stanford bunny. It will allow us to get interesting visual 23 | results using wavelets. 24 | 25 | .. plot:: 26 | :context: close-figs 27 | 28 | >>> G = graphs.Bunny() 29 | 30 | .. note:: 31 | At this stage we could compute the Fourier basis using 32 | :meth:`pygsp.graphs.Graph.compute_fourier_basis`, but this would take some 33 | time, and can be avoided with a Chebychev polynomials approximation to 34 | graph filtering. See the documentation of the 35 | :meth:`pygsp.filters.Filter.filter` filtering function and 36 | :cite:`hammond2011wavelets` for details on how it is down. 37 | 38 | Simple filtering: heat diffusion 39 | -------------------------------- 40 | 41 | Before tackling wavelets, let's observe the effect of a single filter on a 42 | graph signal. We first design a few heat kernel filters, each with a different 43 | scale. 44 | 45 | .. plot:: 46 | :context: close-figs 47 | 48 | >>> taus = [10, 25, 50] 49 | >>> g = filters.Heat(G, taus) 50 | 51 | Let's create a signal as a Kronecker delta located on one vertex, e.g. the 52 | vertex 20. That signal is our heat source. 53 | 54 | .. plot:: 55 | :context: close-figs 56 | 57 | >>> s = np.zeros(G.N) 58 | >>> DELTA = 20 59 | >>> s[DELTA] = 1 60 | 61 | We can now simulate heat diffusion by filtering our signal `s` with each of our 62 | heat kernels. 63 | 64 | .. plot:: 65 | :context: close-figs 66 | 67 | >>> s = g.filter(s, method='chebyshev') 68 | 69 | And finally plot the filtered signal showing heat diffusion at different 70 | scales. 71 | 72 | .. plot:: 73 | :context: close-figs 74 | 75 | >>> fig = plt.figure(figsize=(10, 3)) 76 | >>> for i in range(g.Nf): 77 | ... ax = fig.add_subplot(1, g.Nf, i+1, projection='3d') 78 | ... title = r'Heat diffusion, $\tau={}$'.format(taus[i]) 79 | ... _ = G.plot(s[:, i], colorbar=False, title=title, ax=ax) 80 | ... ax.set_axis_off() 81 | >>> fig.tight_layout() 82 | 83 | .. note:: 84 | The :meth:`pygsp.filters.Filter.localize` method can be used to visualize a 85 | filter in the vertex domain instead of doing it manually. 86 | 87 | Visualizing wavelets atoms 88 | -------------------------- 89 | 90 | Let's now replace the Heat filter by a filter bank of wavelets. We can create a 91 | filter bank using one of the predefined filters, such as 92 | :class:`pygsp.filters.MexicanHat` to design a set of `Mexican hat wavelets`_. 93 | 94 | .. _Mexican hat wavelets: 95 | https://en.wikipedia.org/wiki/Mexican_hat_wavelet 96 | 97 | .. plot:: 98 | :context: close-figs 99 | 100 | >>> g = filters.MexicanHat(G, Nf=6) # Nf = 6 filters in the filter bank. 101 | 102 | Then plot the frequency response of those filters. 103 | 104 | .. plot:: 105 | :context: close-figs 106 | 107 | >>> fig, ax = plt.subplots(figsize=(10, 5)) 108 | >>> _ = g.plot(title='Filter bank of mexican hat wavelets', ax=ax) 109 | 110 | .. note:: 111 | We can see that the wavelet atoms are stacked on the low frequency part of 112 | the spectrum. A better coverage could be obtained by adapting the filter 113 | bank with :class:`pygsp.filters.WarpedTranslates` or by using another 114 | filter bank like :class:`pygsp.filters.Itersine`. 115 | 116 | We can visualize the atoms as we did with the heat kernel, by filtering 117 | a Kronecker delta placed at one specific vertex. 118 | 119 | .. plot:: 120 | :context: close-figs 121 | 122 | >>> s = g.localize(DELTA) 123 | >>> 124 | >>> fig = plt.figure(figsize=(10, 2.5)) 125 | >>> for i in range(3): 126 | ... ax = fig.add_subplot(1, 3, i+1, projection='3d') 127 | ... _ = G.plot(s[:, i], title='Wavelet {}'.format(i+1), ax=ax) 128 | ... ax.set_axis_off() 129 | >>> fig.tight_layout() 130 | 131 | Curvature estimation 132 | -------------------- 133 | 134 | As a last and more applied example, let us try to estimate the curvature of the 135 | underlying 3D model by only using spectral filtering on the nearest-neighbor 136 | graph formed by its point cloud. 137 | 138 | A simple way to accomplish that is to use the coordinates map :math:`[x, y, z]` 139 | and filter it using the above defined wavelets. Doing so gives us a 140 | 3-dimensional signal 141 | :math:`[g_i(L)x, g_i(L)y, g_i(L)z], \ i \in [0, \ldots, N_f]` 142 | which describes variation along the 3 coordinates. 143 | 144 | .. plot:: 145 | :context: close-figs 146 | 147 | >>> s = G.coords 148 | >>> s = g.filter(s) 149 | 150 | The curvature is then estimated by taking the :math:`\ell_1` or :math:`\ell_2` 151 | norm across the 3D position. 152 | 153 | .. plot:: 154 | :context: close-figs 155 | 156 | >>> s = np.linalg.norm(s, ord=2, axis=1) 157 | 158 | Let's finally plot the result to observe that we indeed have a measure of the 159 | curvature at different scales. 160 | 161 | .. plot:: 162 | :context: close-figs 163 | 164 | >>> fig = plt.figure(figsize=(10, 7)) 165 | >>> for i in range(4): 166 | ... ax = fig.add_subplot(2, 2, i+1, projection='3d') 167 | ... title = 'Curvature estimation (scale {})'.format(i+1) 168 | ... _ = G.plot(s[:, i], title=title, ax=ax) 169 | ... ax.set_axis_off() 170 | >>> fig.tight_layout() 171 | -------------------------------------------------------------------------------- /pygsp/graphs/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`pygsp.graphs` module implements the graph class hierarchy. A graph 3 | object is either constructed from an adjacency matrix, or by instantiating one 4 | of the built-in graph models. 5 | 6 | Interface 7 | ========= 8 | 9 | The :class:`Graph` base class allows to construct a graph object from any 10 | adjacency matrix and provides a common interface to that object. Derived 11 | classes then allows to instantiate various standard graph models. 12 | 13 | Attributes 14 | ---------- 15 | 16 | **Matrix operators** 17 | 18 | .. autosummary:: 19 | 20 | Graph.W 21 | Graph.L 22 | Graph.U 23 | Graph.D 24 | 25 | **Vectors** 26 | 27 | .. autosummary:: 28 | 29 | Graph.d 30 | Graph.dw 31 | Graph.e 32 | 33 | **Scalars** 34 | 35 | .. autosummary:: 36 | 37 | Graph.lmax 38 | Graph.coherence 39 | 40 | Attributes computation 41 | ---------------------- 42 | 43 | .. autosummary:: 44 | 45 | Graph.compute_laplacian 46 | Graph.estimate_lmax 47 | Graph.compute_fourier_basis 48 | Graph.compute_differential_operator 49 | 50 | Differential operators 51 | ---------------------- 52 | 53 | .. autosummary:: 54 | 55 | Graph.grad 56 | Graph.div 57 | Graph.dirichlet_energy 58 | 59 | Transforms 60 | ---------- 61 | 62 | .. autosummary:: 63 | 64 | Graph.gft 65 | Graph.igft 66 | 67 | Vertex-frequency transforms are implemented as filter banks and are found in 68 | :mod:`pygsp.filters` (such as :class:`~pygsp.filters.Gabor` and 69 | :class:`~pygsp.filters.Modulation`). 70 | 71 | Checks 72 | ------ 73 | 74 | .. autosummary:: 75 | 76 | Graph.is_weighted 77 | Graph.is_connected 78 | Graph.is_directed 79 | Graph.has_loops 80 | 81 | Plotting 82 | -------- 83 | 84 | .. autosummary:: 85 | 86 | Graph.plot 87 | Graph.plot_spectrogram 88 | 89 | Import and export (I/O) 90 | ----------------------- 91 | 92 | We provide import and export facility to two well-known Python packages for 93 | network analysis: NetworkX_ and graph-tool_. 94 | Those packages and the PyGSP are fundamentally different in their goals (graph 95 | analysis versus graph signal analysis) and graph representations (if in the 96 | PyGSP everything is an ndarray, in NetworkX everything is a dictionary). 97 | Those tools are complementary and good interoperability is necessary to exploit 98 | the strengths of each tool. 99 | We ourselves leverage NetworkX and graph-tool to save and load graphs. 100 | 101 | Note: to tie a signal with the graph, such that they are exported together, 102 | attach it first with :meth:`Graph.set_signal`. 103 | 104 | .. _NetworkX: https://networkx.org 105 | .. _graph-tool: https://graph-tool.skewed.de 106 | 107 | .. autosummary:: 108 | 109 | Graph.load 110 | Graph.save 111 | Graph.from_networkx 112 | Graph.to_networkx 113 | Graph.from_graphtool 114 | Graph.to_graphtool 115 | 116 | Others 117 | ------ 118 | 119 | .. autosummary:: 120 | 121 | Graph.get_edge_list 122 | Graph.set_signal 123 | Graph.set_coordinates 124 | Graph.subgraph 125 | Graph.extract_components 126 | 127 | Graph models 128 | ============ 129 | 130 | In addition to the below graphs, useful resources are the random graph 131 | generators from NetworkX (see `NetworkX's documentation`_) and graph-tool (see 132 | :mod:`graph_tool.generation`), as well as graph-tool's assortment of standard 133 | networks (see :mod:`graph_tool.collection`). 134 | Any graph created by NetworkX or graph-tool can be imported in the PyGSP with 135 | :meth:`Graph.from_networkx` and :meth:`Graph.from_graphtool`. 136 | 137 | .. _NetworkX's documentation: https://networkx.org/documentation/stable/reference/generators.html 138 | 139 | Graphs built from other graphs 140 | ------------------------------ 141 | 142 | .. autosummary:: 143 | 144 | LineGraph 145 | 146 | Generated graphs 147 | ---------------- 148 | 149 | .. autosummary:: 150 | 151 | Airfoil 152 | BarabasiAlbert 153 | Comet 154 | Community 155 | DavidSensorNet 156 | ErdosRenyi 157 | FullConnected 158 | Grid2d 159 | Logo 160 | LowStretchTree 161 | Minnesota 162 | Path 163 | RandomRegular 164 | RandomRing 165 | Ring 166 | Star 167 | StochasticBlockModel 168 | SwissRoll 169 | Torus 170 | 171 | Nearest-neighbors graphs constructed from point clouds 172 | ------------------------------------------------------ 173 | 174 | .. autosummary:: 175 | 176 | NNGraph 177 | Bunny 178 | Cube 179 | ImgPatches 180 | Grid2dImgPatches 181 | Sensor 182 | Sphere 183 | TwoMoons 184 | 185 | """ 186 | 187 | from .airfoil import Airfoil # noqa: F401 188 | from .barabasialbert import BarabasiAlbert # noqa: F401 189 | from .comet import Comet # noqa: F401 190 | from .community import Community # noqa: F401 191 | from .davidsensornet import DavidSensorNet # noqa: F401 192 | from .erdosrenyi import ErdosRenyi # noqa: F401 193 | from .fullconnected import FullConnected # noqa: F401 194 | from .graph import Graph # noqa: F401 195 | from .grid2d import Grid2d # noqa: F401 196 | from .linegraph import LineGraph # noqa: F401 197 | from .logo import Logo # noqa: F401 198 | from .lowstretchtree import LowStretchTree # noqa: F401 199 | from .minnesota import Minnesota # noqa: F401 200 | from .nngraphs.bunny import Bunny # noqa: F401 201 | from .nngraphs.cube import Cube # noqa: F401 202 | from .nngraphs.grid2dimgpatches import Grid2dImgPatches # noqa: F401 203 | from .nngraphs.imgpatches import ImgPatches # noqa: F401 204 | from .nngraphs.nngraph import NNGraph # noqa: F401 205 | from .nngraphs.sensor import Sensor # noqa: F401 206 | from .nngraphs.sphere import Sphere # noqa: F401 207 | from .nngraphs.twomoons import TwoMoons # noqa: F401 208 | from .path import Path # noqa: F401 209 | from .randomregular import RandomRegular # noqa: F401 210 | from .randomring import RandomRing # noqa: F401 211 | from .ring import Ring # noqa: F401 212 | from .star import Star # noqa: F401 213 | from .stochasticblockmodel import StochasticBlockModel # noqa: F401 214 | from .swissroll import SwissRoll # noqa: F401 215 | from .torus import Torus # noqa: F401 216 | -------------------------------------------------------------------------------- /pygsp/graphs/stochasticblockmodel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | from pygsp import utils 5 | 6 | from .graph import Graph # prevent circular import in Python < 3.5 7 | 8 | 9 | class StochasticBlockModel(Graph): 10 | r"""Stochastic Block Model (SBM). 11 | 12 | The Stochastic Block Model graph is constructed by connecting nodes with a 13 | probability which depends on the cluster of the two nodes. One can define 14 | the clustering association of each node, denoted by vector z, but also the 15 | probability matrix M. All edge weights are equal to 1. By default, Mii > 16 | Mjk and nodes are uniformly clusterized. 17 | 18 | Parameters 19 | ---------- 20 | N : int 21 | Number of nodes (default is 1024). 22 | k : float 23 | Number of classes (default is 5). 24 | z : array_like 25 | the vector of length N containing the association between nodes and 26 | classes (default is random uniform). 27 | M : array_like 28 | the k by k matrix containing the probability of connecting nodes based 29 | on their class belonging (default using p and q). 30 | p : float or array_like 31 | the diagonal value(s) for the matrix M. If scalar they all have the 32 | same value. Otherwise expect a length k vector (default is p = 0.7). 33 | q : float or array_like 34 | the off-diagonal value(s) for the matrix M. If scalar they all have the 35 | same value. Otherwise expect a k x k matrix, diagonal will be 36 | discarded (default is q = 0.3/k). 37 | directed : bool 38 | Allow directed edges if True (default is False). 39 | self_loops : bool 40 | Allow self loops if True (default is False). 41 | connected : bool 42 | Force the graph to be connected (default is False). 43 | n_try : int or None 44 | Maximum number of trials to get a connected graph. If None, it will try 45 | forever. 46 | seed : int 47 | Seed for the random number generator (for reproducible graphs). 48 | 49 | Examples 50 | -------- 51 | >>> import matplotlib.pyplot as plt 52 | >>> G = graphs.StochasticBlockModel( 53 | ... 100, k=3, p=[0.4, 0.6, 0.3], q=0.02, seed=42) 54 | >>> G.set_coordinates(kind='spring', seed=42) 55 | >>> fig, axes = plt.subplots(1, 2) 56 | >>> _ = axes[0].spy(G.W, markersize=0.8) 57 | >>> _ = G.plot(ax=axes[1]) 58 | 59 | """ 60 | 61 | def __init__( 62 | self, 63 | N=1024, 64 | k=5, 65 | z=None, 66 | M=None, 67 | p=0.7, 68 | q=None, 69 | directed=False, 70 | self_loops=False, 71 | connected=False, 72 | n_try=10, 73 | seed=None, 74 | **kwargs, 75 | ): 76 | self.k = k 77 | self.directed = directed 78 | self.self_loops = self_loops 79 | self.connected = connected 80 | self.n_try = n_try 81 | self.seed = seed 82 | 83 | rng = np.random.default_rng(seed) 84 | 85 | if z is None: 86 | z = rng.integers(0, k, N) 87 | z.sort() # Sort for nice spy plot of W, where blocks are apparent. 88 | self.z = z 89 | 90 | if M is None: 91 | self.p = p 92 | p = np.asanyarray(p) 93 | if p.size == 1: 94 | p = p * np.ones(k) 95 | if p.shape != (k,): 96 | raise ValueError( 97 | "Optional parameter p is neither a scalar " 98 | "nor a vector of length k." 99 | ) 100 | 101 | if q is None: 102 | q = 0.3 / k 103 | self.q = q 104 | q = np.asanyarray(q) 105 | if q.size == 1: 106 | q = q * np.ones((k, k)) 107 | if q.shape != (k, k): 108 | raise ValueError( 109 | "Optional parameter q is neither a scalar " 110 | "nor a matrix of size k x k." 111 | ) 112 | 113 | M = q 114 | M.flat[:: k + 1] = p # edit the diagonal terms 115 | 116 | self.M = M 117 | 118 | if (M < 0).any() or (M > 1).any(): 119 | raise ValueError("Probabilities should be in [0, 1].") 120 | 121 | # TODO: higher memory, lesser computation alternative. 122 | # Along the lines of np.random.uniform(size=(N, N)) < p. 123 | # Or similar to sparse.random(N, N, p, data_rvs=lambda n: np.ones(n)). 124 | 125 | while (n_try is None) or (n_try > 0): 126 | nb_row, nb_col = 0, 0 127 | csr_data, csr_i, csr_j = [], [], [] 128 | for _ in range(N**2): 129 | if nb_row != nb_col or self_loops: 130 | if nb_row >= nb_col or directed: 131 | if rng.uniform() < M[z[nb_row], z[nb_col]]: 132 | csr_data.append(1) 133 | csr_i.append(nb_row) 134 | csr_j.append(nb_col) 135 | if nb_row < N - 1: 136 | nb_row += 1 137 | else: 138 | nb_row = 0 139 | nb_col += 1 140 | 141 | W = sparse.csr_matrix((csr_data, (csr_i, csr_j)), shape=(N, N)) 142 | 143 | if not directed: 144 | W = utils.symmetrize(W, method="tril") 145 | 146 | if not connected: 147 | break 148 | if Graph(W).is_connected(): 149 | break 150 | if n_try is not None: 151 | n_try -= 1 152 | if connected and n_try == 0: 153 | raise ValueError( 154 | "The graph could not be connected after {} " 155 | "trials. Increase the connection probability " 156 | "or the number of trials.".format(self.n_try) 157 | ) 158 | 159 | self.info = { 160 | "node_com": z, 161 | "comm_sizes": np.bincount(z), 162 | "world_rad": np.sqrt(N), 163 | } 164 | 165 | super().__init__(W, **kwargs) 166 | 167 | def _get_extra_repr(self): 168 | attrs = {"k": self.k} 169 | if type(self.p) is float: 170 | attrs["p"] = f"{self.p:.2f}" 171 | if type(self.q) is float: 172 | attrs["q"] = f"{self.q:.2f}" 173 | attrs.update( 174 | { 175 | "directed": self.directed, 176 | "self_loops": self.self_loops, 177 | "connected": self.connected, 178 | "seed": self.seed, 179 | } 180 | ) 181 | return attrs 182 | -------------------------------------------------------------------------------- /pygsp/tests/test_docstrings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test suite for the docstrings of the pygsp package. 3 | 4 | """ 5 | 6 | import doctest 7 | import os 8 | 9 | import pytest 10 | 11 | 12 | def gen_recursive_file(root, ext, exclude_patterns=None): 13 | """Generate files recursively with given extension, excluding certain patterns.""" 14 | exclude_patterns = exclude_patterns or [] 15 | 16 | for root_dir, _, filenames in os.walk(root): 17 | for name in filenames: 18 | if name.lower().endswith(ext): 19 | full_path = os.path.join(root_dir, name) 20 | # Skip files matching exclude patterns 21 | if any(pattern in full_path for pattern in exclude_patterns): 22 | continue 23 | yield full_path 24 | 25 | 26 | @pytest.fixture(scope="session") 27 | def doctest_namespace(): 28 | """Provide global namespace for doctests.""" 29 | import numpy 30 | 31 | import pygsp 32 | 33 | return { 34 | "graphs": pygsp.graphs, 35 | "filters": pygsp.filters, 36 | "utils": pygsp.utils, 37 | "np": numpy, 38 | } 39 | 40 | 41 | def pytest_collection_modifyitems(config, items): 42 | """Close matplotlib figures after doctests to avoid warning and save memory.""" 43 | import pygsp 44 | 45 | def finalize(): 46 | pygsp.plotting.close_all() 47 | 48 | config.add_cleanup(finalize) 49 | 50 | 51 | def test_api_docstrings(): 52 | """Test docstrings from PyGSP API reference.""" 53 | # Setup namespace for doctests 54 | import numpy 55 | 56 | import pygsp 57 | 58 | globs = { 59 | "graphs": pygsp.graphs, 60 | "filters": pygsp.filters, 61 | "utils": pygsp.utils, 62 | "np": numpy, 63 | } 64 | 65 | # Only test PyGSP files, not external dependencies 66 | files = list(gen_recursive_file("pygsp", ".py")) 67 | 68 | failure_count = 0 69 | test_count = 0 70 | failures = [] 71 | 72 | for filename in files: 73 | try: 74 | # Use more permissive doctest options to handle formatting differences 75 | result = doctest.testfile( 76 | filename, 77 | module_relative=False, 78 | globs=globs, 79 | verbose=False, 80 | optionflags=( 81 | doctest.ELLIPSIS 82 | | doctest.NORMALIZE_WHITESPACE 83 | | doctest.IGNORE_EXCEPTION_DETAIL 84 | ), 85 | ) 86 | if result.failed > 0: 87 | failures.append( 88 | f"{filename}: {result.failed}/{result.attempted} failed" 89 | ) 90 | failure_count += result.failed 91 | test_count += result.attempted 92 | except Exception as e: 93 | failures.append(f"{filename}: Error running doctest: {e}") 94 | 95 | # Only fail if there are significant failures (allow a few minor formatting issues) 96 | significant_failures = ( 97 | failure_count > test_count * 0.1 98 | ) # Allow up to 10% failures for formatting 99 | 100 | if significant_failures and failure_count > 0: 101 | failure_details = "\n".join(failures[:10]) # Show first 10 failures 102 | pytest.fail( 103 | "PyGSP docstring tests failed: " 104 | + f"{failure_count}/{test_count}\n{failure_details}" 105 | ) 106 | 107 | 108 | def test_tutorial_docstrings(): 109 | """Test docstrings from PyGSP tutorials only.""" 110 | # Only test PyGSP-specific documentation, exclude external library docs 111 | exclude_patterns = [ 112 | ".venv/", 113 | "site-packages/", 114 | "sklearn/", 115 | "numpy/", 116 | "scipy/", 117 | ] 118 | 119 | tutorial_files = [] 120 | for filename in gen_recursive_file(".", ".rst", exclude_patterns): 121 | # Only include PyGSP tutorial documentation 122 | if "doc/tutorial" in filename or ( 123 | filename.startswith("./doc/") and "tutorial" in filename 124 | ): 125 | tutorial_files.append(filename) 126 | 127 | if not tutorial_files: 128 | pytest.skip("No PyGSP tutorial files found") 129 | 130 | failure_count = 0 131 | test_count = 0 132 | failures = [] 133 | 134 | for filename in tutorial_files: 135 | try: 136 | result = doctest.testfile( 137 | filename, 138 | module_relative=False, 139 | verbose=False, 140 | optionflags=( 141 | doctest.ELLIPSIS 142 | | doctest.NORMALIZE_WHITESPACE 143 | | doctest.IGNORE_EXCEPTION_DETAIL 144 | ), 145 | ) 146 | if result.failed > 0: 147 | failures.append( 148 | f"{filename}: {result.failed}/{result.attempted} failed" 149 | ) 150 | failure_count += result.failed 151 | test_count += result.attempted 152 | except Exception as e: 153 | failures.append(f"{filename}: Error running doctest: {e}") 154 | 155 | # Only fail if there are significant failures 156 | if ( 157 | failure_count > 0 and failure_count > test_count * 0.2 158 | ): # Allow up to 20% failures for tutorials 159 | failure_details = "\n".join(failures[:5]) # Show first 5 failures 160 | pytest.fail( 161 | "PyGSP tutorial docstring tests failed: " 162 | + f"{failure_count}/{test_count}\n{failure_details}" 163 | ) 164 | 165 | 166 | # Optional: Test specific known-good examples 167 | def test_basic_docstring_examples(): 168 | """Test a few basic examples that should always work.""" 169 | import doctest 170 | 171 | import pygsp.graphs 172 | 173 | # Test a simple, known example 174 | example_code = """ 175 | >>> from pygsp import graphs 176 | >>> G = graphs.Ring(N=10) 177 | >>> G.N 178 | 10 179 | >>> G.n_vertices 180 | 10 181 | """ 182 | 183 | # Run this simple test 184 | globs = {"graphs": pygsp.graphs} 185 | parser = doctest.DocTestParser() 186 | examples = parser.get_examples(example_code) 187 | 188 | if examples: 189 | runner = doctest.DocTestRunner(verbose=False) 190 | test = doctest.DocTest( 191 | examples, globs, "basic_test", "basic_test", 0, example_code 192 | ) 193 | result = runner.run(test) 194 | 195 | if result.failed > 0: 196 | pytest.fail( 197 | f"Basic docstring example failed: {result.failed}/{result.attempted}" 198 | ) 199 | -------------------------------------------------------------------------------- /doc/tutorials/optimization.rst: -------------------------------------------------------------------------------- 1 | =========================================================== 2 | Optimization problems: graph TV vs. Tikhonov regularization 3 | =========================================================== 4 | 5 | Description 6 | ----------- 7 | 8 | Modern signal processing often involves solving an optimization problem. Graph signal processing (GSP) consists roughly of working with linear operators defined by a graph (e.g., the graph Laplacian). The setting up of an optimization problem in the graph context is often then simply a matter of identifying which graph-defined linear operator is relevant to be used in a regularization and/or fidelity term. 9 | 10 | This tutorial focuses on the problem of recovering a label signal on a graph from subsampled and noisy data, but most information should be fairly generally applied to other problems as well. 11 | 12 | .. plot:: 13 | :context: reset 14 | 15 | >>> import numpy as np 16 | >>> from pygsp import graphs, plotting 17 | >>> 18 | >>> # Create a random sensor graph 19 | >>> G = graphs.Sensor(N=256, distributed=True, seed=42) 20 | >>> G.compute_fourier_basis() 21 | >>> 22 | >>> # Create label signal 23 | >>> label_signal = np.copysign(np.ones(G.N), G.U[:, 3]) 24 | >>> 25 | >>> fig, ax = G.plot(label_signal) 26 | 27 | The first figure shows a plot of the original label signal, that we wish to recover, on the graph. 28 | 29 | .. plot:: 30 | :context: close-figs 31 | 32 | >>> rng = np.random.default_rng(42) 33 | >>> 34 | >>> # Create the mask 35 | >>> M = rng.uniform(size=G.N) 36 | >>> M = (M > 0.6).astype(float) # Probability of having no label on a vertex. 37 | >>> 38 | >>> # Applying the mask to the data 39 | >>> sigma = 0.1 40 | >>> subsampled_noisy_label_signal = M * (label_signal + sigma * rng.standard_normal(G.N)) 41 | >>> 42 | >>> fig, ax = G.plot(subsampled_noisy_label_signal) 43 | 44 | This figure shows the label signal on the graph after the application of the subsampling mask and the addition of noise. The label of more than half of the vertices has been set to :math:`0`. 45 | 46 | Since the problem is ill-posed, we will use some regularization to reach a solution that is more in tune with what we expect a label signal to look like. We will compare two approaches, but they are both based on measuring local differences on the label signal. Those differences are essentially an edge signal: to each edge we can associate the difference between the label signals of its associated nodes. The linear operator that does such a mapping is called the graph gradient :math:`\nabla_G`, and, fortunately for us, it is available under the :attr:`pygsp.graphs.Graph.D` (for differential) attribute of any :mod:`pygsp.graphs` graph. 47 | 48 | The reason for measuring local differences comes from prior knowledge: we assume label signals don't vary too much locally. The precise measure of such variation is what distinguishes the two regularization approaches we'll use. 49 | 50 | The first one, shown below, is called graph total variation (TV) regularization. The quadratic fidelity term is multiplied by a regularization constant :math:`\gamma` and its goal is to force the solution to stay close to the observed labels :math:`b`. The :math:`\ell_1` norm of the action of the graph gradient is what's called the graph TV. We will see that it promotes piecewise-constant solutions. 51 | 52 | .. math:: \operatorname*{arg\,min}_x \|\nabla_G x\|_1 + \gamma \|Mx-b\|_2^2 53 | 54 | The second approach, called graph Tikhonov regularization, is to use a smooth (differentiable) quadratic regularizer. A consequence of this choice is that the solution will tend to have smoother transitions. The quadratic fidelity term is still the same. 55 | 56 | .. math:: \operatorname*{arg\,min}_x \|\nabla_G x\|_2^2 + \gamma \|Mx-b\|_2^2 57 | 58 | Results and code 59 | ---------------- 60 | 61 | For solving the optimization problems we've assembled, you will need a numerical solver package. This part is implemented in this tutorial with the `pyunlocbox `_, which is based on proximal splitting algorithms. Check also its `documentation `_ for more information about the parameters used here. 62 | 63 | We start with the graph TV regularization. We will use the :class:`pyunlocbox.solvers.mlfbf` solver from :mod:`pyunlocbox`. It is a primal-dual solver, which means for our problem that the regularization term will be written in terms of the dual variable :math:`u = \nabla_G x`, and the graph gradient :math:`\nabla_G` will be passed to the solver as the primal-dual map. The value of :math:`3.0` for the regularization parameter :math:`\gamma` was chosen on the basis of the visual appeal of the returned solution. 64 | 65 | .. plot:: 66 | :context: close-figs 67 | 68 | >>> import pyunlocbox 69 | >>> 70 | >>> # Set the functions in the problem 71 | >>> gamma = 3.0 72 | >>> d = pyunlocbox.functions.dummy() 73 | >>> r = pyunlocbox.functions.norm_l1() 74 | >>> f = pyunlocbox.functions.norm_l2(w=M, y=subsampled_noisy_label_signal, 75 | ... lambda_=gamma) 76 | >>> 77 | >>> # Define the solver 78 | >>> G.compute_differential_operator() 79 | >>> L = G.D.T.toarray() 80 | >>> step = 0.999 / (1 + np.linalg.norm(L)) 81 | >>> solver = pyunlocbox.solvers.mlfbf(L=L, step=step) 82 | >>> 83 | >>> # Solve the problem 84 | >>> x0 = subsampled_noisy_label_signal.copy() 85 | >>> prob1 = pyunlocbox.solvers.solve([d, r, f], solver=solver, 86 | ... x0=x0, rtol=0, maxit=1000) 87 | Solution found after 1000 iterations: 88 | objective function f(sol) = 2.213139e+02 89 | stopping criterion: MAXIT 90 | >>> 91 | >>> fig, ax = G.plot(prob1['sol']) 92 | 93 | This figure shows the label signal recovered by graph total variation regularization. We can confirm here that this sort of regularization does indeed promote piecewise-constant solutions. 94 | 95 | .. plot:: 96 | :context: close-figs 97 | 98 | >>> # Set the functions in the problem 99 | >>> r = pyunlocbox.functions.norm_l2(A=L, tight=False) 100 | >>> 101 | >>> # Define the solver 102 | >>> step = 0.999 / np.linalg.norm(np.dot(L.T, L) + gamma * np.diag(M), 2) 103 | >>> solver = pyunlocbox.solvers.gradient_descent(step=step) 104 | >>> 105 | >>> # Solve the problem 106 | >>> x0 = subsampled_noisy_label_signal.copy() 107 | >>> prob2 = pyunlocbox.solvers.solve([r, f], solver=solver, 108 | ... x0=x0, rtol=0, maxit=1000) 109 | Solution found after 1000 iterations: 110 | objective function f(sol) = 6.422673e+01 111 | stopping criterion: MAXIT 112 | >>> 113 | >>> fig, ax = G.plot(prob2['sol']) 114 | 115 | This last figure shows the label signal recovered by Tikhonov regularization. As expected, the recovered label signal has smoother transitions than the one obtained by graph TV regularization. 116 | -------------------------------------------------------------------------------- /doc/references.bib: -------------------------------------------------------------------------------- 1 | @UNPUBLISHED{gleich, 2 | author={D. Gleich}, 3 | title={The {MatlabBGL} {Matlab} Library}, 4 | note={http://www.cs.purdue.edu/homes/dgleich/packages/matlab{\_}bgl/index.html}, 5 | } 6 | 7 | @BOOK{chung1997spectral, 8 | title = {Spectral Graph Theory}, 9 | publisher = {Vol. 92 of the {CBMS} Regional Conference Series in Mathematics, {American Mathematical Society}}, 10 | year = {1997}, 11 | author = {F. R. K. Chung}, 12 | } 13 | 14 | @BOOK{parlett1998symmetric, 15 | title = {The Symmetric Eigenvalue Problem}, 16 | publisher = {SIAM}, 17 | year = {1998}, 18 | author = {B. N. Parlett}, 19 | } 20 | 21 | @UNPUBLISHED{davis2011ldl, 22 | title={User Guide for {LDL}, a concise sparse {Cholesky} package}, 23 | author={Timothy A. Davis}, 24 | month={Jan.}, 25 | year={2011}, 26 | note={http://www.cise.ufl.edu/research/sparse/ldl/} 27 | } 28 | 29 | @ARTICLE{davis2005ldl, 30 | title={{Algorithm 849: A concise sparse Cholesky factorization package}}, 31 | author={T. A. Davis}, 32 | journal={ACM Trans. Mathem. Software}, 33 | volume={31}, 34 | number={4}, 35 | month={Dec.}, 36 | year={2005}, 37 | pages={587-591}, 38 | } 39 | 40 | @Article{hammond2011wavelets, 41 | author={D. K. Hammond and P. Vandergheynst and R. Gribonval}, 42 | journal={Appl. Comput. Harmon. Anal.}, 43 | title={Wavelets on graphs via spectral graph theory}, 44 | month={Mar.}, 45 | year=2011, 46 | volume={30}, 47 | number={2}, 48 | pages={129-150}, 49 | } 50 | 51 | @article{shuman2016vertexfrequency, 52 | title={Vertex-frequency analysis on graphs}, 53 | author={Shuman, David I and Ricaud, Benjamin and Vandergheynst, Pierre}, 54 | journal={Applied and Computational Harmonic Analysis}, 55 | year={2016}, 56 | } 57 | 58 | @ARTICLE{leonardi2013tight, 59 | author = {N. Leonardi and D. {Van De Ville}}, 60 | title = {Tight Wavelet Frames on Multislice Graphs}, 61 | journal = {IEEE Trans. Signal Process.}, 62 | month={Jul.}, 63 | volume={61}, 64 | number={13}, 65 | year = {2013}, 66 | pages={3357-3367}, 67 | } 68 | 69 | @article{klein1993resistance, 70 | title={Resistance distance}, 71 | author={Klein, Douglas J and Randi{\'c}, M}, 72 | journal={Journal of Mathematical Chemistry}, 73 | volume={12}, 74 | number={1}, 75 | pages={81--95}, 76 | year={1993}, 77 | publisher={Springer} 78 | } 79 | 80 | @article{strang1999dct, 81 | title={The discrete cosine transform}, 82 | author={Strang, Gilbert}, 83 | journal={SIAM review}, 84 | volume={41}, 85 | number={1}, 86 | pages={135--147}, 87 | year={1999}, 88 | publisher={SIAM}, 89 | url={https://sci-hub.se/10.1137/S0036144598336745}, 90 | } 91 | 92 | @article{shuman2013spectrum, 93 | title={Spectrum-adapted tight graph wavelet and vertex-frequency frames}, 94 | author={Shuman, David I and Wiesmeyr, Christoph and Holighaus, Nicki and Vandergheynst, Pierre}, 95 | journal={arXiv preprint arXiv:1311.0897}, 96 | year={2013} 97 | } 98 | 99 | 100 | @article{spielman2011graph, 101 | title={Graph sparsification by effective resistances}, 102 | author={Spielman, Daniel A and Srivastava, Nikhil}, 103 | journal={SIAM Journal on Computing}, 104 | volume={40}, 105 | number={6}, 106 | pages={1913--1926}, 107 | year={2011}, 108 | publisher={SIAM} 109 | } 110 | 111 | @article{rudelson1999random, 112 | title={Random vectors in the isotropic position}, 113 | author={Rudelson, Mark}, 114 | journal={Journal of Functional Analysis}, 115 | volume={164}, 116 | number={1}, 117 | pages={60--72}, 118 | year={1999}, 119 | publisher={Elsevier} 120 | } 121 | 122 | @article{rudelson2007sampling, 123 | title={Sampling from large matrices: An approach through geometric functional analysis}, 124 | author={Rudelson, Mark and Vershynin, Roman}, 125 | journal={Journal of the ACM (JACM)}, 126 | volume={54}, 127 | number={4}, 128 | pages={21}, 129 | year={2007}, 130 | publisher={ACM} 131 | } 132 | 133 | 134 | @article{dorfler2013kron, 135 | title={Kron reduction of graphs with applications to electrical networks}, 136 | author={Dorfler, Florian and Bullo, Francesco}, 137 | journal={Circuits and Systems I: Regular Papers, IEEE Transactions on}, 138 | volume={60}, 139 | number={1}, 140 | pages={150--163}, 141 | year={2013}, 142 | publisher={IEEE} 143 | } 144 | 145 | @article{pesenson2009variational, 146 | title={Variational splines and Paley--Wiener spaces on combinatorial graphs}, 147 | author={Pesenson, Isaac}, 148 | journal={Constructive Approximation}, 149 | volume={29}, 150 | number={1}, 151 | pages={1--21}, 152 | year={2009}, 153 | publisher={Springer} 154 | } 155 | 156 | 157 | @article{shuman2013framework, 158 | title={A framework for multiscale transforms on graphs}, 159 | author={Shuman, David I and Faraji, Mohammad Javad and Vandergheynst, Pierre}, 160 | journal={arXiv preprint arXiv:1308.4942}, 161 | year={2013} 162 | } 163 | 164 | @inproceedings{turk1994zippered, 165 | title={Zippered polygon meshes from range images}, 166 | author={Turk, Greg and Levoy, Marc}, 167 | booktitle={Proceedings of the 21st annual conference on Computer graphics and interactive techniques}, 168 | pages={311--318}, 169 | year={1994}, 170 | organization={ACM} 171 | } 172 | 173 | @ARTICLE{perraudin2014gspbox, 174 | author = {{Perraudin}, Nathana{\"e}l and {Paratte}, Johan and {Shuman}, David and {Kalofolias}, Vassilis and 175 | {Vandergheynst}, Pierre and {Hammond}, David K. }, 176 | title = "{GSPBOX: A toolbox for signal processing on graphs}", 177 | journal = {ArXiv e-prints}, 178 | archivePrefix = "arXiv", 179 | eprint = {1408.5781}, 180 | primaryClass = "cs.IT", 181 | keywords = {Computer Science - Information Theory}, 182 | year = 2014, 183 | month = aug, 184 | adsurl = {http://arxiv.org/abs/1408.5781}, 185 | } 186 | 187 | @article{tremblay2016compressive, 188 | title={Compressive spectral clustering}, 189 | author={Tremblay, Nicolas and Puy, Gilles and Gribonval, R{\'e}mi and Vandergheynst, Pierre}, 190 | journal={arXiv preprint arXiv:1602.02018}, 191 | year={2016} 192 | } 193 | 194 | @inproceedings{kim2003randomregulargraphs, 195 | title={Generating random regular graphs}, 196 | author={Kim, Jeong Han and Vu, Van H}, 197 | booktitle={Proceedings of the thirty-fifth annual ACM symposium on Theory of computing}, 198 | year={2003} 199 | } 200 | 201 | @inproceedings{leonardi2011wavelet, 202 | title={Wavelet frames on graphs defined by fMRI functional connectivity}, 203 | author={Leonardi, Nora and Van De Ville, Dimitri}, 204 | booktitle={Biomedical Imaging: From Nano to Macro, 2011 IEEE International Symposium on}, 205 | pages={2136--2139}, 206 | year={2011}, 207 | organization={IEEE} 208 | } 209 | 210 | @inproceedings{grassi2016timevertex, 211 | title={Tracking time-vertex propagation using dynamic graph wavelets}, 212 | author={Grassi, Francesco and Perraudin, Nathanael and Ricaud, Benjamin}, 213 | year={2016}, 214 | booktitle={Signal and Information Processing (GlobalSIP), 2016 IEEE Global Conference on}, 215 | } 216 | 217 | @article{grassi2018timevertex, 218 | title={A time-vertex signal processing framework: Scalable processing and meaningful representations for time-series on graphs}, 219 | author={Grassi, Francesco and Loukas, Andreas and Perraudin, Nathanael and Ricaud, Benjamin}, 220 | year={2018}, 221 | journal={IEEE Transactions on Signal Processing}, 222 | } 223 | -------------------------------------------------------------------------------- /pygsp/filters/modulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import interpolate 3 | 4 | from .filter import Filter # prevent circular import in Python < 3.5 5 | 6 | 7 | class Modulation(Filter): 8 | r"""Design a filter bank with a kernel centered at each frequency. 9 | 10 | Design a filter bank from translated versions of a mother filter. 11 | The mother filter is translated to each eigenvalue of the Laplacian via 12 | modulation. A signal is modulated by multiplying it with an eigenvector. 13 | Similarly to localization, it is an element-wise multiplication of a kernel 14 | with the columns of :attr:`pygsp.graphs.Graph.U`, i.e., the eigenvectors, 15 | in the vertex domain. 16 | 17 | This filter bank can be used to compute the frequency content of a signal 18 | at each vertex. After filtering, one obtains a vertex-frequency 19 | representation :math:`Sf(i,k)` of a signal :math:`f` as 20 | 21 | .. math:: Sf(i, k) = \langle g_{i,k}, f \rangle, 22 | 23 | where :math:`g_{i,k}` is the mother kernel modulated in the spectral domain 24 | by the eigenvector :math:`u_k`, and localized on vertex :math:`v_i`. 25 | 26 | While :math:`g_{i,k}` should ideally be localized in both the spectral and 27 | vertex domains, that is impossible for some graphs due to the localization 28 | of some eigenvectors. See :attr:`pygsp.graphs.Graph.coherence`. 29 | 30 | As modulation and localization don't commute, one can define the frame as 31 | :math:`g_{i,k} = T_i M_k g` (modulation first) or :math:`g_{i,k} = M_k T_i 32 | g` (localization first). Localization first usually gives better results. 33 | When localizing first, the obtained vertex-frequency representation is a 34 | generalization to graphs of the windowed graph Fourier transform. Indeed, 35 | 36 | .. math:: Sf(i, k) = \langle f^\text{win}_i, u_k \rangle 37 | 38 | is the graph Fourier transform of the windowed signal :math:`f^\text{win}`. 39 | The signal :math:`f` is windowed in the vertex domain by a point-wise 40 | multiplication with the localized kernel :math:`T_i g`. 41 | 42 | When localizing first, the spectral representation of the filter bank is 43 | different for every localization. As such, we always evaluate the filter in 44 | the spectral domain with modulation first. Moreover, the filter bank is 45 | only defined at the eigenvalues (as modulation is done with discrete 46 | eigenvectors). Evaluating it elsewhere returns NaNs. 47 | 48 | Parameters 49 | ---------- 50 | graph : :class:`pygsp.graphs.Graph` 51 | kernel : :class:`pygsp.filters.Filter` 52 | Kernel function to be modulated. 53 | modulation_first : bool 54 | First modulate then localize the kernel if True, first localize then 55 | modulate if False. The two operators do not commute. This setting only 56 | applies to :meth:`filter`. :meth:`evaluate` only performs modulation, 57 | as the filter would otherwise have a different spectrum depending on 58 | where it is localized. 59 | 60 | See Also 61 | -------- 62 | Gabor : Another way to translate a filter in the spectral domain. 63 | 64 | Notes 65 | ----- 66 | The eigenvalues of the graph Laplacian (i.e., the Fourier basis) are needed 67 | to modulate the kernels. 68 | 69 | References 70 | ---------- 71 | See :cite:`shuman2016vertexfrequency` for details on this vertex-frequency 72 | representation of graph signals. 73 | 74 | Examples 75 | -------- 76 | 77 | Vertex-frequency representations. 78 | Modulating first doesn't produce sufficiently localized filters. 79 | 80 | >>> import matplotlib.pyplot as plt 81 | >>> G = graphs.Path(90) 82 | >>> G.compute_fourier_basis() 83 | >>> 84 | >>> # Design the filter banks. 85 | >>> g = filters.Heat(G, 500) 86 | >>> g1 = filters.Modulation(G, g, modulation_first=False) 87 | >>> g2 = filters.Modulation(G, g, modulation_first=True) 88 | >>> _ = g1.plot(sum=False, labels=False) 89 | >>> 90 | >>> # Signal. 91 | >>> s = np.empty(G.N) 92 | >>> s[:30] = G.U[:30, 10] 93 | >>> s[30:60] = G.U[30:60, 60] 94 | >>> s[60:] = G.U[60:, 30] 95 | >>> G.set_coordinates('line1D') 96 | >>> _ = G.plot(s) 97 | >>> 98 | >>> # Filter with both filter banks. 99 | >>> s1 = g1.filter(s) 100 | >>> s2 = g2.filter(s) 101 | >>> 102 | >>> # Visualize the vertex-frequency representation of the signal. 103 | >>> fig, axes = plt.subplots(1, 2) 104 | >>> _ = axes[0].imshow(np.abs(s1.T)**2) 105 | >>> _ = axes[1].imshow(np.abs(s2.T)**2) 106 | >>> _ = axes[0].set_title('localization then modulation') 107 | >>> _ = axes[1].set_title('modulation then localization') 108 | >>> ticks = [0, G.N//2, G.N-1] 109 | >>> labels = ['{:.1f}'.format(e) for e in G.e[ticks]] 110 | >>> _ = axes[0].set_yticks(ticks) 111 | >>> _ = axes[1].set_yticks([]) 112 | >>> _ = axes[0].set_yticklabels(labels) 113 | >>> _ = axes[0].set_ylabel('graph frequency') 114 | >>> _ = axes[0].set_xlabel('node') 115 | >>> _ = axes[1].set_xlabel('node') 116 | >>> _ = axes[0].set_xticks(ticks) 117 | >>> _ = axes[1].set_xticks(ticks) 118 | >>> fig.tight_layout() 119 | >>> 120 | >>> # Reconstruction. 121 | >>> s = g2.filter(s2) 122 | >>> _ = G.plot(s) 123 | 124 | """ 125 | 126 | def __init__(self, graph, kernel, modulation_first=False): 127 | self.G = graph 128 | self._kernels = kernel 129 | self._modulation_first = modulation_first 130 | 131 | if kernel.n_filters != 1: 132 | raise ValueError( 133 | "A kernel must be one filter. The passed " 134 | "filter bank {} has {}.".format(kernel, kernel.n_filters) 135 | ) 136 | if kernel.G is not graph: 137 | raise ValueError( 138 | "The graph passed to this filter bank must " 139 | "be the one used to build the mother kernel." 140 | ) 141 | 142 | self.n_features_in, self.n_features_out = (1, graph.n_vertices) 143 | self.n_filters = self.n_features_in * self.n_features_out 144 | self.Nf = self.n_filters # TODO: kept for backward compatibility only. 145 | 146 | def evaluate(self, x): 147 | """TODO: will become _evaluate once polynomial filtering is merged.""" 148 | 149 | if not hasattr(self, "_coefficients"): 150 | # Graph Fourier transform -> modulation -> inverse GFT. 151 | c = self.G.igft(self._kernels.evaluate(self.G.e).squeeze()) 152 | c = np.sqrt(self.G.n_vertices) * self.G.U * c[:, np.newaxis] 153 | self._coefficients = self.G.gft(c) 154 | 155 | shape = x.shape 156 | x = x.flatten() 157 | y = np.full((self.n_features_out, x.size), np.nan) 158 | for i in range(len(x)): 159 | query = self._coefficients[x[i] == self.G.e] 160 | if len(query) != 0: 161 | y[:, i] = query[0] 162 | return y.reshape((self.n_features_out,) + shape) 163 | 164 | def filter(self, s, method="exact", order=None): 165 | """TODO: indirection will be removed when poly filtering is merged. 166 | TODO: with _filter and shape handled in Filter.filter, synthesis will work. 167 | """ 168 | if self._modulation_first: 169 | return super().filter(s, method="exact") 170 | else: 171 | # The dot product with each modulated kernel is equivalent to the 172 | # GFT, as for the localization and the IGFT. 173 | y = np.empty((self.G.n_vertices, self.G.n_vertices)) 174 | for i in range(self.G.n_vertices): 175 | x = s * self._kernels.localize(i) 176 | y[i] = np.sqrt(self.G.n_vertices) * self.G.gft(x) 177 | return y 178 | --------------------------------------------------------------------------------