├── .gitignore ├── .pre-commit-config.yaml ├── AUTHORS.md ├── LICENSE ├── Makefile ├── README.md ├── cpp ├── fdct2d_wrapper.cpp └── fdct3d_wrapper.cpp ├── curvelops ├── __init__.py ├── curvelops.py ├── plot │ ├── __init__.py │ ├── _curvelet.py │ └── _generic.py ├── typing │ ├── __init__.py │ └── _typing.py └── utils │ ├── __init__.py │ └── _utils.py ├── docs └── .nojekyll ├── docssrc ├── Makefile └── source │ ├── conf.py │ ├── contributing.rst │ ├── index.rst │ ├── installation.rst │ ├── modules.rst │ └── static │ ├── demo.png │ ├── logo.png │ └── reconstruction.png ├── examples ├── README.rst ├── plot_curvelets_in_fk.py ├── plot_seismic_regularization.py ├── plot_sigmoid.py ├── plot_sigmoid_coefficients.py ├── plot_sigmoid_disks.py └── plot_single_curvelet.py ├── notebooks ├── Desmystifying_Curvelets.ipynb └── Single_Curvelet_Interactive.ipynb ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── testdata ├── python.png ├── seismic.npz └── sigmoid.npz └── tests ├── __init__.py ├── test_fdct.py ├── test_fdct2d_wrapper.py ├── test_fdct3d_wrapper.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | curvelops/_version.py 2 | 3 | # Documentation 4 | docssrc/build 5 | docssrc/source/api/generated 6 | docssrc/source/gallery 7 | 8 | # Editors 9 | .*.sw[po] 10 | .vscode/ 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: "^docs/" 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.3.0 5 | hooks: 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | - id: check-yaml 9 | - id: check-added-large-files 10 | 11 | - repo: https://github.com/psf/black 12 | rev: 23.9.1 13 | hooks: 14 | - id: black 15 | args: # arguments to configure black 16 | - --line-length=88 17 | 18 | - repo: https://github.com/pycqa/isort 19 | rev: 5.12.0 20 | hooks: 21 | - id: isort 22 | name: isort (python) 23 | args: 24 | [ 25 | "--profile", 26 | "black", 27 | "--skip", 28 | "__init__.py", 29 | "--filter-files", 30 | "--line-length=88", 31 | ] 32 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | Carlos Alberto da Costa Filho, [@cako](https://github.com/cako) 2 | Matteo Ravasi, [@mrava87](https://github.com/mrava87) 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020-2023 Carlos da Costa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PIP := $(shell command -v pip3 2> /dev/null || command which pip 2> /dev/null) 2 | PYTHON := $(shell command -v python3 2> /dev/null || command which python 2> /dev/null) 3 | PYTEST := $(shell command -v pytest 2> /dev/null) 4 | 5 | .PHONY: install dev-install tests doc watchdoc servedoc lint typeannot coverage 6 | 7 | pipcheck: 8 | ifndef PIP 9 | $(error "Ensure pip or pip3 are in your PATH") 10 | endif 11 | @echo Using pip: $(PIP) 12 | 13 | pythoncheck: 14 | ifndef PYTHON 15 | $(error "Ensure python or python3 are in your PATH") 16 | endif 17 | @echo Using python: $(PYTHON) 18 | 19 | pytestcheck: 20 | ifndef PYTEST 21 | $(error "Ensure pytest is in your PATH") 22 | endif 23 | @echo Using pytest: $(PYTEST) 24 | 25 | install: 26 | make pipcheck 27 | $(PIP) install -r requirements.txt && $(PIP) install . 28 | 29 | dev-install: 30 | make pipcheck 31 | $(PIP) install -r requirements-dev.txt && $(PIP) install -e . 32 | 33 | tests: 34 | make pytestcheck 35 | $(PYTEST) tests 36 | 37 | lint: 38 | flake8 examples/ docs/ curvelops/ tests/ 39 | 40 | typeannot: 41 | mypy curvelops/ examples/ 42 | 43 | coverage: 44 | coverage run -m pytest && coverage xml && coverage html && $(PYTHON) -m http.server --directory htmlcov/ 45 | 46 | watchdoc: 47 | make doc && while inotifywait -q -r curvelops/ examples/ docssrc/source/ -e create,delete,modify; do { make docupdate; }; done 48 | 49 | servedoc: 50 | $(PYTHON) -m http.server --directory docssrc/build/html/ 51 | 52 | doc: 53 | # Add after rm: sphinx-apidoc -f -M -o source/ ../curvelops 54 | # Use -O to include private files 55 | cd docssrc && rm -rf source/api/generated && rm -rf source/gallery &&\ 56 | rm -rf source/tutorials && rm -rf source/examples &&\ 57 | rm -rf build && make html && cd .. 58 | 59 | docupdate: 60 | cd docssrc && make html && cd .. 61 | 62 | docgithub: 63 | cd docssrc && make github && cd .. 64 | 65 | docpush: 66 | # Only run when main is at a release commit/tag 67 | python3 -m pip install git+https://github.com/PyLops/curvelops@`git describe --tags` 68 | git checkout gh-pages && git merge main && cd docssrc && make github &&\ 69 | cd ../docs && git add . && git commit -m "Updated documentation" &&\ 70 | git push origin gh-pages && git checkout main 71 | python3 -m pip install -e . 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Documentation](https://github.com/PyLops/curvelops/actions/workflows/pages/pages-build-deployment/badge.svg?branch=gh-pages)](https://pylops.github.io/curvelops/) 2 | [![Slack Status](https://img.shields.io/badge/chat-slack-green.svg)](https://pylops.slack.com) 3 | 4 | # curvelops 5 | 6 | Python wrapper for [CurveLab](http://www.curvelet.org)'s 2D and 3D curvelet 7 | transforms. It uses the [PyLops](https://pylops.readthedocs.io/) design 8 | framework to provide the forward and inverse curvelet transforms as matrix-free 9 | linear operations. If you are still confused, check out 10 | [some examples](https://github.com/PyLops/curvelops/tree/main/examples) below 11 | or the [PyLops website](https://pylops.readthedocs.io/)! 12 | 13 | ## Installation 14 | 15 | Installing `curvelops` requires the following external components: 16 | 17 | - [FFTW](http://www.fftw.org/download.html) 2.1.5 18 | - [CurveLab](http://curvelet.org/software.html) >= 2.0.2 19 | 20 | Both of these packages _must be installed manually_. See more information in 21 | the [Documentation](https://pylops.github.io/curvelops/installation.html#requirements). 22 | After these are installed, you may install `curvelops` with: 23 | 24 | ```bash 25 | export FFTW=/path/to/fftw-2.1.5 26 | export FDCT=/path/to/CurveLab-2.1.3 27 | python3 -m pip install git+https://github.com/PyLops/curvelops@0.23.4 28 | ``` 29 | 30 | as long as you are using a `pip>=10.0`. To check, run `python3 -m pip --version`. 31 | 32 | ## Getting Started 33 | 34 | For a 2D transform, you can get started with: 35 | 36 | ```python 37 | import numpy as np 38 | import curvelops as cl 39 | 40 | x = np.random.randn(100, 50) 41 | FDCT = cl.FDCT2D(dims=x.shape) 42 | c = FDCT @ x 43 | xinv = FDCT.H @ c 44 | np.testing.assert_allclose(x, xinv) 45 | ``` 46 | 47 | An excellent place to see how to use the library is the 48 | [Gallery](https://pylops.github.io/curvelops/gallery/index.html). You can also 49 | find more examples in the 50 | [`notebooks/`](https://github.com/PyLops/curvelops/tree/main/notebooks) folder. 51 | 52 | ![Demo](https://github.com/PyLops/curvelops/raw/main/docssrc/source/static/demo.png) 53 | ![Reconstruction](https://github.com/PyLops/curvelops/raw/main/docssrc/source/static/reconstruction.png) 54 | 55 | ## Useful links 56 | 57 | * [Paul Goyes](https://github.com/PAULGOYES) has kindly contributed a rundown of how to install curvelops: [link to YouTube video (in Spanish)](https://www.youtube.com/watch?v=LAFkknyOpGY). 58 | 59 | ## Note 60 | 61 | This package contains no CurveLab code apart from function calls. It is 62 | provided to simplify the use of CurveLab in a Python environment. Please ensure 63 | you own a CurveLab license as per required by the authors. See the 64 | [CurveLab website](http://curvelet.org/software.html) for more information. All 65 | CurveLab rights are reserved to Emmanuel Candes, Laurent Demanet, David Donoho 66 | and Lexing Ying. 67 | -------------------------------------------------------------------------------- /cpp/fdct2d_wrapper.cpp: -------------------------------------------------------------------------------- 1 | /* fdct2d_wrapper (Pybind11 wrapper for Fast 2D Curvelet Wrapping Transform) 2 | Copyright (C) 2020-2023 Carlos Alberto da Costa Filho 3 | 4 | ${CXX} -O3 -Wall -shared -std=c++11 -fPIC \ 5 | -I${FFTW}/fftw `python3 -m pybind11 --includes` \ 6 | fdct2d_wrapper.cpp ${FDCT}/fdct_wrapping_cpp/src/libfdct_wrapping.a \ 7 | -L${FFTW}/fftw/.libs -lfftw \ 8 | -o fdct2d_wrapper`python3-config --extension-suffix` 9 | */ 10 | 11 | #include "fdct_wrapping.hpp" 12 | #include "fdct_wrapping_inc.hpp" 13 | #include "fdct_wrapping_inline.hpp" 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | namespace py = pybind11; 21 | namespace fdct = fdct_wrapping_ns; 22 | using fdct_wrapping_ns::cpx; 23 | using fdct_wrapping_ns::CpxNumMat; 24 | using std::vector; 25 | 26 | py::tuple 27 | fdct2d_param_wrap(int m, int n, int nbscales, int nbangles_coarse, int ac) 28 | { 29 | // Almost sure this function creates a copy, but it's ok since the outputs 30 | // are small 31 | vector> sx, sy; 32 | vector> fx, fy; 33 | vector> nx, ny; 34 | fdct::fdct_wrapping_param(m, n, nbscales, nbangles_coarse, ac, sx, sy, fx, 35 | fy, nx, ny); 36 | return py::make_tuple(sx, sy, fx, fy, nx, ny); 37 | } 38 | 39 | vector>> 40 | fdct2d_forward_wrap(int nbscales, 41 | int nbangles_coarse, 42 | int ac, 43 | py::array_t x) 44 | { 45 | // Our wrapper takes a NumPy array, but ``fdct_wrapping`` requires a 46 | // CpxNumMat input (which will be accessed read-only). So we must create 47 | // CpxNumMat ``xmat`` which will "mirror" our input ``x`` in a no-copy 48 | // fashion. We also need to output ``c`` whose conversion to a Python list 49 | // of lists of CpxNumMat will be handled by pybind11. The vector -> list 50 | // casting is automatic in pybind11, whereas the CpxNumMat -> 51 | // py::array_t casting is inside our function. 52 | CpxNumMat xmat; 53 | vector> cmat; 54 | 55 | // Responsibly access py::array with possible casting to complex. See: 56 | // https://stackoverflow.com/questions/42645228/cast-numpy-array-to-from-custom-c-matrix-class-using-pybind11 57 | // Note: CurveLab uses Fortran-style indexing, so we must transpose the 58 | // input array. We do this 59 | // by simply reading it as a Fortran array 60 | auto buf = 61 | py::array_t::ensure(x); 62 | if (!buf) 63 | throw std::runtime_error("x array buffer is empty. If you're calling " 64 | "from Python this should not happen!"); 65 | if (buf.ndim() != 2) 66 | throw std::runtime_error("x.ndims != 2"); 67 | 68 | // We don't to initialize ``x(m, n)`` because this allocates an array on 69 | // the heap! 70 | xmat._m = buf.shape()[0]; 71 | xmat._n = buf.shape()[1]; 72 | // Put our Python array buffer pointer as the CpxNumMat data 73 | xmat._data = const_cast(buf.data()); 74 | 75 | // Call our forward function with all the right types 76 | fdct::fdct_wrapping(xmat._m, xmat._n, nbscales, nbangles_coarse, ac, xmat, 77 | cmat); 78 | 79 | // Clear the structure as if it had never existed... 80 | // xmat didn't allocate any data, so we make sure it doesn't deallocate any 81 | // on the way out 82 | xmat._m = xmat._n = 0; 83 | xmat._data = nullptr; 84 | 85 | vector>> c; 86 | // Expand ``c`` to fit the scales 87 | c.resize(cmat.size()); 88 | for (size_t i = 0; i < cmat.size(); i++) { 89 | // Now we expand each scale to fit the angles 90 | c[i].resize(cmat[i].size()); 91 | for (size_t j = 0; j < cmat[i].size(); j++) { 92 | // Create capsule linked to `cmat[i][j]._data` to track its 93 | // lifetime 94 | // https://stackoverflow.com/questions/44659924/returning-numpy-arrays-via-pybind11 95 | py::capsule free_when_done(cmat[i][j].data(), [](void* cpx_ptr) { 96 | cpx* cpx_arr = reinterpret_cast(cpx_ptr); 97 | delete[] cpx_arr; 98 | }); 99 | 100 | // Shape 101 | // Strides (in bytes) of the underlying data array 102 | // Data pointer 103 | // Capsule to be called when the array is deleted in Python 104 | py::array c_arr({cmat[i][j]._n, cmat[i][j]._m}, 105 | {sizeof(cpx) * cmat[i][j]._m, sizeof(cpx)}, 106 | cmat[i][j].data(), free_when_done); 107 | 108 | c[i][j] = c_arr; 109 | cmat[i][j]._m = cmat[i][j]._n = 0; 110 | cmat[i][j]._data = nullptr; 111 | } 112 | } 113 | return c; 114 | } 115 | 116 | py::array_t 117 | fdct2d_inverse_wrap(int m, 118 | int n, 119 | int nbscales, 120 | int nbangles_coarse, 121 | int ac, 122 | vector>> c) 123 | { 124 | // Similarly to the forward wrapper, we create ``cmat`` and ``xmat`` to use 125 | // as dummy input and output arrays. 126 | size_t i, j; 127 | CpxNumMat xmat; 128 | vector> cmat; 129 | 130 | if ((size_t)nbscales != c.size()) 131 | throw std::runtime_error("nbscales != len(c)"); 132 | 133 | // We copy the ``c`` "structure" onto a ``cmat`` "structure" 134 | // Start by expanding the first index of ``cmat`` to fit all scales 135 | cmat.resize(c.size()); 136 | for (i = 0; i < c.size(); i++) { 137 | // Now we expand each scale to fit all angles for that scale 138 | cmat[i].resize(c[i].size()); 139 | for (j = 0; j < c[i].size(); j++) { 140 | // Now we must copy the structure over to ``cmat`` 141 | py::buffer_info buf = c[i][j].request(); 142 | cmat[i][j]._m = buf.shape[1]; 143 | cmat[i][j]._n = buf.shape[0]; 144 | cmat[i][j]._data = static_cast(buf.ptr); 145 | } 146 | } 147 | // No bounds checking is made inside this, so if ``c`` (or equivalently 148 | // ``cmat``) are not compatible with the other parameters, this function 149 | // WILL segfault 150 | // TODO: Optionally sanitize this by calling ``fdct2d_param_wrap`` 151 | fdct::ifdct_wrapping(m, n, nbscales, nbangles_coarse, ac, cmat, xmat); 152 | 153 | // Clear input structure without deallocating 154 | for (i = 0; i < c.size(); i++) 155 | for (j = 0; j < c[i].size(); j++) { 156 | cmat[i][j]._m = cmat[i][j]._n = 0; 157 | cmat[i][j]._data = nullptr; 158 | } 159 | 160 | py::capsule free_when_done(xmat.data(), [](void* cpx_ptr) { 161 | cpx* cpx_arr = reinterpret_cast(cpx_ptr); 162 | delete[] cpx_arr; 163 | }); 164 | 165 | // Create output array 166 | // Shape 167 | // Strides (in bytes) of the underlying data array 168 | // Data pointer 169 | // Capsule to be called when the array is deleted in Python 170 | py::array x({m, n}, {sizeof(cpx), sizeof(cpx) * m}, xmat.data(), 171 | free_when_done); 172 | 173 | // Clear output structure without deallocating 174 | xmat._m = xmat._n = 0; 175 | xmat._data = nullptr; 176 | 177 | return x; 178 | } 179 | 180 | PYBIND11_MODULE(fdct2d_wrapper, m) 181 | { 182 | m.doc() = "FDCT2D pybind11 wrapper"; 183 | m.def("fdct2d_param_wrap", 184 | &fdct2d_param_wrap, 185 | "Parameters for 2D FDCT", 186 | py::return_value_policy::take_ownership); 187 | m.def("fdct2d_forward_wrap", 188 | &fdct2d_forward_wrap, 189 | "2D Forward FDCT", 190 | py::return_value_policy::take_ownership); 191 | m.def("fdct2d_inverse_wrap", 192 | &fdct2d_inverse_wrap, 193 | "2D Inverse FDCT", 194 | py::return_value_policy::take_ownership); 195 | } 196 | -------------------------------------------------------------------------------- /cpp/fdct3d_wrapper.cpp: -------------------------------------------------------------------------------- 1 | /* fdct3d_wrapper (Pybind11 wrapper for Fast 3D Curvelet Wrapping Transform) 2 | Copyright (C) 2020-2023 Carlos Alberto da Costa Filho 3 | 4 | ${CXX} -O3 -Wall -shared -std=c++11 -fPIC \ 5 | -I${FFTW}/fftw `python3 -m pybind11 --includes` \ 6 | fdct2d_wrapper.cpp ${FDCT}/fdct3d/src/libfdct3d.a \ 7 | -L${FFTW}/fftw/.libs -lfftw \ 8 | -o fdct2d_wrapper`python3-config --extension-suffix` 9 | */ 10 | 11 | #include "fdct3d.hpp" 12 | #include "fdct3dinline.hpp" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | namespace py = pybind11; 20 | 21 | py::tuple 22 | fdct3d_param_wrap(int m, 23 | int n, 24 | int p, 25 | int nbscales, 26 | int nbangles_coarse, 27 | int ac) 28 | { 29 | // Almost sure this function creates a copy, but it's ok since the outputs 30 | // are small 31 | vector> fxs, fys, fzs; 32 | vector> nxs, nys, nzs; 33 | fdct3d_param( 34 | m, n, p, nbscales, nbangles_coarse, ac, fxs, fys, fzs, nxs, nys, nzs); 35 | return py::make_tuple(fxs, fys, fzs, nxs, nys, nzs); 36 | } 37 | 38 | vector>> 39 | fdct3d_forward_wrap(int nbscales, 40 | int nbangles_coarse, 41 | int ac, 42 | py::array_t x) 43 | { 44 | // Our wrapper takes a NumPy array, but ``fdct3d_forward`` requires a 45 | // CpxNumTns input (which will be accessed read-only). So we must create 46 | // CpxNumTns ``xtns`` which will "mirror" our input ``x`` in a no-copy 47 | // fashion. We also need to output ``c`` whose conversion to a Python list 48 | // of lists of CpxNumTns will be handled by pybind11. The vector -> list 49 | // casting is automatic in pybind11, whereas the CpxNumTns -> 50 | // py::array_t casting is inside our function. 51 | CpxNumTns xtns; 52 | vector> ctns; 53 | 54 | // Responsibly access py::array with possible casting to complex. See: 55 | // https://stackoverflow.com/questions/42645228/cast-numpy-array-to-from-custom-c-matrix-class-using-pybind11 56 | // Note: CurveLab uses Fortran-style indexing, so we must transpose the 57 | // input array. We do this 58 | // by simply reading it as a Fortran array 59 | auto buf = 60 | py::array_t::ensure(x); 61 | if (!buf) 62 | throw std::runtime_error("x array buffer is empty. If you're calling " 63 | "from Python this should not happen!"); 64 | if (buf.ndim() != 3) 65 | throw std::runtime_error("x.ndims != 3"); 66 | 67 | // We don't to initialize ``x(m, n, p)`` because this allocates an array on 68 | // the heap! 69 | xtns._m = buf.shape()[0]; 70 | xtns._n = buf.shape()[1]; 71 | xtns._p = buf.shape()[2]; 72 | // Put our Python array buffer pointer as the CpxNumTns data 73 | xtns._data = const_cast(buf.data()); 74 | 75 | // Call our forward function with all the right types 76 | fdct3d_forward( 77 | xtns._m, xtns._n, xtns._p, nbscales, nbangles_coarse, ac, xtns, ctns); 78 | 79 | // Clear the structure as if it had never existed... 80 | // xtns didn't allocate any data, so we make sure it doesn't deallocate any 81 | // on the way out 82 | xtns._m = xtns._n = xtns._p = 0; 83 | xtns._data = nullptr; 84 | 85 | vector>> c; 86 | // Expand ``c`` to fit the scales 87 | c.resize(ctns.size()); 88 | for (size_t i = 0; i < ctns.size(); i++) { 89 | // Now we expand each scale to fit the angles 90 | c[i].resize(ctns[i].size()); 91 | for (size_t j = 0; j < ctns[i].size(); j++) { 92 | // Create capsule linked to `ctns[i][j]._data` to track its 93 | // lifetime 94 | // https://stackoverflow.com/questions/44659924/returning-numpy-arrays-via-pybind11 95 | py::capsule free_when_done(ctns[i][j].data(), [](void* cpx_ptr) { 96 | cpx* cpx_arr = reinterpret_cast(cpx_ptr); 97 | delete[] cpx_arr; 98 | }); 99 | 100 | // Shape 101 | // Strides (in bytes) of the underlying data array 102 | // Data pointer 103 | // Capsule to be called when the array is deleted in Python 104 | py::array c_arr({ctns[i][j]._p, ctns[i][j]._n, ctns[i][j]._m}, 105 | {sizeof(cpx) * ctns[i][j]._m * ctns[i][j]._n, 106 | sizeof(cpx) * ctns[i][j]._m, sizeof(cpx)}, 107 | ctns[i][j].data(), free_when_done); 108 | 109 | c[i][j] = c_arr; 110 | ctns[i][j]._m = ctns[i][j]._n = ctns[i][j]._p = 0; 111 | ctns[i][j]._data = nullptr; 112 | } 113 | } 114 | return c; 115 | } 116 | 117 | py::array_t 118 | fdct3d_inverse_wrap(int m, 119 | int n, 120 | int p, 121 | int nbscales, 122 | int nbangles_coarse, 123 | int ac, 124 | vector>> c) 125 | { 126 | // Similarly to the forward wrapper, we create ``ctns`` and ``xtns`` to use 127 | // as dummy input and output arrays. 128 | size_t i, j; 129 | CpxNumTns xtns; 130 | vector> ctns; 131 | 132 | if ((size_t)nbscales != c.size()) 133 | throw std::runtime_error("nbscales != len(c)"); 134 | 135 | // We copy the ``c`` "structure" onto a ``ctns`` "structure" 136 | // Start by expanding the first index of ``ctns`` to fit all scales 137 | ctns.resize(c.size()); 138 | for (i = 0; i < c.size(); i++) { 139 | // Now we expand each scale to fit all angles for that scale 140 | ctns[i].resize(c[i].size()); 141 | for (j = 0; j < c[i].size(); j++) { 142 | // Now we must copy the structure over to ``ctns`` 143 | py::buffer_info buf = c[i][j].request(); 144 | ctns[i][j]._m = buf.shape[2]; 145 | ctns[i][j]._n = buf.shape[1]; 146 | ctns[i][j]._p = buf.shape[0]; 147 | ctns[i][j]._data = static_cast(buf.ptr); 148 | } 149 | } 150 | // No bounds checking is made inside this, so if ``c`` (or equivalently 151 | // ``ctns``) are not compatible with the other parameters, this function 152 | // WILL segfault 153 | // TODO: Optionally sanitize this by calling ``fdct3d_param_wrap`` 154 | fdct3d_inverse(m, n, p, nbscales, nbangles_coarse, ac, ctns, xtns); 155 | 156 | // Clear input structure without deallocating 157 | for (i = 0; i < c.size(); i++) 158 | for (j = 0; j < c[i].size(); j++) { 159 | ctns[i][j]._m = ctns[i][j]._n = ctns[i][j]._p = 0; 160 | ctns[i][j]._data = nullptr; 161 | } 162 | 163 | py::capsule free_when_done(xtns.data(), [](void* cpx_ptr) { 164 | cpx* cpx_arr = reinterpret_cast(cpx_ptr); 165 | delete[] cpx_arr; 166 | }); 167 | 168 | // Create output array 169 | // Shape 170 | // Strides (in bytes) of the underlying data array 171 | // Data pointer 172 | // Capsule to be called when the array is deleted in Python 173 | py::array x({m, n, p}, {sizeof(cpx), sizeof(cpx) * m, sizeof(cpx) * m * n}, 174 | xtns.data(), free_when_done); 175 | 176 | // Clear output structure without deallocating 177 | xtns._m = xtns._n = xtns._p = 0; 178 | xtns._data = nullptr; 179 | 180 | return x; 181 | } 182 | 183 | PYBIND11_MODULE(fdct3d_wrapper, m) 184 | { 185 | m.doc() = "FDCT3D pybind11 wrapper"; 186 | m.def("fdct3d_param_wrap", 187 | &fdct3d_param_wrap, 188 | "Parameters for 3D FDCT", 189 | py::return_value_policy::take_ownership); 190 | m.def("fdct3d_forward_wrap", 191 | &fdct3d_forward_wrap, 192 | "3D Forward FDCT", 193 | py::return_value_policy::take_ownership); 194 | m.def("fdct3d_inverse_wrap", 195 | &fdct3d_inverse_wrap, 196 | "3D Inverse FDCT", 197 | py::return_value_policy::take_ownership); 198 | } 199 | -------------------------------------------------------------------------------- /curvelops/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``curvelops`` 3 | ============= 4 | 5 | Python wrapper for CurveLab's 2D and 3D curvelet transforms. 6 | """ 7 | from .curvelops import * 8 | from .utils import * 9 | from .plot import * 10 | from .typing import * 11 | 12 | 13 | try: 14 | from ._version import __version__ 15 | except ImportError: 16 | from datetime import datetime 17 | 18 | __version__ = "0.0.unknown+" + datetime.today().strftime("%Y%m%d") 19 | -------------------------------------------------------------------------------- /curvelops/curvelops.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``curvelops.curvelops`` 3 | ======================= 4 | 5 | Provides a LinearOperator for the 2D and 3D curvelet transforms. 6 | """ 7 | 8 | from itertools import product 9 | from math import prod 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | import numpy as np 13 | from numpy.core.multiarray import normalize_axis_index # type: ignore 14 | from numpy.typing import DTypeLike, NDArray 15 | from pylops import LinearOperator 16 | from pylops.utils.typing import InputDimsLike 17 | 18 | # pylint: disable=no-name-in-module 19 | from .fdct2d_wrapper import ( # noqa: F403 20 | fdct2d_forward_wrap, 21 | fdct2d_inverse_wrap, 22 | fdct2d_param_wrap, 23 | ) 24 | from .fdct3d_wrapper import ( # noqa: F403 25 | fdct3d_forward_wrap, 26 | fdct3d_inverse_wrap, 27 | fdct3d_param_wrap, 28 | ) 29 | 30 | # pylint: enable=no-name-in-module 31 | from .typing import FDCTStructLike 32 | 33 | 34 | def _fdct_docs(dimension: int) -> str: 35 | if dimension == 2: 36 | doc = "2D" 37 | elif dimension == 3: 38 | doc = "3D" 39 | else: 40 | doc = "2D/3D" 41 | return f"""{doc} dimensional Curvelet operator. 42 | Apply {doc} Curvelet Transform along two ``axes`` of a 43 | multi-dimensional array of size ``dims``. 44 | 45 | Parameters 46 | ---------- 47 | dims : :obj:`tuple` 48 | Number of samples for each dimension. 49 | axes : :obj:`tuple`, optional 50 | Axes along which FDCT is applied. 51 | nbscales : :obj:`int`, optional 52 | Number of scales (including the coarsest level); 53 | Defaults to ceil(log2(min(input_dims)) - 3). 54 | nbangles_coarse : :obj:`int`, optional 55 | Number of angles at 2nd coarsest scale. 56 | allcurvelets : :obj:`bool`, optional 57 | Use curvelets at the finest (last) scale. 58 | If ``False``, a wavelet transform will be used for the 59 | finest scale. The coarsest scale is always a wavelet transform; 60 | the ones between the coarsest and the finest are all curvelet 61 | transforms. This option only affects the finest scale. 62 | dtype : :obj:`DTypeLike `, optional 63 | ``dtype`` of the transform. 64 | """ 65 | 66 | 67 | class FDCT(LinearOperator): 68 | __doc__ = _fdct_docs(0) 69 | 70 | def __init__( 71 | self, 72 | dims: InputDimsLike, 73 | axes: Tuple[int, ...], 74 | nbscales: Optional[int] = None, 75 | nbangles_coarse: int = 16, 76 | allcurvelets: bool = True, 77 | dtype: DTypeLike = "complex128", 78 | ) -> None: 79 | ndim = len(dims) 80 | 81 | # Ensure axes are between 0, ndim-1 82 | axes = tuple(normalize_axis_index(d, ndim) for d in axes) 83 | 84 | # If input is shaped (100, 200, 300) and axes = (0, 2) 85 | # then input_shape will be (100, 300) 86 | self._input_shape = list(int(dims[d]) for d in axes) 87 | if nbscales is None: 88 | nbscales = int(np.ceil(np.log2(min(self._input_shape)) - 3)) 89 | 90 | # Check dimension 91 | sizes: Union[Tuple[NDArray, NDArray], Tuple[NDArray, NDArray, NDArray]] 92 | if len(axes) == 2: 93 | self.fdct: Callable = fdct2d_forward_wrap # type: ignore # noqa: F405 94 | self.ifdct: Callable = fdct2d_inverse_wrap # type: ignore # noqa: F405 95 | _, _, _, _, nxs, nys = fdct2d_param_wrap( # type: ignore # noqa: F405 96 | *self._input_shape, nbscales, nbangles_coarse, allcurvelets 97 | ) 98 | sizes = (nys, nxs) 99 | elif len(axes) == 3: 100 | self.fdct: Callable = fdct3d_forward_wrap # type: ignore # noqa: F405 101 | self.ifdct: Callable = fdct3d_inverse_wrap # type: ignore # noqa: F405 102 | _, _, _, nxs, nys, nzs = fdct3d_param_wrap( # type: ignore # noqa: F405 103 | *self._input_shape, nbscales, nbangles_coarse, allcurvelets 104 | ) 105 | sizes = (nzs, nys, nxs) 106 | else: 107 | raise NotImplementedError("FDCT is only implemented in 2D or 3D") 108 | 109 | # Complex operator is required to handle complex input 110 | dtype = np.dtype(dtype) 111 | if np.issubdtype(dtype, np.complexfloating): 112 | cpx = True 113 | else: 114 | cpx = False 115 | raise NotImplementedError("Only complex types supported") 116 | 117 | # Now we need to build the iterator which will only iterate along 118 | # the required axes. Following the example above, 119 | # iterable_axes = [ False, True, False ] 120 | iterable_axes = [i not in axes for i in range(ndim)] 121 | iterable_dims = np.array(dims)[iterable_axes] 122 | self._ndim_iterable = prod(iterable_dims) 123 | 124 | # Build the iterator itself. In our example, the slices 125 | # would be [:, i, :] for i in range(200) 126 | # We use slice(None) is the colon operator 127 | self._iterator = list( 128 | product( 129 | *( 130 | range(dims[ax]) if doiter else [slice(None)] 131 | for ax, doiter in enumerate(iterable_axes) 132 | ) 133 | ) 134 | ) 135 | 136 | # For a single 2d/3d input, the length of the vector will be given by 137 | # the shapes in FDCT.sizes 138 | self.shapes = [ 139 | [tuple(s[i][j] for s in sizes) for j in range(len(nx))] 140 | for i, nx in enumerate(nxs) 141 | ] 142 | self._output_len = sum(prod(j) for i in self.shapes for j in i) 143 | 144 | # Save some useful properties 145 | self.inpdims = dims 146 | self.axes = axes 147 | self.nbscales = nbscales 148 | self.nbangles_coarse = nbangles_coarse 149 | self.allcurvelets = allcurvelets 150 | self.cpx = cpx 151 | 152 | # Required by PyLops 153 | super().__init__( 154 | dtype=dtype, 155 | dims=self.inpdims, 156 | dimsd=(*iterable_dims, self._output_len), 157 | ) 158 | 159 | def _matvec(self, x: NDArray) -> NDArray: 160 | fwd_out = np.zeros((self._output_len, self._ndim_iterable), dtype=self.dtype) 161 | for i, index in enumerate(self._iterator): 162 | x_shaped = np.array(x.reshape(self.inpdims)[index]) 163 | c_struct: FDCTStructLike = self.fdct( 164 | self.nbscales, 165 | self.nbangles_coarse, 166 | self.allcurvelets, 167 | x_shaped, 168 | ) 169 | fwd_out[:, i] = self.vect(c_struct) 170 | return fwd_out.ravel() 171 | 172 | def _rmatvec(self, x: NDArray) -> NDArray: 173 | y_shaped = x.reshape(self._output_len, self._ndim_iterable) 174 | inv_out = np.zeros(self.inpdims, dtype=self.dtype) 175 | for i, index in enumerate(self._iterator): 176 | y_struct = self.struct(np.array(y_shaped[:, i])) 177 | xinv: NDArray = self.ifdct( 178 | *self._input_shape, 179 | self.nbscales, 180 | self.nbangles_coarse, 181 | self.allcurvelets, 182 | y_struct, 183 | ) 184 | inv_out[index] = xinv 185 | 186 | return inv_out.ravel() 187 | 188 | def inverse(self, x: NDArray) -> NDArray: 189 | """Inverse Curvelet Transform 190 | 191 | Parameters 192 | ---------- 193 | x : NDArray 194 | Input vector 195 | 196 | Returns 197 | ------- 198 | NDArray 199 | FDCT.H @ x 200 | """ 201 | return self._rmatvec(x) 202 | 203 | def struct(self, x: NDArray) -> FDCTStructLike: 204 | """Convert curvelet flattened vector to curvelet structure. 205 | 206 | The FDCT always returns a 1D vector that has all curvelet 207 | coefficients. These coefficients can be organized into 208 | scales, wedges and spatial positions. Applying this 209 | function to a 1D vector generates this structure. 210 | 211 | Parameters 212 | ---------- 213 | x : :obj:`NDArray ` 214 | Input flattened vector. 215 | 216 | Returns 217 | ------- 218 | :obj:`FDCTStructLike ` 219 | Curvelet structure, a list of lists of multidimensional arrays. 220 | The first index corresponds to scale, the second corresponds to 221 | angular wedge. 222 | """ 223 | c_struct: FDCTStructLike = [] 224 | k = 0 225 | for shapes_s in self.shapes: 226 | angles = [] 227 | for shape_w in shapes_s: 228 | size = prod(shape_w) 229 | angles.append(x[k : k + size].reshape(shape_w)) 230 | k += size 231 | c_struct.append(angles) 232 | return c_struct 233 | 234 | def vect(self, x: FDCTStructLike) -> NDArray: 235 | """Convert curvelet structure to curvelet flattened vector. 236 | 237 | The FDCT always returns a 1D vector that has all curvelet 238 | coefficients. These coefficients can be organized into 239 | scales, wedges and spatial positions. Applying this 240 | function to a curvelet structure returns the flattened 241 | vector. 242 | 243 | Parameters 244 | ---------- 245 | x : :obj:`FDCTStructLike ` 246 | Input curvelet structure. 247 | 248 | Returns 249 | ------- 250 | :obj:`NDArray ` 251 | Flattened vector. 252 | """ 253 | return np.concatenate([coef.ravel() for angle in x for coef in angle]) 254 | 255 | 256 | class FDCT2D(FDCT): 257 | __doc__ = _fdct_docs(2) 258 | 259 | def __init__( 260 | self, 261 | dims: InputDimsLike, 262 | axes: Tuple[int, ...] = (-2, -1), 263 | nbscales: Optional[int] = None, 264 | nbangles_coarse: int = 16, 265 | allcurvelets: bool = True, 266 | dtype: DTypeLike = "complex128", 267 | ) -> None: 268 | assert len(axes) == 2, ValueError("FDCT2D must be called with exactly two axes") 269 | super().__init__(dims, axes, nbscales, nbangles_coarse, allcurvelets, dtype) 270 | 271 | 272 | class FDCT3D(FDCT): 273 | __doc__ = _fdct_docs(3) 274 | 275 | def __init__( 276 | self, 277 | dims: InputDimsLike, 278 | axes: Tuple[int, ...] = (-3, -2, -1), 279 | nbscales: Optional[int] = None, 280 | nbangles_coarse: int = 16, 281 | allcurvelets: bool = True, 282 | dtype: DTypeLike = "complex128", 283 | ) -> None: 284 | assert len(axes) == 3, ValueError( 285 | "FDCT3D must be called with exactly three axes" 286 | ) 287 | super().__init__(dims, axes, nbscales, nbangles_coarse, allcurvelets, dtype) 288 | -------------------------------------------------------------------------------- /curvelops/plot/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``curvelops.plot`` 3 | ================== 4 | 5 | Auxiliary functions for plotting. 6 | """ 7 | 8 | from . import _curvelet, _generic 9 | 10 | __all__ = _curvelet.__all__ + _generic.__all__ 11 | 12 | 13 | from ._curvelet import * 14 | from ._generic import * 15 | -------------------------------------------------------------------------------- /curvelops/plot/_curvelet.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "curveshow", 3 | "overlay_disks", 4 | ] 5 | import itertools 6 | from math import ceil, floor 7 | from typing import List, Optional, Union 8 | 9 | import matplotlib as mpl 10 | import matplotlib.cm as cm 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | from matplotlib.colors import Colormap 14 | from matplotlib.figure import Figure 15 | from numpy.typing import NDArray 16 | 17 | from ..typing import FDCTStructLike 18 | from ..utils import apply_along_wedges, energy_split 19 | 20 | 21 | def curveshow( 22 | c_struct: FDCTStructLike, 23 | k_space: bool = False, 24 | basesize: int = 5, 25 | showaxis: bool = False, 26 | real: bool = True, 27 | kwargs_imshow: Optional[dict] = None, 28 | ) -> List[Figure]: 29 | """Display curvelet coefficients in each wedge as images. 30 | 31 | For each curvelet scale, display a figure with each wedge 32 | plotted as an image in its own axis. 33 | 34 | Parameters 35 | ---------- 36 | c_struct : :obj:`FDCTStructLike ` 37 | Curvelet structure. 38 | k_space : :obj:`bool`, optional 39 | Show curvelet coefficient (False) or its 2D FFT transform (True), 40 | by default False. 41 | basesize : :obj:`int`, optional 42 | Base fize of figure, by default 5. Each figure will be sized 43 | ``(basesize * cols, basesize * rows)``, where 44 | ``rows = floor(sqrt(nangles))`` and ``cols = ceil(nangles / rows)`` 45 | showaxis : :obj:`bool`, optional 46 | Turn on axis lines and labels, by default False. 47 | real : :obj:`bool`, optional 48 | Plot real or imaginary part of curvelet coefficients. Only applicable 49 | when ``k_space`` is False. 50 | kwargs_imshow : ``Optional[dict]``, optional 51 | Arguments to be passed to :obj:`matplotlib.pyplot.imshow`. 52 | 53 | Examples 54 | -------- 55 | >>> import numpy as np 56 | >>> from curvelops import FDCT2D 57 | >>> from curvelops.utils import apply_along_wedges, energy 58 | >>> from curvelops.plot import curveshow 59 | >>> d = np.random.randn(101, 101) 60 | >>> C = FDCT2D(d.shape, nbscales=2, nbangles_coarse=8) 61 | >>> y = C.struct(C @ d) 62 | >>> y_norm = apply_along_wedges(y, lambda w, *_: w / energy(w)) 63 | >>> curveshow( 64 | >>> y_norm, 65 | >>> basesize=2, 66 | >>> kwargs_imshow=dict(aspect="auto", vmin=-1, vmax=1, cmap="RdBu") 67 | >>> ) 68 | 69 | Returns 70 | ------- 71 | List[:obj:`Figure `] 72 | One figure per scale. 73 | """ 74 | 75 | def fft(x): 76 | return np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(x), norm="ortho")) 77 | 78 | _kwargs_imshow_default = {} 79 | if k_space: 80 | _kwargs_imshow_default["vmax"] = np.abs(fft(c_struct[0][0])).max() 81 | _kwargs_imshow_default["vmin"] = 0.0 82 | _kwargs_imshow_default["cmap"] = "turbo" 83 | else: 84 | _kwargs_imshow_default["vmax"] = np.abs(c_struct[0][0]).max() 85 | _kwargs_imshow_default["vmin"] = -_kwargs_imshow_default["vmax"] 86 | _kwargs_imshow_default["cmap"] = "gray" 87 | if kwargs_imshow is None: 88 | kwargs_imshow = _kwargs_imshow_default 89 | else: 90 | kwargs_imshow = {**_kwargs_imshow_default, **kwargs_imshow} 91 | 92 | figsize_aspect = c_struct[0][0].shape[0] / c_struct[0][0].shape[1] 93 | figs_axes = [] 94 | for iscale, c_scale in enumerate(c_struct): 95 | nangles = len(c_scale) 96 | rows = floor(np.sqrt(nangles)) 97 | cols = ceil(nangles / rows) 98 | fig, axes = plt.subplots( 99 | rows, 100 | cols, 101 | figsize=(basesize * cols, figsize_aspect * basesize * rows), 102 | ) 103 | fig.suptitle(f"Scale {iscale} ({nangles} wedge{'s' if nangles > 1 else ''})") 104 | figs_axes.append((fig, axes)) 105 | axes = np.atleast_1d(axes).ravel() 106 | 107 | for iwedge, (c_wedge, ax) in enumerate(zip(c_scale, axes)): 108 | if k_space: 109 | ax.imshow(np.abs(fft(c_wedge)), **kwargs_imshow) 110 | else: 111 | if real: 112 | ax.imshow(c_wedge.real, **kwargs_imshow) 113 | else: 114 | ax.imshow(c_wedge.imag, **kwargs_imshow) 115 | if nangles > 1: 116 | ax.set(title=f"Wedge {iwedge}") 117 | if not showaxis: 118 | ax.axis("off") 119 | fig.tight_layout() 120 | return figs_axes 121 | 122 | 123 | def overlay_disks( 124 | c_struct: FDCTStructLike, 125 | axes: NDArray, 126 | linewidth: float = 0.5, 127 | linecolor: str = "r", 128 | map_cmap: bool = True, 129 | cmap: Union[str, Colormap] = "gray_r", 130 | alpha: float = 1.0, 131 | pclip: float = 1.0, 132 | map_alpha: bool = False, 133 | min_alpha: float = 0.05, 134 | normalize: str = "all", 135 | annotate: bool = False, 136 | ): 137 | """Overlay curvelet disks over a 2D grid of axes. 138 | 139 | Its intended usage is to display the strength of curvelet coefficients 140 | of a certain image with a disk display. Given an ``axes`` 2D array, 141 | each curvelet wedge will be split into ``rows, cols = axes.shape`` 142 | sub-wedges. The energy of each of these sub-wedges will be mapped 143 | to a colormap color and/or transparency. 144 | 145 | See Also 146 | -------- 147 | :obj:`energy_split `: Splits a wedge into ``(rows, cols)`` wedges and computes the energy of each of these subdivisions. 148 | 149 | :obj:`create_inset_axes_grid`: Create a grid of insets. 150 | 151 | :obj:`create_axes_grid`: Creates a grid of axes. 152 | 153 | :obj:`curveshow`: Display curvelet coefficients in each wedge as images. 154 | 155 | Parameters 156 | ---------- 157 | c_struct : :obj:`FDCTStructLike `: 158 | Curvelet coefficients of underlying image. 159 | axes : :obj:`NDArray ` 160 | 2D grid of axes for which disks will be computed. 161 | linewidth : :obj:`float`, optional 162 | Width of line separating scales, by default 0.5. 163 | Will be scaled by ``0.1 / nscales`` internally. 164 | Set to zero to disable. 165 | linecolor : :obj:`str`, optional 166 | Color of line separating scales, by default "r". 167 | map_cmap : :obj:`bool`, optional 168 | When enabled, energy will be mapped to the colormap, by default True. 169 | cmap : Union[:obj:`str`, :obj:`Colormap `], optional 170 | Colormap, by default ``"gray_r"``. 171 | alpha : :obj:`float`, optional 172 | When using ``map_cmap``, sets a transparecy for all wedges. 173 | Has no effect when ``map_alpha`` is enabled. By default 1.0. 174 | pclip : :obj:`float`, optional 175 | Clips the maximum amplitude by this percentage. By default 1.0. 176 | Should be between 0.0 and 1.0. 177 | map_alpha : :obj:`bool`, optional 178 | When enabled, energy will be mapped to the transparency, by default False. 179 | min_alpha : :obj:`float`, optional 180 | When using ``map_alpha``, sets a minimum transparency value. 181 | Has no effect when ``map_alpha`` is disabled. By default 0.05. 182 | normalize : :obj:`str`, optional 183 | Normalize wedges by: 184 | 185 | * ``"all"`` (default) 186 | Colormap/alpha value of 1.0 will correspond to the maximum 187 | energy found across all wedges 188 | 189 | * ``"scale"`` 190 | Colormap/alpha value of 1.0 will correspond to the maximum 191 | energy found across all wedges in the same scale. 192 | annotate : :obj:`bool`, optional 193 | When true, will display in the middle of the wedge a 194 | pair of numbers ``iscale, iwedge``, the index of that scale 195 | and that wedge, both starting from zero. This option is useful to 196 | understand which directions each wedge corresponds to. 197 | By default False. 198 | 199 | Examples 200 | -------- 201 | >>> import matplotlib.pyplot as plt 202 | >>> import numpy as np 203 | >>> from curvelops import FDCT2D 204 | >>> from curvelops.utils import apply_along_wedges 205 | >>> from curvelops.plot import create_axes_grid, overlay_disks 206 | >>> x = np.random.randn(50, 100) 207 | >>> C = FDCT2D(x.shape, nbscales=4, nbangles_coarse=8) 208 | >>> y = C.struct(C @ x) 209 | >>> y_ones = apply_along_wedges(y, lambda w, *_: np.ones_like(w)) 210 | >>> fig, axes = create_axes_grid( 211 | >>> 1, 212 | >>> 1, 213 | >>> kwargs_subplots=dict(projection="polar"), 214 | >>> kwargs_figure=dict(figsize=(8, 8)), 215 | >>> ) 216 | >>> overlay_disks(y_ones, axes, annotate=True, cmap="gray") 217 | 218 | >>> import matplotlib as mpl 219 | >>> import matplotlib.pyplot as plt 220 | >>> import numpy as np 221 | >>> from mpl_toolkits.axes_grid1 import make_axes_locatable 222 | >>> from curvelops import FDCT2D 223 | >>> from curvelops.plot import create_inset_axes_grid, overlay_disks 224 | >>> from curvelops.utils import apply_along_wedges 225 | >>> plt.rcParams.update({"image.interpolation": "blackman"}) 226 | >>> # Construct signal 227 | >>> xlim = [-1.0, 1.0] 228 | >>> ylim = [-0.5, 0.5] 229 | >>> x = np.linspace(*xlim, 201) 230 | >>> z = np.linspace(*ylim, 101) 231 | >>> xm, zm = np.meshgrid(x, z, indexing="ij") 232 | >>> freq = 5 233 | >>> d = np.cos(2 * np.pi * freq * (xm + np.cos(xm) * zm) ** 3) 234 | >>> # Compute curvelet coefficients 235 | >>> C = FDCT2D(d.shape, nbangles_coarse=8, allcurvelets=False) 236 | >>> d_c = C.struct(C @ d) 237 | >>> # Plot original signal 238 | >>> fig, ax = plt.subplots(figsize=(8, 4 )) 239 | >>> ax.imshow(d.T, extent=[*xlim, *(ylim[::-1])], cmap="RdYlBu", vmin=-1, vmax=1) 240 | >>> ax.axis("off") 241 | >>> # Overlay disks 242 | >>> rows, cols = 3, 6 243 | >>> axesin = create_inset_axes_grid( 244 | >>> ax, rows, cols, width=0.75, kwargs_inset_axes=dict(projection="polar") 245 | >>> ) 246 | >>> pclip = 0.2 247 | >>> cmap = "gray_r" 248 | >>> overlay_disks(d_c, axesin, linewidth=0.0, pclip=pclip, cmap=cmap) 249 | >>> # Display disk colorbar 250 | >>> divider = make_axes_locatable(ax) 251 | >>> cax = divider.append_axes("right", size="5%", pad=0.1) 252 | >>> mpl.colorbar.ColorbarBase( 253 | >>> cax, cmap=cmap, norm=mpl.colors.Normalize(vmin=0, vmax=pclip) 254 | >>> ) 255 | """ 256 | rows, cols = axes.shape 257 | e_split = apply_along_wedges(c_struct, lambda w, *_: energy_split(w, rows, cols)) 258 | max_e = max(a.max() for a in itertools.chain.from_iterable(e_split)) 259 | 260 | cmapper = cm.ScalarMappable( 261 | norm=mpl.colors.Normalize(0, pclip, clip=True), cmap=cmap 262 | ) 263 | 264 | nscales = len(c_struct) 265 | linewidth *= 0.1 / nscales 266 | 267 | for iscale in range(nscales): 268 | nangles = len(c_struct[iscale]) 269 | angles_per_wedge = 2 * np.pi / nangles 270 | 271 | if normalize == "scale": 272 | max_e = max(a.max() for a in itertools.chain.from_iterable(e_split[iscale])) 273 | 274 | # To start starting counterclockwise from the top middle, 275 | # we need to offset the wedge index by the following amount 276 | iwedge_offset = nangles - nangles // 8 277 | for iwedge in range(nangles): 278 | for irow in range(rows): 279 | for icol in range(cols): 280 | e = e_split[iscale][iwedge][irow, icol] 281 | if map_alpha: 282 | alpha = np.clip( 283 | min_alpha + (1 - min_alpha) * e / max_e, 284 | min_alpha, 285 | 1, 286 | ) 287 | if map_cmap: 288 | color = cmapper.to_rgba(np.clip(e / max_e, 0, 1)) 289 | else: 290 | color = cmapper.to_rgba(1) 291 | 292 | # Place the starting wedges in the correct place 293 | iwedge_shift = (nangles // 2 + iwedge + iwedge_offset) % nangles 294 | 295 | # Wedge coordinates in polar plot 296 | wedge_x = (iwedge_shift + 0.5) * angles_per_wedge 297 | wedge_width = angles_per_wedge 298 | wedge_height = 1 / (nscales - 1) 299 | wedge_bottom = iscale * wedge_height 300 | axes[irow][icol].bar( 301 | x=wedge_x, 302 | height=wedge_height, 303 | width=wedge_width, 304 | bottom=wedge_bottom, 305 | color=color, 306 | alpha=alpha, 307 | ) 308 | if nangles > 1: 309 | axes[irow][icol].bar( 310 | x=wedge_x - wedge_width / 2, 311 | height=wedge_height, 312 | width=linewidth, 313 | bottom=wedge_bottom, 314 | color=linecolor, 315 | ) 316 | if annotate: 317 | axes[irow][icol].text( 318 | wedge_x, 319 | wedge_bottom 320 | + (0 if wedge_bottom == 0 else wedge_height / 2), 321 | f"{iscale}, {iwedge}", 322 | backgroundcolor="w", 323 | color="k", 324 | horizontalalignment="center", 325 | verticalalignment="center", 326 | fontsize=6, 327 | ) 328 | 329 | # Plot line separating scales 330 | for irow in range(rows): 331 | for icol in range(cols): 332 | axes[irow][icol].axis("off") 333 | for iscale in range(nscales): 334 | if linewidth > 0.0: 335 | axes[irow][icol].bar( 336 | x=0, 337 | height=linewidth, 338 | width=2 * np.pi, 339 | bottom=(iscale + 1 - linewidth / 2) / (nscales - 1), 340 | color=linecolor, 341 | ) 342 | -------------------------------------------------------------------------------- /curvelops/plot/_generic.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "create_colorbar", 3 | "create_axes_grid", 4 | "create_inset_axes_grid", 5 | "overlay_arrows", 6 | ] 7 | from typing import Optional, Tuple 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from matplotlib.axes import Axes 12 | from matplotlib.colorbar import Colorbar 13 | from matplotlib.figure import Figure 14 | from matplotlib.image import AxesImage 15 | from mpl_toolkits.axes_grid1 import make_axes_locatable 16 | from numpy.typing import NDArray 17 | 18 | 19 | def _create_range(start, end, n): 20 | return start + (end - start) * (0.5 + np.arange(n)) / n 21 | 22 | 23 | def create_colorbar( 24 | im: AxesImage, 25 | ax: Axes, 26 | size: float = 0.05, 27 | pad: float = 0.1, 28 | orientation: str = "vertical", 29 | ) -> Tuple[Axes, Colorbar]: 30 | r"""Create a colorbar. 31 | 32 | Divides axis and attaches a colorbar to it. 33 | 34 | Parameters 35 | ---------- 36 | im : :obj:`AxesImage ` 37 | Image from which the colorbar will be created. 38 | Commonly the output of :obj:`matplotlib.pyplot.imshow`. 39 | ax : :obj:`Axes ` 40 | Axis which to split. 41 | size : :obj:`float`, optional 42 | Size of split, by default 0.05. Effectively sets the size of the colorbar. 43 | pad : :obj:`float`, optional` 44 | Padding between colorbar axis and input axis, by default 0.1. 45 | orientation : :obj:`str`, optional 46 | Orientation of the colorbar, by default "vertical". 47 | 48 | Returns 49 | ------- 50 | Tuple[:obj:`Axes `, :obj:`Colorbar `] 51 | **cax** : Colorbar axis. 52 | 53 | **cb** : Colorbar. 54 | 55 | Examples 56 | -------- 57 | >>> import matplotlib.pyplot as plt 58 | >>> from matplotlib.ticker import MultipleLocator 59 | >>> from curvelops.plot import create_colorbar 60 | >>> fig, ax = plt.subplots() 61 | >>> im = ax.imshow([[0]], vmin=-1, vmax=1, cmap="gray") 62 | >>> cax, cb = create_colorbar(im, ax) 63 | >>> cax.yaxis.set_major_locator(MultipleLocator(0.1)) 64 | >>> print(cb.vmin) 65 | -1.0 66 | """ 67 | divider = make_axes_locatable(ax) 68 | cax = divider.append_axes("right", size=f"{size:%}", pad=pad) 69 | cb = ax.get_figure().colorbar(im, cax=cax, orientation=orientation) 70 | return cax, cb 71 | 72 | 73 | def create_axes_grid( 74 | rows: int, 75 | cols: int, 76 | kwargs_figure: Optional[dict] = None, 77 | kwargs_gridspec: Optional[dict] = None, 78 | kwargs_subplots: Optional[dict] = None, 79 | ) -> Tuple[Figure, NDArray]: 80 | r"""Creates a grid of axes. 81 | 82 | Parameters 83 | ---------- 84 | rows : :obj:`int` 85 | Number of rows. 86 | cols : :obj:`int` 87 | Number of columns. 88 | kwargs_figure : ``Optional[dict]``, optional 89 | Arguments to be passed to :obj:`matplotlib.pyplot.figure`. 90 | kwargs_gridspec : ``Optional[dict]``, optional 91 | Arguments to be passed to :obj:`matplotlib.gridspec.GridSpec`. 92 | kwargs_subplots : ``Optional[dict]``, optional 93 | Arguments to be passed to :obj:`matplotlib.figure.Figure.add_subplot`. 94 | 95 | 96 | Returns 97 | ------- 98 | Tuple[:obj:`Figure `, :obj:`NDArray `] 99 | **fig** : Figure. 100 | 101 | **axs** : Array of :obj:`Axes ` shaped ``(rows, cols)``. 102 | 103 | Examples 104 | -------- 105 | >>> from curvelops.plot import create_axes_grid 106 | >>> rows, cols = 2, 3 107 | >>> fig, axs = create_axes_grid( 108 | >>> rows, 109 | >>> cols, 110 | >>> kwargs_figure=dict(figsize=(8, 8)), 111 | >>> kwargs_gridspec=dict(wspace=0.3, hspace=0.3), 112 | >>> ) 113 | >>> for irow in range(rows): 114 | >>> for icol in range(cols): 115 | >>> axs[irow][icol].plot(np.cos((2 + irow + icol**2) * np.linspace(0, 1))) 116 | >>> axs[irow][icol].set(title=f"Row, Col: ({irow}, {icol})") 117 | """ 118 | if kwargs_figure is None: 119 | kwargs_figure = {} 120 | if kwargs_gridspec is None: 121 | kwargs_gridspec = {} 122 | if kwargs_subplots is None: 123 | kwargs_subplots = {} 124 | fig = plt.figure(**kwargs_figure) 125 | grid = fig.add_gridspec(rows, cols, **kwargs_gridspec) 126 | axs = np.empty((rows, cols), dtype=Axes) 127 | for irow in range(rows): 128 | for icol in range(cols): 129 | axs[irow, icol] = fig.add_subplot(grid[irow, icol], **kwargs_subplots) 130 | return fig, axs 131 | 132 | 133 | def create_inset_axes_grid( 134 | ax: Axes, 135 | rows: int, 136 | cols: int, 137 | height: float = 0.5, 138 | width: float = 0.5, 139 | kwargs_inset_axes: Optional[dict] = None, 140 | ) -> NDArray: 141 | r"""Create a grid of insets. 142 | 143 | The input axis will be overlaid with a grid of insets. 144 | Numbering of the axes is top to bottom (rows) and 145 | left to right (cols). 146 | 147 | Parameters 148 | ---------- 149 | ax : :obj:`Axes ` 150 | Input axis. 151 | rows : :obj:`int` 152 | Number of rows. 153 | cols : :obj:`int` 154 | Number of columns. 155 | width : :obj:`float`, optional 156 | Width of each axis, as a percentage of ``cols``, by default 0.5. 157 | height : :obj:`float`, optional 158 | Height of each axis, as a percentage of ``rows``, by default 0.5. 159 | kwargs_inset_axes : ``Optional[dict]``, optional 160 | Arguments to be passed to :obj:`matplotlib.axes.Axes.inset_axes`. 161 | 162 | Returns 163 | ------- 164 | :obj:`NDArray ` 165 | Array of :obj:`Axes ` shaped ``(rows, cols)``. 166 | 167 | Examples 168 | -------- 169 | >>> import matplotlib.pyplot as plt 170 | >>> import numpy as np 171 | >>> from curvelops.plot import create_inset_axes_grid 172 | >>> fig, ax = plt.subplots(figsize=(6, 6)) 173 | >>> ax.imshow([[0]], extent=[-2, 2, 2, -2], vmin=-1, vmax=1, cmap="gray") 174 | >>> rows, cols = 2, 3 175 | >>> inset_axes = create_inset_axes_grid( 176 | >>> ax, 177 | >>> rows, 178 | >>> cols, 179 | >>> width=0.5, 180 | >>> height=0.5, 181 | >>> kwargs_inset_axes=dict(projection="polar"), 182 | >>> ) 183 | >>> nscales = 4 184 | >>> lw = 0.1 185 | >>> for irow in range(rows): 186 | >>> for icol in range(cols): 187 | >>> for iscale in range(1, nscales): 188 | >>> inset_axes[irow][icol].bar( 189 | >>> x=0, 190 | >>> height=lw, 191 | >>> width=2 * np.pi, 192 | >>> bottom=((iscale + 1) - 0.5 * lw) / (nscales - 1), 193 | >>> color="r", 194 | >>> ) 195 | >>> inset_axes[irow][icol].set(title=f"Row, Col: ({irow}, {icol})") 196 | >>> inset_axes[irow][icol].axis("off") 197 | """ 198 | if kwargs_inset_axes is None: 199 | kwargs_inset_axes = {} 200 | 201 | axes = np.empty((rows, cols), dtype=object) 202 | 203 | xmin, xmax = ax.get_xlim() 204 | ymin, ymax = ax.get_ylim() 205 | xmin, xmax = min(xmin, xmax), max(xmin, xmax) 206 | ymin, ymax = min(ymin, ymax), max(ymin, ymax) 207 | 208 | width *= (xmax - xmin) / cols 209 | height *= (ymax - ymin) / rows 210 | 211 | for irow, rowpos in enumerate(_create_range(ymin, ymax, rows)): 212 | for icol, colpos in enumerate(_create_range(xmin, xmax, cols)): 213 | axes[irow, icol] = ax.inset_axes( 214 | [colpos - 0.5 * width, rowpos - 0.5 * height, width, height], 215 | transform=ax.transData, 216 | **kwargs_inset_axes, 217 | ) 218 | return axes 219 | 220 | 221 | def overlay_arrows( 222 | vectors: NDArray, ax: Axes, arrowprops: Optional[dict] = None 223 | ) -> None: 224 | r"""Overlay arrows on an axis. 225 | 226 | Parameters 227 | ---------- 228 | vectors : :obj:`NDArray ` 229 | Array shaped ``(rows, cols, 2)``, corresponding to a 2D vector field. 230 | ax : :obj:`Axes ` 231 | Axis on which to overlay the arrows. 232 | arrowprops : ``Optional[dict]``, optional 233 | Arrow properties, to be passed to :obj:`matplotlib.pyplot.annotate`. 234 | By default will be set to ``dict(facecolor="black", shrink=0.05)``. 235 | 236 | Examples 237 | -------- 238 | >>> import matplotlib.pyplot as plt 239 | >>> import numpy as np 240 | >>> from curvelops.plot import overlay_arrows 241 | >>> fig, ax = plt.subplots(figsize=(8, 10)) 242 | >>> ax.imshow([[0]], vmin=-1, vmax=1, extent=[0, 1, 1, 0], cmap="gray") 243 | >>> rows, cols = 3, 4 244 | >>> kvecs = np.array( 245 | >>> [ 246 | >>> [(1 + x, x * y) for x in (0.5 + np.arange(cols)) / cols] 247 | >>> for y in (0.5 + np.arange(rows)) / rows 248 | >>> ] 249 | >>> ) 250 | >>> overlay_arrows( 251 | >>> 0.05 * kvecs, 252 | >>> ax, 253 | >>> arrowprops=dict( 254 | >>> facecolor="r", 255 | >>> shrink=0.05, 256 | >>> width=10 / cols, 257 | >>> headwidth=10, 258 | >>> headlength=10, 259 | >>> ), 260 | >>> ) 261 | """ 262 | rows, cols, _ = vectors.shape 263 | 264 | xmin, xmax = ax.get_xlim() 265 | ymin, ymax = ax.get_ylim() 266 | xmin, xmax = min(xmin, xmax), max(xmin, xmax) 267 | ymin, ymax = min(ymin, ymax), max(ymin, ymax) 268 | 269 | if arrowprops is None: 270 | arrowprops = dict(facecolor="black", shrink=0.05) 271 | 272 | for irow, rowpos in enumerate(_create_range(ymin, ymax, rows)): 273 | for icol, colpos in enumerate(_create_range(xmin, xmax, cols)): 274 | ax.annotate( 275 | "", 276 | xy=( 277 | colpos + vectors[irow, icol, 0], 278 | rowpos + vectors[irow, icol, 1], 279 | ), 280 | xytext=(colpos, rowpos), 281 | xycoords="data", 282 | arrowprops=arrowprops, 283 | annotation_clip=False, 284 | ) 285 | -------------------------------------------------------------------------------- /curvelops/typing/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``curvelops.typing`` 3 | ==================== 4 | 5 | Typing submodule. 6 | """ 7 | 8 | from . import _typing 9 | 10 | __all__ = _typing.__all__.copy() 11 | 12 | 13 | from ._typing import * 14 | -------------------------------------------------------------------------------- /curvelops/typing/_typing.py: -------------------------------------------------------------------------------- 1 | __all__ = ["FDCTStructLike", "RecursiveListNDArray"] 2 | from typing import List, Union 3 | 4 | from numpy.typing import NDArray 5 | 6 | FDCTStructLike = List[List[NDArray]] 7 | RecursiveListNDArray = Union[List[NDArray], List["RecursiveListNDArray"]] 8 | -------------------------------------------------------------------------------- /curvelops/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``curvelops.utils`` 3 | =================== 4 | 5 | Utility functions for processing curvelets. 6 | """ 7 | 8 | from . import _utils 9 | 10 | __all__ = _utils.__all__.copy() 11 | 12 | 13 | from ._utils import * 14 | -------------------------------------------------------------------------------- /curvelops/utils/_utils.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "array_split_nd", 3 | "split_nd", 4 | "apply_along_wedges", 5 | "energy", 6 | "energy_split", 7 | "ndargmax", 8 | ] 9 | from typing import Callable, List, TypeVar 10 | 11 | import numpy as np 12 | from numpy.typing import NDArray 13 | 14 | from ..typing._typing import FDCTStructLike, RecursiveListNDArray 15 | 16 | 17 | def array_split_nd(ary: NDArray, *args: int) -> RecursiveListNDArray: 18 | r"""Split an array into multiple sub-arrays recursively, possibly unevenly. 19 | 20 | See Also 21 | -------- 22 | :obj:`numpy.array_split` : Split an array into multiple sub-arrays. 23 | 24 | :obj:`split_nd`: Evenly split an array into multiple sub-arrays recursively. 25 | 26 | Parameters 27 | ---------- 28 | ary : :obj:`NDArray ` 29 | Input array. 30 | 31 | args : :obj:`int`, optional 32 | Number of splits for each axis of `ary`. 33 | Axis 0 will be split into `args[0]` subarrays, axis 1 will be 34 | into `args[1]` subarrays, etc. An axis of length 35 | `l = ary.shape[axis]` that should be split into `n = args[axis]` 36 | sections, will return `l % n` sub-arrays of size `l//n + 1` 37 | and the rest of size `l//n`. 38 | 39 | Returns 40 | ------- 41 | :obj:`RecursiveListNDArray ` 42 | Recursive lists of lists of :obj:`NDArray `. 43 | The number of recursions is equivalent to the number arguments in args. 44 | 45 | Examples 46 | -------- 47 | >>> from curvelops.utils import array_split_nd 48 | >>> ary = np.outer(1 + np.arange(2), 2 + np.arange(3)) 49 | array([[2, 3, 4], 50 | [4, 6, 8]]) 51 | >>> array_split_nd(ary, 2, 3) 52 | [[array([[2]]), array([[3]]), array([[4]])], 53 | [array([[4]]), array([[6]]), array([[8]])]] 54 | 55 | >>> from curvelops.utils import array_split_nd 56 | >>> ary = np.outer(np.arange(3), np.arange(5)) 57 | >>> array_split_nd(ary, 2, 3) 58 | [[array([[0, 0], 59 | [0, 1]]), 60 | array([[0, 0], 61 | [2, 3]]), 62 | array([[0], 63 | [4]])], 64 | [array([[0, 2]]), array([[4, 6]]), array([[8]])]] 65 | """ 66 | axis = ary.ndim - len(args) 67 | split = np.array_split(ary, args[0], axis=axis) 68 | if len(args) == 1: 69 | return split 70 | return [array_split_nd(s, *args[1:]) for s in split] 71 | 72 | 73 | def split_nd(ary: NDArray, *args: int) -> RecursiveListNDArray: 74 | r"""Evenly split an array into multiple sub-arrays recursively. 75 | 76 | See Also 77 | -------- 78 | :obj:`numpy.split` : Split an array into multiple sub-arrays. 79 | 80 | :obj:`array_split_nd`: Split an array into multiple sub-arrays recursively, possibly unevenly. 81 | 82 | 83 | Parameters 84 | ---------- 85 | ary : :obj:`NDArray ` 86 | Input array. 87 | 88 | args : :obj:`int`, optional 89 | Number of splits for each axis of `ary`. 90 | Axis 0 will be split into `args[0]` subarrays, axis 1 will be 91 | into `args[1]` subarrays, etc. If the split cannot be made even 92 | for all dimensions, raises an error. 93 | 94 | Returns 95 | ------- 96 | :obj:`RecursiveListNDArray ` 97 | Recursive lists of lists of :obj:`NDArray `. 98 | The number of recursions is equivalent to the number arguments in args. 99 | 100 | Examples 101 | -------- 102 | >>> from curvelops.utils import split_nd 103 | >>> ary = np.outer(1 + np.arange(2), 2 + np.arange(3)) 104 | array([[2, 3, 4], 105 | [4, 6, 8]]) 106 | >>> split_nd(ary, 2, 3) 107 | [[array([[2]]), array([[3]]), array([[4]])], 108 | [array([[4]]), array([[6]]), array([[8]])]] 109 | 110 | >>> from curvelops.utils import split_nd 111 | >>> ary = np.outer(np.arange(3), np.arange(5)) 112 | >>> split_nd(ary, 2, 3) 113 | ValueError: array split does not result in an equal division 114 | """ 115 | axis = ary.ndim - len(args) 116 | split = np.split(ary, args[0], axis=axis) 117 | if len(args) == 1: 118 | return split 119 | return [split_nd(s, *args[1:]) for s in split] 120 | 121 | 122 | T = TypeVar("T") 123 | 124 | 125 | def apply_along_wedges( 126 | c_struct: FDCTStructLike, fun: Callable[[NDArray, int, int, int, int], T] 127 | ) -> List[List[T]]: 128 | """Applies a function to each individual wedge. 129 | 130 | Parameters 131 | ---------- 132 | c_struct : :obj:`FDCTStructLike ` 133 | Input curvelet coefficients in struct format. 134 | fun : Callable[[:obj:`NDArray `, :obj:`int`, :obj:`int`, :obj:`int`, :obj:`int`], T] 135 | Function to apply to each individual wedge. The function's arguments 136 | are respectively: `wedge`, `wedge index in scale`, `scale index`, `number of 137 | wedges in scale`, `number of scales`. 138 | 139 | Returns 140 | ------- 141 | List[List[T]] 142 | Struct containing the result of applying `fun` to each wedge. 143 | 144 | Examples 145 | -------- 146 | >>> import numpy as np 147 | >>> from curvelops import FDCT2D 148 | >>> from curvelops.utils import apply_along_wedges 149 | >>> x = np.zeros((32, 32)) 150 | >>> C = FDCT2D(x.shape, nbscales=3, nbangles_coarse=8, allcurvelets=False) 151 | >>> y = C.struct(C @ x) 152 | >>> apply_along_wedges(y, lambda w, *_: w.shape) 153 | [[(11, 11)], 154 | [(23, 11), 155 | (23, 11), 156 | (11, 23), 157 | (11, 23), 158 | (23, 11), 159 | (23, 11), 160 | (11, 23), 161 | (11, 23)], 162 | [(32, 32)]] 163 | """ 164 | mapped_struct: List[List[T]] = [[] for _ in c_struct] 165 | for iscale, c_angles in enumerate(c_struct): 166 | mapped_struct[iscale] = [] 167 | for iwedge, c_wedge in enumerate(c_angles): 168 | out = fun(c_wedge, iwedge, iscale, len(c_angles), len(c_struct)) 169 | mapped_struct[iscale].append(out) 170 | return mapped_struct 171 | 172 | 173 | def energy(ary: NDArray) -> float: 174 | r"""Computes the energy of an n-dimensional wedge. 175 | 176 | The energy of a vector (flattened n-dimensional array) 177 | :math:`(a_0,\ldots,a_{N-1})` is defined as 178 | 179 | .. math:: 180 | 181 | \sqrt{\frac{1}{N}\sum\limits_{i=0}^{N-1} |a_i|^2}. 182 | 183 | Parameters 184 | ---------- 185 | ary : :obj:`NDArray ` 186 | Input wedge. 187 | 188 | Returns 189 | ------- 190 | :obj:`float` 191 | Energy. 192 | """ 193 | return np.sqrt((ary.real**2 + ary.imag**2).sum() / ary.size) 194 | 195 | 196 | def energy_split(ary: NDArray, rows: int, cols: int) -> NDArray: 197 | """Splits a wedge into ``(rows, cols)`` wedges and computes the energy 198 | of each of these subdivisions. 199 | 200 | See Also 201 | -------- 202 | 203 | :obj:`energy` : Computes the energy of a wedge. 204 | 205 | Parameters 206 | ---------- 207 | ary : :obj:`NDArray ` 208 | Input wedge. 209 | rows : :obj:`int` 210 | Split axis 0 into `rows` subdivisions. 211 | cols : :obj:`int` 212 | Split axis 1 into `cols` subdivisions. 213 | 214 | Returns 215 | ------- 216 | :obj:`NDArray ` 217 | Matrix of shape ``(rows, cols)`` containing the energy of each 218 | subdivision of the input wedge. 219 | """ 220 | norm_local = np.empty((rows, cols), dtype=float) 221 | split = array_split_nd(ary, rows, cols) 222 | for irow in range(rows): 223 | for icol in range(cols): 224 | norm_local[irow, icol] = energy(split[irow][icol]) 225 | return norm_local 226 | 227 | 228 | def ndargmax(ary: NDArray) -> tuple: 229 | """N-dimensional argmax of array. 230 | 231 | Parameters 232 | ---------- 233 | ary : :obj:`NDArray ` 234 | Input array 235 | 236 | Examples 237 | -------- 238 | >>> import numpy as np 239 | >>> from curvelops.utils import ndargmax 240 | >>> x = np.zeros((10, 10, 10)) 241 | >>> x[1, 1, 1] = 1.0 242 | >>> ndargmax(x) 243 | (1, 1, 1) 244 | 245 | Returns 246 | ------- 247 | tuple 248 | N-dimensional index of the maximum of ``ary``. 249 | """ 250 | return np.unravel_index(ary.argmax(), ary.shape) 251 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PyLops/curvelops/d6dc5fde8bcf399e57f81c5beb38449eb8863e45/docs/.nojekyll -------------------------------------------------------------------------------- /docssrc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # Disable numba 5 | # export NUMBA_DISABLE_JIT=1 6 | 7 | # You can set these variables from the command line. 8 | SPHINXOPTS = 9 | SPHINXBUILD = sphinx-build 10 | SPHINXPROJ = curvelops 11 | SOURCEDIR = source 12 | BUILDDIR = build 13 | 14 | # Put it first so that "make" without argument is like "make help". 15 | help: 16 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 17 | 18 | .PHONY: help Makefile 19 | 20 | # Catch-all target: route all unknown targets to Sphinx using the new 21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 22 | %: Makefile 23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 24 | 25 | # Make for github pages 26 | github: 27 | @make html 28 | @cp -a build/html/. ../docs 29 | -------------------------------------------------------------------------------- /docssrc/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | # import os 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | 19 | from sphinx_gallery.sorting import ExampleTitleSortKey 20 | 21 | from curvelops import __version__ as version 22 | 23 | release = version 24 | 25 | # -- Project information ----------------------------------------------------- 26 | 27 | project = "curvelops" 28 | copyright = "2020-2023, Carlos Alberto da Costa Filho" 29 | author = "Carlos Alberto da Costa Filho" 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | "sphinx.ext.autodoc", 43 | "sphinx.ext.doctest", 44 | "sphinx.ext.mathjax", 45 | "sphinx.ext.ifconfig", 46 | "sphinx.ext.viewcode", 47 | "sphinx.ext.intersphinx", 48 | "sphinx_gallery.gen_gallery", 49 | "sphinx_copybutton", 50 | ] 51 | 52 | # intersphinx configuration 53 | intersphinx_mapping = { 54 | "python": ("https://docs.python.org/3/", None), 55 | "numpy": ("https://docs.scipy.org/doc/numpy/", None), 56 | "matplotlib": ("https://matplotlib.org/", None), 57 | } 58 | 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = ["_templates"] 61 | 62 | # The suffix(es) of source filenames. 63 | # You can specify multiple suffix as a list of string: 64 | # 65 | # source_suffix = ['.rst', '.md'] 66 | source_suffix = ".rst" 67 | 68 | # The master toctree document. 69 | master_doc = "index" 70 | 71 | # The language for content autogenerated by Sphinx. Refer to documentation 72 | # for a list of supported languages. 73 | # 74 | # This is also used if you do content translation via gettext catalogs. 75 | # Usually you set "language" from the command line for these cases. 76 | language = "en" 77 | 78 | # List of patterns, relative to source directory, that match files and 79 | # directories to ignore when looking for source files. 80 | # This pattern also affects html_static_path and html_extra_path. 81 | exclude_patterns = [] 82 | 83 | # The name of the Pygments (syntax highlighting) style to use. 84 | pygments_style = None 85 | 86 | 87 | # -- Options for HTML output ------------------------------------------------- 88 | 89 | # The theme to use for HTML and HTML Help pages. See the documentation for 90 | # a list of builtin themes. 91 | # 92 | html_theme = "pydata_sphinx_theme" 93 | 94 | # Theme options are theme-specific and customize the look and feel of a theme 95 | # further. For a list of options available for each theme, see the 96 | # documentation. 97 | # 98 | # html_theme_options = {} 99 | 100 | # Add any paths that contain custom static files (such as style sheets) here, 101 | # relative to this directory. They are copied after the builtin static files, 102 | # so a file named "default.css" will overwrite the builtin "default.css". 103 | html_static_path = ["_static"] 104 | 105 | # Custom sidebar templates, must be a dictionary that maps document names 106 | # to template names. 107 | # 108 | # The default sidebars (for documents that don't match any pattern) are 109 | # defined by theme itself. Builtin themes are using these templates by 110 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 111 | # 'searchbox.html']``. 112 | # 113 | # html_sidebars = {} 114 | 115 | 116 | # -- Options for HTMLHelp output --------------------------------------------- 117 | 118 | # Output file base name for HTML help builder. 119 | htmlhelp_basename = "curvelopsdoc" 120 | 121 | 122 | # -- Options for LaTeX output ------------------------------------------------ 123 | 124 | latex_elements = { 125 | # The paper size ('letterpaper' or 'a4paper'). 126 | # 127 | # 'papersize': 'letterpaper', 128 | # The font size ('10pt', '11pt' or '12pt'). 129 | # 130 | # 'pointsize': '10pt', 131 | # Additional stuff for the LaTeX preamble. 132 | # 133 | # 'preamble': '', 134 | # Latex figure (float) alignment 135 | # 136 | # 'figure_align': 'htbp', 137 | } 138 | 139 | # Grouping the document tree into LaTeX files. List of tuples 140 | # (source start file, target name, title, 141 | # author, documentclass [howto, manual, or own class]). 142 | latex_documents = [ 143 | ( 144 | master_doc, 145 | "curvelops.tex", 146 | "curvelops Documentation", 147 | "Carlos Alberto da Costa Filho", 148 | "manual", 149 | ), 150 | ] 151 | 152 | 153 | # -- Options for manual page output ------------------------------------------ 154 | 155 | # One entry per manual page. List of tuples 156 | # (source start file, name, description, authors, manual section). 157 | man_pages = [(master_doc, "curvelops", "curvelops Documentation", [author], 1)] 158 | 159 | 160 | # -- Options for Texinfo output ---------------------------------------------- 161 | 162 | # Grouping the document tree into Texinfo files. List of tuples 163 | # (source start file, target name, title, author, 164 | # dir menu entry, description, category) 165 | texinfo_documents = [ 166 | ( 167 | master_doc, 168 | "curvelops", 169 | "curvelops Documentation", 170 | author, 171 | "curvelops", 172 | "One line description of project.", 173 | "Miscellaneous", 174 | ), 175 | ] 176 | 177 | 178 | # -- Options for Epub output ------------------------------------------------- 179 | 180 | # Bibliographic Dublin Core info. 181 | epub_title = project 182 | 183 | # The unique identifier of the text. This can be a ISBN number 184 | # or the project homepage. 185 | # 186 | # epub_identifier = '' 187 | 188 | # A unique identification for the text. 189 | # 190 | # epub_uid = '' 191 | 192 | # A list of files that should not be packed into the epub file. 193 | epub_exclude_files = ["search.html"] 194 | 195 | 196 | # -- Extension configuration ------------------------------------------------- 197 | autodoc_typehints = "none" 198 | 199 | sphinx_gallery_conf = { 200 | "examples_dirs": "../../examples", # path to your example scripts 201 | "gallery_dirs": "gallery", # path to where to save gallery generated output 202 | "filename_pattern": r"\.py", 203 | # Remove the "Download all examples" button from the top level gallery 204 | "download_all_examples": False, 205 | # Sort gallery example by file name instead of number of lines (default) 206 | "within_subsection_order": ExampleTitleSortKey, 207 | # directory where function granular galleries are stored 208 | "backreferences_dir": "api/generated/backreferences", 209 | # Modules for which function level galleries are created. 210 | "doc_module": "curvelops", 211 | # Insert links to documentation of objects in the examples 212 | "reference_url": {"curvelops": None}, 213 | } 214 | # Always show the source code that generates a plot 215 | plot_include_source = True 216 | plot_formats = ["png"] 217 | # Sphinx project configuration 218 | templates_path = ["_templates"] 219 | exclude_patterns = ["_build", "**.ipynb_checkpoints", "**.ipynb", "**.md5"] 220 | source_suffix = ".rst" 221 | 222 | # Copybutton config 223 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " 224 | copybutton_prompt_is_regexp = True 225 | 226 | # Pydata config 227 | html_theme_options = { 228 | "github_url": "https://github.com/PyLops/curvelops", 229 | "external_links": [{"url": "https://github.com/PyLops/pylops", "name": "PyLops"}], 230 | "header_links_before_dropdown": 10, 231 | "show_toc_level": 2, 232 | } 233 | html_context = { 234 | "github_user": "PyLops", 235 | "github_repo": "curvelops", 236 | "github_version": "main", 237 | "doc_path": "docssrc", 238 | } 239 | -------------------------------------------------------------------------------- /docssrc/source/contributing.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | Contributions are welcome! Please submit your pull-request, issue or comment 5 | in the `GitHub repo `__. You are also 6 | welcome to join the `PyLops slack channel `__. 7 | 8 | Installation for developers 9 | --------------------------- 10 | 11 | Developers should clone the 12 | `main `__ branch of the 13 | repository and install the dev requiments: 14 | 15 | .. code-block:: console 16 | 17 | $ git clone https://github.com/PyLops/curvelops 18 | $ git remote add upstream https://github.com/PyLops/curvelops 19 | $ make dev-install 20 | 21 | They should then follow the same instructions in the :ref:`Requirements` 22 | section. We recommend installing dependencies into a separate environment. 23 | Finally, they can install Curvelops with 24 | 25 | .. code-block:: console 26 | 27 | $ python3 -m pip install -e . 28 | 29 | Developers should also install `pre-commit `__ hooks with 30 | 31 | .. code-block:: console 32 | 33 | $ pre-commit install 34 | 35 | 36 | Developer workflow 37 | ------------------ 38 | 39 | Developers should start from a fresh copy of main with 40 | 41 | .. code-block:: console 42 | 43 | $ git checkout main 44 | $ git pull upstream main 45 | 46 | Before you start making changes, create a new branch with 47 | 48 | .. code-block:: console 49 | 50 | $ git checkout -b patch-some-cool-feature 51 | 52 | After implementing your cool feature (including tests 🤩), commit your changes 53 | to kick-off the pre-commit hooks. These will reject and "fix" your code by 54 | running the proper hooks. At this point, the user must check the changes and 55 | then stage them before trying to commit again. 56 | 57 | Once changes are committed, we encourage developers to lint, check types and 58 | build/check documentation with: 59 | 60 | .. code-block:: console 61 | 62 | $ make tests 63 | $ make lint 64 | $ make typeannot 65 | $ make coverage 66 | $ make doc 67 | $ make servedoc 68 | 69 | Once everything is in order, and your code has been pushed to GitHub, 70 | navigate to https://github.com/PyLops/curvelops and submit your PR! 71 | -------------------------------------------------------------------------------- /docssrc/source/index.rst: -------------------------------------------------------------------------------- 1 | .. curvelops documentation master file, created by 2 | sphinx-quickstart on Sun Nov 15 14:04:06 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Overview 7 | ======== 8 | 9 | Curvelops is part of the PyLops ecossystem, an open-source Python library 10 | focused on providing a backend-agnostic, idiomatic, matrix-free library of 11 | linear operators and related computations. Curvelops provides 2D and 3D 12 | Curvelet transforms via `CurveLab `__. 13 | 14 | Visit :ref:`Installation` and then get started with the 15 | `Gallery `__ or browse the 16 | :ref:`API`. 17 | 18 | 19 | .. attention:: 20 | `CurveLab `__ is a proprietary library which must be 21 | sourced independently by the user. It is free for academic use. Curvelops 22 | contains no CurveLab code apart from function calls. 23 | 24 | .. note:: 25 | All CurveLab rights are reserved to Emmanuel Candes, Laurent Demanet, David 26 | Donoho and Lexing Ying. PyLops and Curvelops are not affiliated with 27 | CurveLab or its authors in any way. 28 | 29 | .. toctree:: 30 | :maxdepth: 1 31 | :hidden: 32 | 33 | self 34 | installation.rst 35 | modules.rst 36 | gallery/index.rst 37 | contributing.rst 38 | -------------------------------------------------------------------------------- /docssrc/source/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | Installation 4 | ============ 5 | 6 | .. _requirements: 7 | 8 | Requirements 9 | ------------ 10 | 11 | Installing Curvelops requires the following external components: 12 | 13 | * `FFTW `_ 2.1.5 14 | * `CurveLab `_ >= 2.0.2 15 | 16 | Both of these packages must be installed manually. 17 | 18 | Installing FFTW 19 | ~~~~~~~~~~~~~~~ 20 | Download and install with: 21 | 22 | 23 | .. code-block:: console 24 | 25 | $ wget https://www.fftw.org/fftw-2.1.5.tar.gz 26 | $ tar xvzf fftw-2.1.5.tar.gz 27 | $ mkdir -p /home/$USER/opt/ 28 | $ mv fftw-2.1.5/ /home/$USER/opt/ 29 | $ cd /home/$USER/opt/fftw-2.1.5/ 30 | $ ./configure --with-pic --prefix=/home/$USER/opt/fftw-2.1.5 --with-gcc=$(which gcc) 31 | $ make 32 | $ make install 33 | 34 | The ``--prefix`` and ``--with-gcc`` are optional and determine where it will 35 | install FFTW and where to find the GCC compiler, respectively. We recommend 36 | using the same compiler for FFTW and CurveLab. To ensure that FFTW has been 37 | installed correctly, run 38 | 39 | .. code-block:: console 40 | 41 | $ make check 42 | 43 | 44 | Installing CurveLab 45 | ~~~~~~~~~~~~~~~~~~~ 46 | After downloading the latest version of CurveLab, run 47 | 48 | .. code-block:: console 49 | 50 | $ tar xvzf CurveLab-2.1.3.tar.gz 51 | $ mkdir -p /home/$USER/opt/ 52 | $ mv CurveLab-2.1.3/ /home/$USER/opt/ 53 | $ cd /home/$USER/opt/CurveLab-2.1.3/ 54 | $ cp makefile.opt makefile.opt.bak 55 | 56 | In the file ``makefile.opt`` set ``FFTW_DIR``, ``CC`` and ``CXX`` variables. 57 | We recommend setting ``FFTW_DIR=/home/$USER/opt/fftw-2.1.5`` 58 | (or whatever directory was used in the ``--prefix`` option above), the output 59 | of ``which gcc`` in CC (or whatever compiler was used in ``--with-gcc``), and 60 | the ouput of ``which g++`` (or whatever C++ compiler is the equivalent of 61 | the selected ``CC`` compiler). Once the variables are set in `makefile.opt`, 62 | compile the library with 63 | 64 | .. code-block:: console 65 | 66 | $ cd /home/$USER/opt/CurveLab-2.1.3/ 67 | $ make clean 68 | $ make lib 69 | 70 | To ensure that CurveLab is installed correctly, run 71 | 72 | .. code-block:: console 73 | 74 | $ make test 75 | 76 | Installing Curvelops 77 | -------------------- 78 | 79 | Once FFTW and CurveLab are installed, install Curvelops with: 80 | 81 | .. code-block:: console 82 | 83 | $ export FFTW=/path/to/fftw-2.1.5 84 | $ export FDCT=/path/to/CurveLab-2.1.3 85 | $ python3 -m pip install git+https://github.com/PyLops/curvelops@0.23.4 86 | 87 | The ``FFTW`` variable is the same as ``FFTW_DIR`` as provided in the CurveLab 88 | installation. The ``FDCT`` variable points to the root of the CurveLab 89 | installation. 90 | -------------------------------------------------------------------------------- /docssrc/source/modules.rst: -------------------------------------------------------------------------------- 1 | .. _API: 2 | 3 | API 4 | === 5 | 6 | curvelops 7 | --------- 8 | 9 | .. automodule:: curvelops.curvelops 10 | :members: 11 | :undoc-members: 12 | :show-inheritance: 13 | 14 | plot 15 | ---- 16 | 17 | .. automodule:: curvelops.plot 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | typing 23 | ------ 24 | 25 | .. automodule:: curvelops.typing 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | utils 31 | ----- 32 | 33 | .. automodule:: curvelops.utils 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | -------------------------------------------------------------------------------- /docssrc/source/static/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PyLops/curvelops/d6dc5fde8bcf399e57f81c5beb38449eb8863e45/docssrc/source/static/demo.png -------------------------------------------------------------------------------- /docssrc/source/static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PyLops/curvelops/d6dc5fde8bcf399e57f81c5beb38449eb8863e45/docssrc/source/static/logo.png -------------------------------------------------------------------------------- /docssrc/source/static/reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PyLops/curvelops/d6dc5fde8bcf399e57f81c5beb38449eb8863e45/docssrc/source/static/reconstruction.png -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | Gallery 2 | ------- 3 | 4 | Below is a gallery of examples using curvelops. 5 | -------------------------------------------------------------------------------- /examples/plot_curvelets_in_fk.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 4. Curvelet Coefficients in the FK domain 3 | ========================================= 4 | This example shows the regions in the FK domain where each 5 | curvelet coefficient occupies. 6 | """ 7 | # sphinx_gallery_thumbnail_number = 5 8 | 9 | # %% 10 | import matplotlib as mpl 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | 14 | from curvelops import FDCT2D 15 | 16 | # %% 17 | # Setup 18 | # ===== 19 | 20 | # %% 21 | nx, nz = 301, 201 22 | data_empty = np.zeros((nx, nz)) 23 | 24 | # %% 25 | nbscales = 4 26 | nbangles_coarse = 8 27 | allcurvelets = False 28 | 29 | # %% 30 | Cop = FDCT2D( 31 | data_empty.shape, 32 | nbscales=nbscales, 33 | nbangles_coarse=nbangles_coarse, 34 | allcurvelets=allcurvelets, 35 | ) 36 | 37 | # %% 38 | empty_fdct = Cop @ data_empty 39 | 40 | # Convert to a curvelet struct indexed by 41 | # [scale, wedge (angle), z, x] 42 | empty_fdct_struct = Cop.struct(empty_fdct) 43 | 44 | # %% 45 | 46 | 47 | def create_dirac_wedge(Cop, scale, wedge): 48 | d = np.zeros(Cop.dims) 49 | wedge_only_fdct = Cop @ d 50 | 51 | wedge_only_fdct_struct = Cop.struct(wedge_only_fdct) 52 | normalization = np.sqrt(wedge_only_fdct_struct[scale][wedge].size) 53 | iz, ix = wedge_only_fdct_struct[scale][wedge].shape 54 | 55 | wedge_only_fdct_struct[scale][wedge][iz // 2, ix // 2] = normalization 56 | wedge_only_fdct = Cop.vect(wedge_only_fdct_struct) 57 | wedge_only = Cop.H @ wedge_only_fdct 58 | return wedge_only 59 | 60 | 61 | # %% 62 | # Plot Wedges of each Scale 63 | # ========================= 64 | 65 | # %% 66 | # Colormap to be used in all plots below 67 | fig, ax = plt.subplots(figsize=(6, 1)) 68 | col_map = plt.get_cmap("turbo") 69 | mpl.colorbar.ColorbarBase( 70 | ax, 71 | cmap=col_map, 72 | orientation="horizontal", 73 | norm=mpl.colors.Normalize(vmin=0, vmax=1), 74 | ) 75 | fig.tight_layout() 76 | 77 | # %% 78 | wedge_fk_abs = np.zeros_like(data_empty) 79 | for j, fdct_scale in enumerate(empty_fdct_struct, start=1): 80 | rows = int(np.floor(np.sqrt(len(fdct_scale)))) 81 | fig, axes = plt.subplots( 82 | int(np.ceil(len(fdct_scale) / rows)), 83 | rows, 84 | figsize=(5 * rows, 3 * rows), 85 | ) 86 | fig.suptitle( 87 | f"Scale {j} ({len(fdct_scale)} wedge{'s' if len(fdct_scale) > 1 else ''})" 88 | ) 89 | axes = np.atleast_1d(axes).ravel() 90 | wedge_scale_fk_abs = np.zeros_like(data_empty) 91 | for iw, (fdct_wedge, ax) in enumerate(zip(fdct_scale, axes), start=1): 92 | dirac_wedge = create_dirac_wedge(Cop, j - 1, iw - 1) 93 | dirac_wedge_fk = np.fft.fftshift( 94 | np.fft.fft2(np.fft.ifftshift(dirac_wedge), norm="ortho") 95 | ) 96 | wedge_scale_fk_abs += np.abs(dirac_wedge_fk) 97 | 98 | ax.imshow(np.abs(dirac_wedge_fk).T, cmap="turbo", vmin=0, vmax=1) 99 | if len(fdct_scale) > 1: 100 | ax.set(title=f"Wedge {iw}") 101 | ax.axis("off") 102 | fig.tight_layout() 103 | wedge_fk_abs += wedge_scale_fk_abs 104 | if len(fdct_scale) > 1: 105 | fig, ax = plt.subplots(figsize=(5, 3)) 106 | fig.suptitle(f"Scale {j} (sum of all wedges)") 107 | ax.imshow(wedge_scale_fk_abs.T, cmap="turbo", vmin=0, vmax=1) 108 | ax.axis("off") 109 | fig.tight_layout() 110 | 111 | fig, ax = plt.subplots(figsize=(5, 3)) 112 | fig.suptitle("Sum of all wedges of all scales)") 113 | ax.imshow(wedge_fk_abs.T, cmap="turbo", vmin=0, vmax=1) 114 | ax.axis("off") 115 | fig.tight_layout() 116 | 117 | # %% 118 | # Plot Dirac in Space domain 119 | # ========================== 120 | 121 | # %% 122 | dirac_all_fdct_struct = Cop.struct(empty_fdct.copy()) 123 | for fdct_scale in dirac_all_fdct_struct: 124 | for fdct_wedge in fdct_scale: 125 | normalization = np.sqrt(fdct_wedge.size) 126 | iz, ix = fdct_wedge.shape 127 | fdct_wedge[iz // 2, ix // 2] = normalization * (1 + 1j) 128 | fdct_wedge[iz // 2 + 1, ix // 2] = normalization * (1 + 1j) 129 | fdct_wedge[iz // 2, ix // 2 + 1] = normalization * (1 + 1j) 130 | fdct_wedge[iz // 2 + 1, ix // 2 + 1] = normalization * (1 + 1j) 131 | 132 | data_dirac = Cop.H @ Cop.vect(dirac_all_fdct_struct) 133 | data_dirac = (data_dirac.real + data_dirac.imag) / np.sqrt(2) 134 | vmax = 0.5 * np.sqrt(data_dirac.size) 135 | 136 | fig, ax = plt.subplots(figsize=(5, 3)) 137 | ax.imshow(data_dirac.T, cmap="gray", vmin=-vmax, vmax=vmax) 138 | ax.set( 139 | xlim=(nx // 2 - 30, nx // 2 + 30), 140 | ylim=(nz // 2 + 30, nz // 2 - 30), 141 | title="Space domain magnified", 142 | ) 143 | fig.tight_layout() 144 | -------------------------------------------------------------------------------- /examples/plot_seismic_regularization.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 5. Seismic Regularization 3 | ========================= 4 | This example shows how to use the Curvelet transform to 5 | condition a missing-data seismic regularization problem. 6 | """ 7 | # sphinx_gallery_thumbnail_number = 2 8 | 9 | # %% 10 | import warnings 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import pylops 17 | from pylops.optimization.sparsity import fista 18 | from scipy.signal import convolve 19 | 20 | from curvelops import FDCT2D 21 | 22 | np.random.seed(0) 23 | warnings.filterwarnings("ignore") 24 | 25 | # %% 26 | # Setup 27 | # ===== 28 | inputfile = "../testdata/seismic.npz" 29 | inputdata = np.load(inputfile) 30 | 31 | x = inputdata["R"][50, :, ::2] 32 | x = x / np.abs(x).max() 33 | taxis, xaxis = inputdata["t"][::2], inputdata["r"][0] 34 | 35 | par = {} 36 | par["nx"], par["nt"] = x.shape 37 | par["dx"] = inputdata["r"][0, 1] - inputdata["r"][0, 0] 38 | par["dt"] = inputdata["t"][1] - inputdata["t"][0] 39 | 40 | # Add wavelet 41 | wav = inputdata["wav"][::2] 42 | wav_c = np.argmax(wav) 43 | x = np.apply_along_axis(convolve, 1, x, wav, mode="full") 44 | x = x[:, wav_c:][:, : par["nt"]] 45 | 46 | # Gain 47 | gain = np.tile((taxis**2)[:, np.newaxis], (1, par["nx"])).T 48 | x *= gain 49 | 50 | # Subsampling locations 51 | perc_subsampling = 0.5 52 | Nsub = int(np.round(par["nx"] * perc_subsampling)) 53 | iava = np.sort(np.random.permutation(np.arange(par["nx"]))[:Nsub]) 54 | 55 | # Restriction operator 56 | Rop = pylops.Restriction((par["nx"], par["nt"]), iava, axis=0, dtype="float64") 57 | 58 | y = Rop @ x 59 | xadj = Rop.H @ y 60 | 61 | # Apply mask 62 | ymask = Rop.mask(x) 63 | 64 | # %% 65 | # Curvelet transform 66 | # ================== 67 | 68 | # %% 69 | DCTOp = FDCT2D((par["nx"], par["nt"]), nbscales=4) 70 | 71 | yc = DCTOp @ x 72 | xcadj = DCTOp.H @ yc 73 | 74 | # %% 75 | opts_plot = dict( 76 | cmap="gray", 77 | vmin=-0.1, 78 | vmax=0.1, 79 | extent=(xaxis[0], xaxis[-1], taxis[-1], taxis[0]), 80 | ) 81 | 82 | fig, axs = plt.subplots(1, 2, sharey=True, figsize=(10, 7)) 83 | axs[0].imshow(x.T, **opts_plot) 84 | axs[0].set_title("Data") 85 | axs[0].axis("tight") 86 | axs[1].imshow(np.real(xcadj).T, **opts_plot) 87 | axs[1].set_title("Adjoint curvelet") 88 | axs[1].axis("tight") 89 | 90 | # %% 91 | # Reconstruction based on Curvelet transform 92 | # ########################################## 93 | 94 | # %% 95 | # Combined modelling operator 96 | RCop = Rop @ DCTOp.H 97 | RCop.dims = (RCop.shape[1],) # flatten 98 | RCop.dimsd = (RCop.shape[0],) 99 | 100 | # Inverse 101 | pl1, _, cost = fista(RCop, y.ravel(), niter=100, eps=1e-3, show=True) 102 | xl1 = (DCTOp.H @ pl1).real.reshape(x.shape) 103 | 104 | # %% 105 | fig, axs = plt.subplots(1, 4, sharey=True, figsize=(16, 7)) 106 | axs[0].imshow(x.T, **opts_plot) 107 | axs[0].set_title("Data") 108 | axs[0].axis("tight") 109 | axs[1].imshow(ymask.T, **opts_plot) 110 | axs[1].set_title("Masked data") 111 | axs[1].axis("tight") 112 | axs[2].imshow(xl1.T, **opts_plot) 113 | axs[2].set_title("Reconstructed data") 114 | axs[2].axis("tight") 115 | axs[3].imshow((x - xl1).T, **opts_plot) 116 | axs[3].set_title("Reconstruction error") 117 | axs[3].axis("tight") 118 | 119 | # %% 120 | fig, ax = plt.subplots(figsize=(16, 2)) 121 | ax.plot(range(1, len(cost) + 1), cost, "k") 122 | ax.set(xlim=[1, len(cost)]) 123 | fig.suptitle("FISTA convergence") 124 | -------------------------------------------------------------------------------- /examples/plot_sigmoid.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 2. Sigmoid Example 3 | ================== 4 | This example shows the effectiveness of curvelets in describing a typical 5 | subsurface structure. It compares the Curvelet transform with the Wavelet 6 | and Seislet transforms. 7 | """ 8 | # sphinx_gallery_thumbnail_number = 3 9 | 10 | # %% 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pylops 14 | 15 | from curvelops import FDCT2D 16 | 17 | try: 18 | # Progress bars 19 | from tqdm.notebook import tqdm 20 | except ImportError: 21 | 22 | def tqdm(x): 23 | return x 24 | 25 | print("Try out tqdm for progress bars!") 26 | 27 | # %% 28 | # Input data 29 | # ========== 30 | 31 | # %% 32 | inputfile = "../testdata/sigmoid.npz" 33 | 34 | d = np.load(inputfile) 35 | d = d["sigmoid"] 36 | nx, nt = d.shape 37 | dx, dt = 8, 0.004 38 | x, t = np.arange(nx) * dx, np.arange(nt) * dt 39 | 40 | # %% 41 | clip = 0.5 * np.max(np.abs(d)) 42 | opts = dict( 43 | aspect="auto", 44 | extent=(x[0], x[-1], t[-1], t[0]), 45 | vmin=-clip, 46 | vmax=clip, 47 | cmap="gray", 48 | interpolation="nearest", 49 | ) 50 | 51 | fig, ax = plt.subplots(figsize=(8, 6), sharey=True, sharex=True) 52 | ax.imshow(d.T, **opts) 53 | ax.set(xlabel="Position [m]", ylabel="Time [s]", title="Data") 54 | fig.tight_layout() 55 | 56 | # %% 57 | # Sparsifying Transforms 58 | # ====================== 59 | # * Seislet 60 | # * Wavelet 61 | # * Curvelet 62 | 63 | # %% 64 | 65 | # Seislet 66 | slope = -pylops.utils.signalprocessing.slope_estimate(d.T, dt, dx, smooth=6)[0] 67 | Sop = pylops.signalprocessing.Seislet(slope.T, sampling=(dx, dt)) 68 | Sop.shape 69 | 70 | # %% 71 | 72 | # Wavelet 73 | Wop = pylops.signalprocessing.Seislet(np.zeros_like(slope.T), sampling=(dx, dt)) 74 | Wop.shape 75 | 76 | # %% 77 | 78 | # Curvelet 79 | Cop = FDCT2D(d.shape) 80 | Cop.shape 81 | 82 | # %% 83 | 84 | 85 | def reconstruct(data, op, perc=0.1): 86 | """ 87 | Convenience function to calculate reconstruction using top 88 | `perc` percent of coefficients of a given operator `op`. 89 | """ 90 | y = op * data.ravel() 91 | denoise = np.zeros_like(y) 92 | 93 | # Order coefficients by strength 94 | strong_idx = np.argsort(-np.abs(y)) 95 | strong = np.abs(y)[strong_idx] 96 | 97 | # Select only top `perc`% coefficients 98 | strong_idx = strong_idx[: int(np.rint(len(strong_idx) * perc))] 99 | denoise[strong_idx] = y[strong_idx] 100 | 101 | data_denoise = op.inverse(denoise).reshape(data.shape) 102 | return data_denoise.real, strong 103 | 104 | 105 | # %% 106 | 107 | # Reconstruct data with only 10% of the strongest coefficients in sparse domain 108 | perc = 0.1 109 | d_seis, seis_strong = reconstruct(d, Sop, perc=perc) 110 | d_dwt, dwt_strong = reconstruct(d, Wop, perc=perc) 111 | d_dct, dct_strong = reconstruct(d, Cop, perc=perc) 112 | 113 | # %% 114 | fig, ax = plt.subplots() 115 | ax.semilogy( 116 | np.linspace(0, 100, len(seis_strong), endpoint=True), 117 | seis_strong / seis_strong[0], 118 | label="Seislet", 119 | ) 120 | ax.semilogy( 121 | np.linspace(0, 100, len(dwt_strong), endpoint=True), 122 | dwt_strong / dwt_strong[0], 123 | label="Wavelet", 124 | ) 125 | ax.semilogy( 126 | np.linspace(0, 100, len(dct_strong), endpoint=True), 127 | dct_strong / dct_strong[0], 128 | label="Curvelet", 129 | ) 130 | ax.set( 131 | xlim=(0, 100), 132 | ylim=(1e-4, 1), 133 | xlabel="Coefficients [%]", 134 | ylabel="Coefficient strength [dB]", 135 | title="Transform Coefficients", 136 | ) 137 | ax.axvline(100 * perc, color="k", label=f"{100*perc:.0f}%") 138 | ax.legend() 139 | fig.tight_layout() 140 | 141 | # %% 142 | gain = 4 143 | fig, ax = plt.subplots(2, 3, figsize=(14, 8), sharey=True, sharex=True) 144 | for i, (d_trans, title) in enumerate( 145 | zip([d_seis, d_dwt, d_dct], ["Seislet", "Wavelet", "Curvelet"]) 146 | ): 147 | ax[0, i].imshow(d_trans.T, **opts) 148 | im = ax[1, i].imshow((d - d_trans).T, **opts) 149 | im.set_clim(vmin=-clip / gain, vmax=clip / gain) 150 | ax[0, i].set(title=f"{title} ({100*perc:.0f}% of components)") 151 | ax[1, i].set(title=f"{title} Error x {gain}", xlabel="Position [m]") 152 | ax[0, 0].set(ylabel="Time [s]") 153 | ax[1, 0].set(ylabel="Time [s]") 154 | fig.tight_layout() 155 | 156 | # %% 157 | 158 | # Calculate error in reconstruction by number of coefficients used 159 | error_seis = [] 160 | error_dwt = [] 161 | error_dct = [] 162 | for perc in tqdm(2 ** np.arange(7) / 100.0): 163 | d_seis = reconstruct(d, Sop, perc=perc)[0] 164 | d_dwt = reconstruct(d, Wop, perc=perc)[0] 165 | d_dct = reconstruct(d, Cop, perc=perc)[0] 166 | error_seis.append(np.linalg.norm(d_seis - d)) 167 | error_dwt.append(np.linalg.norm(d_dwt - d)) 168 | error_dct.append(np.linalg.norm(d_dct - d)) 169 | 170 | # %% 171 | fig, ax = plt.subplots() 172 | ax.semilogy(2 ** np.arange(7), error_seis, "o-", label="Seislet") 173 | ax.semilogy(2 ** np.arange(7), error_dwt, "o-", label="Wavelet") 174 | ax.semilogy(2 ** np.arange(7), error_dct, "o-", label="Curvelet") 175 | ax.set(xlabel="Percentage of coefficients", ylabel=r"Error ($L_2$ norm)") 176 | ax.legend() 177 | fig.tight_layout() 178 | -------------------------------------------------------------------------------- /examples/plot_sigmoid_coefficients.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3. Visualizing Curvelet Coefficients 3 | ==================================== 4 | This example shows the how to visualize curvelet coefficients of an image, 5 | using as example a typical subsurface structure. 6 | """ 7 | # sphinx_gallery_thumbnail_number = 3 8 | 9 | # %% 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | 13 | from curvelops import FDCT2D, apply_along_wedges, curveshow 14 | 15 | # %% 16 | # Input data 17 | # ========== 18 | 19 | # %% 20 | inputfile = "../testdata/sigmoid.npz" 21 | 22 | d = np.load(inputfile) 23 | d = d["sigmoid"] 24 | nx, nt = d.shape 25 | dx, dt = 0.005, 0.004 26 | x, t = np.arange(nx) * dx, np.arange(nt) * dt 27 | 28 | # %% 29 | aspect = dt / dx 30 | opts_plot = dict( 31 | extent=(x[0], x[-1], t[-1], t[0]), 32 | cmap="gray", 33 | interpolation="lanczos", 34 | aspect=aspect, 35 | ) 36 | vmax = 0.5 * np.max(np.abs(d)) 37 | figsize_aspect = aspect * nt / nx 38 | fig, ax = plt.subplots(figsize=(8, figsize_aspect * 8), sharey=True, sharex=True) 39 | ax.imshow(d.T, vmin=-vmax, vmax=vmax, **opts_plot) 40 | ax.set(xlabel="Position [m]", ylabel="Time [s]", title=f"Data shape {d.shape}") 41 | fig.tight_layout() 42 | 43 | # %% 44 | # Create Curvelet Transform 45 | # ========================= 46 | nbscales = 4 47 | nbangles_coarse = 8 48 | allcurvelets = False # Last scale will be a wavelet transform 49 | 50 | # %% 51 | Cop = FDCT2D( 52 | d.shape, 53 | nbscales=nbscales, 54 | nbangles_coarse=nbangles_coarse, 55 | allcurvelets=allcurvelets, 56 | ) 57 | 58 | # %% 59 | # Convert to a list of lists of ndarrays. 60 | d_fdct_struct = Cop.struct(Cop @ d) 61 | 62 | # %% 63 | # Real part of FDCT coefficients 64 | # ============================== 65 | # Curvelet coefficients are essentially directionally-filtered, shrunk versions 66 | # of the original signal. Note that the "shrinking" does not preserve aspect 67 | # ratio. 68 | 69 | # %% 70 | for j, c_scale in enumerate(d_fdct_struct, start=1): 71 | nangles = len(c_scale) 72 | rows = int(np.floor(np.sqrt(nangles))) 73 | cols = int(np.ceil(nangles / rows)) 74 | fig, axes = plt.subplots( 75 | rows, 76 | cols, 77 | figsize=(5 * rows, figsize_aspect * 5 * rows), 78 | ) 79 | fig.suptitle(f"Scale {j} ({len(c_scale)} wedge{'s' if len(c_scale) > 1 else ''})") 80 | axes = np.atleast_1d(axes).ravel() 81 | vmax = 0.5 * max(np.abs(Cweg).max() for Cweg in c_scale) 82 | for iw, (fdct_wedge, ax) in enumerate(zip(c_scale, axes), start=1): 83 | # Note that wedges are transposed in comparison to the input vector. 84 | # This is due to the underlying implementation of the transform. In 85 | # order to plot in the same manner as the data, we must first 86 | # transpose the wedge. We will using the transpose of the wedge for 87 | # visualization. 88 | c = fdct_wedge.real.T 89 | ax.imshow(c.T, vmin=-vmax, vmax=vmax, **opts_plot) 90 | ax.set(title=f"Wedge {iw} shape {c.shape}") 91 | ax.axis("off") 92 | fig.tight_layout() 93 | 94 | # %% 95 | # Imaginagy part of FDCT coefficients 96 | # =================================== 97 | # Curvelops includes much of the above logic wrapped in the following 98 | # :py:class:`curvelops.plot.cuveshow`. Since we 99 | 100 | # Normalize each coefficient by max abs 101 | y_norm = apply_along_wedges(d_fdct_struct, lambda w, *_: w / np.abs(w).max()) 102 | 103 | # %% 104 | figs = curveshow( 105 | y_norm, 106 | real=False, 107 | kwargs_imshow={**opts_plot, "vmin": -0.5, "vmax": 0.5}, 108 | ) 109 | -------------------------------------------------------------------------------- /examples/plot_sigmoid_disks.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 6. Multiscale Local Directions 3 | ============================== 4 | This example shows how to use the Curvelet transform to 5 | visualize local, multiscale preferrential directions in 6 | an image. Inspired by `Kymatio's Scattering disks `__. 7 | """ 8 | # sphinx_gallery_thumbnail_number = 3 9 | 10 | # %% 11 | import matplotlib as mpl 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import numpy.typing as npt 15 | from mpl_toolkits.axes_grid1 import make_axes_locatable 16 | from pylops.signalprocessing import FFT2D 17 | 18 | from curvelops import FDCT2D 19 | from curvelops.plot import ( 20 | create_axes_grid, 21 | create_inset_axes_grid, 22 | overlay_arrows, 23 | overlay_disks, 24 | ) 25 | from curvelops.utils import array_split_nd, ndargmax 26 | 27 | # %% 28 | # Input 29 | # ===== 30 | 31 | # %% 32 | inputfile = "../testdata/sigmoid.npz" 33 | 34 | data = np.load(inputfile) 35 | data = data["sigmoid"] 36 | nx, nz = data.shape 37 | dx, dz = 0.005, 0.004 38 | x, z = np.arange(nx) * dx, np.arange(nz) * dz 39 | 40 | 41 | # %% 42 | aspect = dz / dx 43 | figsize_aspect = aspect * nz / nx 44 | opts_space = dict( 45 | extent=(x[0], x[-1], z[-1], z[0]), 46 | cmap="gray", 47 | interpolation="lanczos", 48 | aspect=aspect, 49 | ) 50 | vmax = 0.5 * np.max(np.abs(data)) 51 | fig, ax = plt.subplots(figsize=(8, figsize_aspect * 8)) 52 | ax.imshow(data.T, vmin=-vmax, vmax=vmax, **opts_space) 53 | ax.set(xlabel="Position [km]", ylabel="Depth [km]", title="Data") 54 | fig.tight_layout() 55 | 56 | 57 | # %% 58 | # Understanding Curvelet Disks 59 | # ============================ 60 | 61 | # %% 62 | # First we create and apply curvelet transform. 63 | Cop = FDCT2D(data.shape, nbscales=4, nbangles_coarse=8, allcurvelets=False) 64 | d_c = Cop.struct(Cop @ data) 65 | 66 | # %% 67 | # Each wedge is mapped to a region of the scattering disk. 68 | # The first number refers to the scale, the second to the wedge index, 69 | # zero-indexed. 70 | # 71 | # The disks have the most energy in the direction perpendicular to the 72 | # directions of minimum change. The following disk is computed with the entire 73 | # image. We observe that with energy mostly along the top-bottom direction, 74 | # the directions in the image will be mostly along the left-right direction, 75 | # which matches the input data. 76 | rows, cols = 1, 1 77 | fig, axes = create_axes_grid( 78 | rows, 79 | cols, 80 | kwargs_subplots=dict(projection="polar"), 81 | kwargs_figure=dict(figsize=(4, 4)), 82 | ) 83 | overlay_disks(d_c, axes, annotate=True) 84 | 85 | 86 | # %% 87 | # Multiscale Local Directions 88 | # ============================ 89 | # The power of the curvelet transform is to provide dip information varying 90 | # with location and scale. 91 | # Below we will compute preferrential local directions using an approach 92 | # based on the 2D FFT that does not differentiate between scales. 93 | 94 | # %% 95 | rows, cols = 5, 6 96 | 97 | 98 | def local_single_scale_dips(data: npt.NDArray, rows: int, cols: int) -> npt.NDArray: 99 | kvecs = np.empty((rows, cols, 2)) 100 | d_split = array_split_nd(data.T, rows, cols) 101 | 102 | for irow in range(kvecs.shape[0]): 103 | for icol in range(kvecs.shape[1]): 104 | d_loc = d_split[irow][icol].T 105 | Fop_loc = FFT2D( 106 | d_loc.shape, 107 | sampling=[dx, dz], 108 | norm="ortho", 109 | real=False, 110 | ifftshift_before=True, 111 | fftshift_after=True, 112 | engine="scipy", 113 | ) 114 | d_k_loc = Fop_loc @ d_loc 115 | 116 | kx_loc = Fop_loc.f1 117 | kz_loc = Fop_loc.f2 118 | 119 | kx_locmax, kz_locmax = ndargmax(np.abs(d_k_loc[:, kz_loc > 0])) 120 | 121 | k = np.array([kx_loc[kx_locmax], kz_loc[kz_loc > 0][kz_locmax]]) 122 | kvecs[irow, icol, :] = k / np.linalg.norm(k) 123 | return kvecs 124 | 125 | 126 | # %% 127 | diskcmap = "turbo" 128 | rows, cols = 5, 6 129 | kvecs = local_single_scale_dips(data, rows, cols) 130 | kvecs *= 0.15 * min(x[-1] - x[0], z[-1] - z[0]) 131 | 132 | fig, ax = plt.subplots(figsize=(8, figsize_aspect * 8)) 133 | ax.imshow(data.T, vmin=-vmax, vmax=vmax, **opts_space) 134 | ax.set(xlabel="Position [km]", ylabel="Depth [km]") 135 | divider = make_axes_locatable(ax) 136 | cax = divider.append_axes("right", size="5%", pad=0.1) 137 | mpl.colorbar.ColorbarBase( 138 | cax, 139 | cmap=plt.get_cmap(diskcmap), 140 | norm=mpl.colors.Normalize(vmin=0, vmax=1), 141 | alpha=0.8, 142 | ) 143 | 144 | # Local single-scale directions 145 | overlay_arrows(kvecs, ax) 146 | 147 | # Local multsicale directions 148 | axesin = create_inset_axes_grid( 149 | ax, 150 | rows, 151 | cols, 152 | height=0.6, 153 | width=0.6, 154 | kwargs_inset_axes=dict(projection="polar"), 155 | ) 156 | overlay_disks(d_c, axesin, linewidth=0.0, cmap=diskcmap) 157 | fig.tight_layout() 158 | -------------------------------------------------------------------------------- /examples/plot_single_curvelet.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 1. Visualize a Single Curvelet 3 | ============================== 4 | This example shows a single curvelet coefficient in 5 | spatial and frequency domains. 6 | """ 7 | 8 | # %% 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from matplotlib.ticker import MultipleLocator 12 | 13 | from curvelops import FDCT2D 14 | from curvelops.plot import create_colorbar 15 | 16 | # %% 17 | # Setup 18 | # ===== 19 | m = 512 20 | n = 512 21 | x = np.zeros((m, n)) 22 | DCT = FDCT2D(x.shape) 23 | 24 | # %% 25 | # Curvelet Domain 26 | # =============== 27 | 28 | # %% 29 | y = DCT * x 30 | 31 | # Convert to a curvelet struct indexed by 32 | # [scale, wedge (angle), x, y] 33 | y_reshape = DCT.struct(y) 34 | 35 | # %% 36 | # Select single curvelet 37 | # ====================== 38 | s = 4 39 | w = 0 40 | a, b = y_reshape[s][w].shape 41 | normalization = np.sqrt(y_reshape[s][w].size) 42 | y_reshape[s][w][a // 2, b // 2] = 1 * normalization 43 | y_reshape[s][w + len(y_reshape[s]) // 2][a // 2, b // 2] = -1j * normalization 44 | 45 | y = DCT.vect(y_reshape) 46 | 47 | # %% 48 | # Perform adjoint transform and reshape 49 | # ===================================== 50 | x = DCT.H @ y 51 | 52 | # %% 53 | # F-K domain 54 | # ========== 55 | x_fk = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(x), norm="ortho")) 56 | 57 | # %% 58 | # Visualize 59 | # ========= 60 | vmin, vmax = 0.8 * np.array([-1, 1]) * np.abs(np.max(x)) 61 | fig, ax = plt.subplots(2, 2, figsize=(8, 8), sharex="row", sharey="row") 62 | 63 | im = ax[0, 0].imshow(x.real.T, cmap="gray", vmin=vmin, vmax=vmax) 64 | create_colorbar(im, ax[0, 0]) 65 | 66 | im = ax[0, 1].imshow(x.imag.T, cmap="gray", vmin=vmin, vmax=vmax) 67 | create_colorbar(im, ax[0, 1]) 68 | 69 | im = ax[1, 0].imshow(np.abs(x_fk).T, cmap="turbo", vmin=0) 70 | create_colorbar(im, ax[1, 0]) 71 | 72 | mask = np.abs(x_fk) > 0.01 * np.abs(x_fk).max() 73 | im = ax[1, 1].imshow( 74 | (mask * np.angle(x_fk, deg=True)).T, 75 | cmap="twilight_shifted", 76 | vmin=-180, 77 | vmax=180, 78 | ) 79 | cax, cb = create_colorbar(im, ax[1, 1]) 80 | cax.get_yaxis().set_major_locator(MultipleLocator(45)) 81 | 82 | 83 | ax[0, 0].set( 84 | xlim=(m // 2 - 50, m // 2 + 50), 85 | ylim=(n // 2 - 50, n // 2 + 50), 86 | title="Space domain (Real) magnified", 87 | ) 88 | ax[0, 1].set(title="Space domain (Imag) magnified") 89 | ax[1, 0].set(title="Frequency domain (Abs)") 90 | ax[1, 1].set(title="Frequency domain (Phase)") 91 | fig.tight_layout() 92 | -------------------------------------------------------------------------------- /notebooks/Single_Curvelet_Interactive.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Single Curvelet (Interactive)\n", 9 | "This interactive example shows a single curvelet coefficient in\n", 10 | "spatial and frequency domains.\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import matplotlib.pyplot as plt\n", 20 | "import numpy as np\n", 21 | "from IPython.display import display\n", 22 | "from ipywidgets import HBox, IntSlider, VBox, interactive_output\n", 23 | "\n", 24 | "from curvelops import FDCT2D" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "### Setup" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "nx = 300\n", 41 | "nz = 350\n", 42 | "\n", 43 | "# Create operator\n", 44 | "DCT = FDCT2D((nx, nz), nbangles_coarse=8)\n", 45 | "\n", 46 | "# Create empty structure for curvelet\n", 47 | "y_struct = DCT.struct(np.zeros(DCT.shape[0]))" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "### Plotting" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "def display_curvelet(scale=1, wedge=1, ix=1, iy=1):\n", 64 | " s = scale - 1\n", 65 | " w = wedge - 1\n", 66 | "\n", 67 | " # Populate curvelet\n", 68 | " y_new = DCT.struct(np.zeros(DCT.shape[0]))\n", 69 | " A, B = y_new[s][w].shape\n", 70 | " iy = max(1, min(iy, A))\n", 71 | " ix = max(1, min(ix, B))\n", 72 | " y_new[s][w][iy - 1, ix - 1] = 1.0\n", 73 | "\n", 74 | " x = DCT.H @ DCT.vect(y_new)\n", 75 | "\n", 76 | " x_fk = np.fft.fft2(x)\n", 77 | " x_fk = np.fft.fftshift(x_fk)\n", 78 | "\n", 79 | " vmin, vmax = 0.8 * np.array([-1, 1]) * np.abs(np.max(x))\n", 80 | " fig, ax = plt.subplots(2, 2, figsize=(8, 8), sharex=\"row\", sharey=\"row\")\n", 81 | " ax[0, 0].imshow(np.real(x.T), cmap=\"gray\", vmin=vmin, vmax=vmax)\n", 82 | " ax[0, 1].imshow(np.imag(x.T), cmap=\"gray\", vmin=vmin, vmax=vmax)\n", 83 | " ax[1, 0].imshow(np.abs(x_fk.T), cmap=\"turbo\", vmin=0)\n", 84 | " mask = np.abs(x_fk) > 0.01 * np.abs(x_fk).max()\n", 85 | " ax[1, 1].imshow(\n", 86 | " (mask * np.angle(x_fk, deg=True)).T,\n", 87 | " cmap=\"twilight_shifted\",\n", 88 | " vmin=-180,\n", 89 | " vmax=180,\n", 90 | " )\n", 91 | " ax[0, 0].set(title=\"Space domain (Real)\")\n", 92 | " ax[0, 1].set(title=\"Space domain (Imag)\")\n", 93 | " ax[1, 0].set(title=\"Frequency domain (Abs)\")\n", 94 | " ax[1, 1].set(title=\"Frequency domain (Phase)\")\n", 95 | " ax[0, 0].axvline(nx / 2, color=\"y\", alpha=0.5)\n", 96 | " ax[0, 0].axhline(nz / 2, color=\"y\", alpha=0.5)\n", 97 | " ax[0, 1].axvline(nx / 2, color=\"y\", alpha=0.5)\n", 98 | " ax[0, 1].axhline(nz / 2, color=\"y\", alpha=0.5)\n", 99 | " fig.tight_layout()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 4, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "image/png": "", 110 | "text/plain": [ 111 | "
" 112 | ] 113 | }, 114 | "metadata": {}, 115 | "output_type": "display_data" 116 | } 117 | ], 118 | "source": [ 119 | "display_curvelet(\n", 120 | " scale=3,\n", 121 | " wedge=3,\n", 122 | " ix=y_struct[2][2].shape[1] // 2 + 1,\n", 123 | " iy=y_struct[2][2].shape[0] // 2 + 1,\n", 124 | ")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "### Interactive" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 5, 137 | "metadata": { 138 | "scrolled": false 139 | }, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "application/vnd.jupyter.widget-view+json": { 144 | "model_id": "0683c8c7a7664e43a36bff56dcad8383", 145 | "version_major": 2, 146 | "version_minor": 0 147 | }, 148 | "text/plain": [ 149 | "HBox(children=(VBox(children=(IntSlider(value=1, description='Scales', max=6, min=1), IntSlider(value=1, descr…" 150 | ] 151 | }, 152 | "metadata": {}, 153 | "output_type": "display_data" 154 | }, 155 | { 156 | "data": { 157 | "application/vnd.jupyter.widget-view+json": { 158 | "model_id": "74fb7044b37c446b87f31be817ad2b4d", 159 | "version_major": 2, 160 | "version_minor": 0 161 | }, 162 | "text/plain": [ 163 | "Output()" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "max_scale = DCT.nbscales\n", 172 | "max_wedge = len(y_struct[0])\n", 173 | "max_iy, max_ix = y_struct[0][0].shape\n", 174 | "curr_scale = 1\n", 175 | "curr_wedge = 1\n", 176 | "\n", 177 | "slider_scale = IntSlider(\n", 178 | " min=1, max=max_scale, value=curr_scale, step=1, description=\"Scales\"\n", 179 | ")\n", 180 | "slider_wedge = IntSlider(\n", 181 | " min=1, max=max_wedge, value=curr_wedge, step=1, description=\"Wedge\"\n", 182 | ")\n", 183 | "slider_ix = IntSlider(\n", 184 | " min=1, max=max_ix, value=max_ix // 2 + 1, step=1, description=\"X Index\"\n", 185 | ")\n", 186 | "slider_iy = IntSlider(\n", 187 | " min=1, max=max_iy, value=max_iy // 2 + 1, step=1, description=\"Y Index\"\n", 188 | ")\n", 189 | "\n", 190 | "\n", 191 | "def handle_scale_change(change):\n", 192 | " global curr_scale\n", 193 | " curr_scale = change.new\n", 194 | " slider_wedge.max = len(y_struct[curr_scale - 1])\n", 195 | " global curr_wedge\n", 196 | " curr_wedge = slider_wedge.value\n", 197 | " A, B = y_struct[curr_scale - 1][curr_wedge - 1].shape\n", 198 | " slider_ix.max = B\n", 199 | " slider_iy.max = A\n", 200 | "\n", 201 | "\n", 202 | "def handle_wedge_change(change):\n", 203 | " global curr_wedge\n", 204 | " curr_wedge = change.new\n", 205 | " A, B = y_struct[curr_scale - 1][curr_wedge - 1].shape\n", 206 | " slider_ix.max = B\n", 207 | " slider_iy.max = A\n", 208 | "\n", 209 | "\n", 210 | "slider_scale.observe(handle_scale_change, names=\"value\")\n", 211 | "slider_wedge.observe(handle_wedge_change, names=\"value\")\n", 212 | "\n", 213 | "out = interactive_output(\n", 214 | " display_curvelet,\n", 215 | " {\n", 216 | " \"scale\": slider_scale,\n", 217 | " \"wedge\": slider_wedge,\n", 218 | " \"ix\": slider_ix,\n", 219 | " \"iy\": slider_iy,\n", 220 | " },\n", 221 | ")\n", 222 | "vbox1 = VBox([slider_scale, slider_wedge])\n", 223 | "vbox2 = VBox([slider_ix, slider_iy])\n", 224 | "ui = HBox([vbox1, vbox2])\n", 225 | "display(ui, out)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [] 234 | } 235 | ], 236 | "metadata": { 237 | "@webio": { 238 | "lastCommId": null, 239 | "lastKernelId": null 240 | }, 241 | "kernelspec": { 242 | "display_name": "Python 3 (ipykernel)", 243 | "language": "python", 244 | "name": "python3" 245 | }, 246 | "language_info": { 247 | "codemirror_mode": { 248 | "name": "ipython", 249 | "version": 3 250 | }, 251 | "file_extension": ".py", 252 | "mimetype": "text/x-python", 253 | "name": "python", 254 | "nbconvert_exporter": "python", 255 | "pygments_lexer": "ipython3", 256 | "version": "3.10.6" 257 | }, 258 | "vscode": { 259 | "interpreter": { 260 | "hash": "2a97748138635e07aa5d1e3bf79826285d6ad6d370bfec5306f9f799b8f29eef" 261 | } 262 | } 263 | }, 264 | "nbformat": 4, 265 | "nbformat_minor": 4 266 | } 267 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | "pybind11>=2.6.0; python_version < '3.11'", 6 | "pybind11>=2.10.0; python_version >= '3.11'", 7 | ] 8 | build-backend = "setuptools.build_meta" 9 | 10 | [tool.black] 11 | line-length = 88 12 | 13 | [tool.isort] 14 | profile = "black" 15 | 16 | [[tool.mypy.overrides]] 17 | module = [ 18 | "curvelops._version", 19 | "curvelops.fdct2d_wrapper", 20 | "curvelops.fdct3d_wrapper", 21 | "pylops", 22 | "pylops.*", 23 | "matplotlib", 24 | "matplotlib.*", 25 | "mpl_toolkits.axes_grid1", 26 | "scipy.*", 27 | "tqdm.*", 28 | ] 29 | ignore_missing_imports = true 30 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Install requires 2 | numpy>=1.21.0 3 | scipy>=1.9.1; python_version >= '3.9' 4 | pylops>=2.0 5 | # Setup requires 6 | pybind11>=2.6.0; python_version < '3.10' 7 | pybind11>=2.10.0; python_version >= '3.11' 8 | setuptools_scm 9 | # Tests require 10 | pytest 11 | # Docs require 12 | Sphinx 13 | pydata-sphinx-theme 14 | sphinx-gallery 15 | sphinx-copybutton 16 | # Dev requires 17 | pre-commit 18 | # Lint 19 | flake8 20 | mypy 21 | coverage 22 | # Examples require 23 | matplotlib 24 | tqdm 25 | ipywidgets 26 | ipython 27 | ipympl 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | addopts = --verbose 3 | 4 | [flake8] 5 | ignore = E203, E501, W503, E402 6 | per-file-ignores = 7 | __init__.py: F401, F403, F405 8 | max-line-length = 88 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from setuptools import find_packages, setup 5 | 6 | if "clean" in sys.argv: 7 | from pathlib import Path 8 | 9 | # Delete any previously compiled files in pygeos 10 | p = Path("curvelops") 11 | for filename in p.glob("*.so"): 12 | print("removing '{}'".format(filename)) 13 | filename.unlink() 14 | 15 | from pybind11.setup_helpers import Pybind11Extension, build_ext 16 | 17 | NAME = "curvelops" 18 | AUTHOR = "Carlos Alberto da Costa Filho" 19 | AUTHOR_EMAIL = "c.dacostaf@gmail.com" 20 | URL = "https://github.com/PyLops/curvelops" 21 | DESCRIPTION = "Python wrapper for CurveLab's 2D and 3D curvelet transforms" 22 | LICENSE = "MIT" 23 | 24 | with open("README.md", encoding="utf-8") as f: 25 | LONG_DESCRIPTION = f.read() 26 | 27 | try: 28 | FFTW = os.environ["FFTW"] 29 | except KeyError: 30 | print( 31 | """ 32 | ============================================================== 33 | 34 | Please ensure the FFTW environment variable is set to the root 35 | of the FFTW 2.1.5 installation directory. 36 | 37 | ============================================================== 38 | """ 39 | ) 40 | try: 41 | FDCT = os.environ["FDCT"] 42 | except KeyError: 43 | print( 44 | """ 45 | ============================================================== 46 | 47 | Please ensure the FDCT environment variable is set to the root 48 | of the CurveLab installation directory. 49 | 50 | ============================================================== 51 | """ 52 | ) 53 | 54 | 55 | ext_modules = [ 56 | Pybind11Extension( 57 | "fdct2d_wrapper", 58 | [os.path.join("cpp", "fdct2d_wrapper.cpp")], 59 | include_dirs=[ 60 | os.path.join(FFTW, "fftw"), 61 | os.path.join(FDCT, "fdct_wrapping_cpp", "src"), 62 | ], 63 | libraries=["fftw"], 64 | library_dirs=[os.path.join(FFTW, "fftw", ".libs")], 65 | extra_objects=[ 66 | os.path.join(FDCT, "fdct_wrapping_cpp", "src", "libfdct_wrapping.a") 67 | ], 68 | language="c++", 69 | ), 70 | Pybind11Extension( 71 | "fdct3d_wrapper", 72 | [os.path.join("cpp", "fdct3d_wrapper.cpp")], 73 | include_dirs=[ 74 | os.path.join(FFTW, "fftw"), 75 | os.path.join(FDCT, "fdct3d", "src"), 76 | ], 77 | libraries=["fftw"], 78 | library_dirs=[os.path.join(FFTW, "fftw", ".libs")], 79 | extra_objects=[os.path.join(FDCT, "fdct3d", "src", "libfdct3d.a")], 80 | language="c++", 81 | ), 82 | ] 83 | 84 | # Remove -stdlib=libc++ from MACOS flags if MACOS_GCC flag is equal to 1 85 | # (This is required because pybind11 assumes OSX will use clang compiler but 86 | # FFTW and FDCT may require switching to a gcc compiler in some OSX versions. 87 | MACOS = sys.platform.startswith("darwin") 88 | if MACOS and int(os.getenv("MACOS_GCC", 0)) == 1: 89 | for ext in ext_modules: 90 | new_flags = [] 91 | for flag in ext.extra_compile_args: 92 | if flag != "-stdlib=libc++": 93 | new_flags.append(flag) 94 | ext.extra_compile_args = new_flags 95 | 96 | new_flags = [] 97 | for flag in ext.extra_link_args: 98 | if flag != "-stdlib=libc++": 99 | new_flags.append(flag) 100 | ext.extra_link_args = new_flags 101 | 102 | setup( 103 | name=NAME, 104 | author=AUTHOR, 105 | author_email=AUTHOR_EMAIL, 106 | url=URL, 107 | description=DESCRIPTION, 108 | long_description=LONG_DESCRIPTION, 109 | long_description_content_type="text/markdown", 110 | zip_safe=False, 111 | include_package_data=True, 112 | cmdclass={"build_ext": build_ext}, 113 | ext_package="curvelops", 114 | ext_modules=ext_modules, 115 | packages=find_packages(exclude=["pytests"]), 116 | install_requires=[ 117 | "numpy>=1.21.0", 118 | "scipy>=1.9.1; python_version >= '3.9'", 119 | "pylops>=2.0", 120 | "matplotlib", 121 | ], 122 | setup_requires=[ 123 | "pybind11>=2.6.0; python_version < '3.10'", 124 | "pybind11>=2.10.0; python_version >= '3.11'", 125 | "setuptools_scm", 126 | ], 127 | use_scm_version=dict( 128 | root=".", relative_to=__file__, write_to=f"{NAME}/_version.py" 129 | ), 130 | license=LICENSE, 131 | test_suite="pytests", 132 | tests_require=["pytest"], 133 | extras_require={"dev": ["pytest"]}, 134 | python_requires=">=3.7", 135 | classifiers=[ 136 | "Development Status :: 3 - Beta", 137 | "Intended Audience :: Science/Research", 138 | "License :: OSI Approved :: MIT License", 139 | "Natural Language :: English", 140 | "Programming Language :: Python :: 3.7", 141 | "Programming Language :: Python :: 3.8", 142 | "Programming Language :: Python :: 3.9", 143 | "Programming Language :: Python :: 3.10", 144 | "Programming Language :: Python :: 3.11", 145 | "Topic :: Scientific/Engineering :: Mathematics", 146 | ], 147 | keywords="curvelet curvelab pylops", 148 | ) 149 | -------------------------------------------------------------------------------- /testdata/python.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PyLops/curvelops/d6dc5fde8bcf399e57f81c5beb38449eb8863e45/testdata/python.png -------------------------------------------------------------------------------- /testdata/seismic.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PyLops/curvelops/d6dc5fde8bcf399e57f81c5beb38449eb8863e45/testdata/seismic.npz -------------------------------------------------------------------------------- /testdata/sigmoid.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PyLops/curvelops/d6dc5fde8bcf399e57f81c5beb38449eb8863e45/testdata/sigmoid.npz -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PyLops/curvelops/d6dc5fde8bcf399e57f81c5beb38449eb8863e45/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_fdct.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from pylops.utils import dottest 4 | 5 | from curvelops import FDCT2D, FDCT3D 6 | 7 | PYCT = False 8 | try: 9 | import pyct as ct 10 | 11 | PYCT = True 12 | print( 13 | """ 14 | Imported `pyct` 15 | """ 16 | ) 17 | 18 | except ImportError: 19 | print( 20 | """ 21 | Could not import `pyct` (PyCurvelab), will proceed without 22 | checking if both libraries match 23 | """ 24 | ) 25 | 26 | pars = [ 27 | # {'nx': 32, 'ny': 32, 'nz': 32, 'imag': 0, 'dtype': 'float64'}, 28 | {"nx": 32, "ny": 32, "nz": 32, "imag": 1j, "dtype": "complex128"}, 29 | # {'nx': 32, 'ny': 32, 'nz': 64, 'imag': 0, 'dtype': 'float64'}, 30 | {"nx": 32, "ny": 32, "nz": 64, "imag": 1j, "dtype": "complex128"}, 31 | # {'nx': 100, 'ny': 50, 'nz': 20, 'imag': 0, 'dtype': 'complex128'}, 32 | {"nx": 100, "ny": 50, "nz": 20, "imag": 1j, "dtype": "complex128"}, 33 | ] 34 | 35 | 36 | @pytest.mark.parametrize("par", pars) 37 | def test_FDCT2D_2dsignal(par): 38 | """ 39 | Tests for FDCT2D operator for 2d signal. 40 | """ 41 | x = ( 42 | np.random.normal(0.0, 1.0, (par["nx"], par["ny"])) 43 | + np.random.normal(0.0, 1.0, (par["nx"], par["ny"])) * par["imag"] 44 | ) 45 | 46 | FDCTop = FDCT2D(dims=(par["nx"], par["ny"]), dtype=par["dtype"]) 47 | 48 | assert dottest( 49 | FDCTop, *FDCTop.shape, rtol=1e-12, complexflag=0 if par["imag"] == 0 else 3 50 | ) 51 | 52 | y = FDCTop * x.ravel() 53 | xinv = FDCTop.H * y 54 | np.testing.assert_array_almost_equal(xinv.reshape(*x.shape), x, decimal=14) 55 | 56 | if PYCT: 57 | FDCTct = ct.fdct2( 58 | x.shape, 59 | FDCTop.nbscales, 60 | FDCTop.nbangles_coarse, 61 | FDCTop.allcurvelets, 62 | cpx=False if par["imag"] == 0 else True, 63 | ) 64 | y_ct = np.array(FDCTct.fwd(x)).ravel() 65 | 66 | np.testing.assert_array_almost_equal(y, y_ct, decimal=64) 67 | assert y.dtype == y_ct.dtype 68 | 69 | 70 | @pytest.mark.parametrize("par", pars) 71 | def test_FDCT2D_3dsignal(par): 72 | """ 73 | Tests for FDCT2D operator for 3d signal. 74 | """ 75 | x = ( 76 | np.random.normal(0.0, 1.0, (par["nx"], par["ny"], par["nz"])) 77 | + np.random.normal(0.0, 1.0, (par["nx"], par["ny"], par["nz"])) * par["imag"] 78 | ) 79 | axes = [0, -1] 80 | FDCTop = FDCT2D( 81 | dims=(par["nx"], par["ny"], par["nz"]), axes=axes, dtype=par["dtype"] 82 | ) 83 | 84 | assert dottest( 85 | FDCTop, *FDCTop.shape, rtol=1e-12, complexflag=0 if par["imag"] == 0 else 3 86 | ) 87 | 88 | y = FDCTop * x.ravel() 89 | xinv = FDCTop.H * y 90 | np.testing.assert_array_almost_equal(xinv.reshape(*x.shape), x, decimal=14) 91 | 92 | 93 | @pytest.mark.parametrize("par", pars) 94 | def test_FDCT3D_3dsignal(par): 95 | """ 96 | Tests for FDCT3D operator for 3d signal. 97 | """ 98 | x = ( 99 | np.random.normal(0.0, 1.0, (par["nx"], par["ny"], par["nz"])) 100 | + np.random.normal(0.0, 1.0, (par["nx"], par["ny"], par["nz"])) * par["imag"] 101 | ) 102 | 103 | FDCTop = FDCT3D(dims=(par["nx"], par["ny"], par["nz"]), dtype=par["dtype"]) 104 | 105 | assert dottest( 106 | FDCTop, *FDCTop.shape, rtol=1e-12, complexflag=0 if par["imag"] == 0 else 3 107 | ) 108 | 109 | y = FDCTop * x.ravel() 110 | xinv = FDCTop.H * y 111 | np.testing.assert_array_almost_equal(xinv.reshape(*x.shape), x, decimal=14) 112 | 113 | if PYCT: 114 | FDCTct = ct.fdct3( 115 | x.shape, 116 | FDCTop.nbscales, 117 | FDCTop.nbangles_coarse, 118 | FDCTop.allcurvelets, 119 | cpx=False if par["imag"] == 0 else True, 120 | ) 121 | 122 | y_ct = np.array(FDCTct.fwd(x)).ravel() 123 | 124 | np.testing.assert_array_almost_equal(y, y_ct, decimal=64) 125 | assert y.dtype == y_ct.dtype 126 | 127 | 128 | @pytest.mark.parametrize("par", pars) 129 | def test_FDCT3D_4dsignal(par): 130 | """ 131 | Tests for FDCT3D operator for 4d signal. 132 | """ 133 | x = ( 134 | np.random.normal(0.0, 1.0, (par["nx"], 4, par["ny"], par["nz"])) 135 | + np.random.normal(0.0, 1.0, (par["nx"], 4, par["ny"], par["nz"])) * par["imag"] 136 | ) 137 | axes = [0, -2, -1] 138 | FDCTop = FDCT3D( 139 | dims=(par["nx"], 4, par["ny"], par["nz"]), 140 | axes=axes, 141 | dtype=par["dtype"], 142 | ) 143 | 144 | assert dottest( 145 | FDCTop, *FDCTop.shape, rtol=1e-12, complexflag=0 if par["imag"] == 0 else 3 146 | ) 147 | 148 | x = x.ravel() 149 | y = FDCTop * x 150 | xinv = FDCTop.H * y 151 | np.testing.assert_array_almost_equal(xinv, x, decimal=14) 152 | -------------------------------------------------------------------------------- /tests/test_fdct2d_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import curvelops.fdct2d_wrapper as ct 5 | 6 | pars = [ 7 | {"nx": 100, "ny": 50, "imag": 0, "dtype": "float64"}, 8 | {"nx": 100, "ny": 50, "imag": 1j, "dtype": "float64"}, 9 | {"nx": 256, "ny": 256, "imag": 0, "dtype": "float64"}, 10 | {"nx": 256, "ny": 256, "imag": 1j, "dtype": "float64"}, 11 | {"nx": 512, "ny": 256, "imag": 0, "dtype": "float64"}, 12 | {"nx": 512, "ny": 256, "imag": 1j, "dtype": "float64"}, 13 | {"nx": 512, "ny": 512, "imag": 0, "dtype": "float64"}, 14 | {"nx": 512, "ny": 512, "imag": 1j, "dtype": "complex128"}, 15 | ] 16 | 17 | 18 | @pytest.mark.parametrize("par", pars) 19 | def test_FDCT2D_wrapper_2dsignal(par): 20 | x = ( 21 | np.random.normal(0, 1, (par["nx"], par["ny"])) 22 | + np.random.normal(0, 1, (par["nx"], par["ny"])) * par["imag"] 23 | ) 24 | 25 | for nbscales in [4, 6, 8, 16]: 26 | for nbangles_coarse in [8, 16]: 27 | for ac in [True, False]: 28 | c = ct.fdct2d_forward_wrap(nbscales, nbangles_coarse, ac, x) 29 | xinv = ct.fdct2d_inverse_wrap( 30 | *x.shape, nbscales, nbangles_coarse, ac, c 31 | ) 32 | np.testing.assert_array_almost_equal(x, xinv, decimal=12) 33 | np.testing.assert_array_almost_equal( 34 | 2.0 * np.sum(np.abs(x - xinv)) / np.sum(np.abs(x + xinv)), 35 | 0.0, 36 | decimal=12, 37 | ) 38 | -------------------------------------------------------------------------------- /tests/test_fdct3d_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import curvelops.fdct3d_wrapper as ct 5 | 6 | pars = [ 7 | {"nx": 32, "ny": 32, "nz": 32, "imag": 0, "dtype": "float64"}, 8 | {"nx": 32, "ny": 32, "nz": 32, "imag": 1j, "dtype": "complex128"}, 9 | {"nx": 32, "ny": 32, "nz": 64, "imag": 0, "dtype": "float64"}, 10 | {"nx": 32, "ny": 32, "nz": 64, "imag": 1j, "dtype": "complex128"}, 11 | {"nx": 100, "ny": 50, "nz": 20, "imag": 0, "dtype": "float64"}, 12 | {"nx": 100, "ny": 50, "nz": 20, "imag": 1j, "dtype": "complex128"}, 13 | ] 14 | 15 | 16 | @pytest.mark.parametrize("par", pars) 17 | def test_FDCT3D_wrapper_3dsignal(par): 18 | x = ( 19 | np.random.normal(0, 1, (par["nx"], par["ny"], par["nz"])) 20 | + np.random.normal(0, 1, (par["nx"], par["ny"], par["nz"])) * par["imag"] 21 | ) 22 | for nbscales in [4, 6, 8]: 23 | for nbangles_coarse in [8, 16]: 24 | for ac in [True, False]: 25 | c = ct.fdct3d_forward_wrap(nbscales, nbangles_coarse, ac, x) 26 | xinv = ct.fdct3d_inverse_wrap( 27 | *x.shape, nbscales, nbangles_coarse, ac, c 28 | ) 29 | np.testing.assert_array_almost_equal(x, xinv, decimal=12) 30 | np.testing.assert_array_almost_equal( 31 | 2.0 * np.sum(np.abs(x - xinv)) / np.sum(np.abs(x + xinv)), 32 | 0.0, 33 | decimal=12, 34 | ) 35 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from numpy.random import randint 4 | 5 | from curvelops import FDCT 6 | from curvelops.utils import ( 7 | apply_along_wedges, 8 | array_split_nd, 9 | energy, 10 | energy_split, 11 | ndargmax, 12 | split_nd, 13 | ) 14 | 15 | pars = [ 16 | {"shape": (randint(1, 99),), "splits": (randint(1, 10),)}, 17 | { 18 | "shape": (randint(1, 99), randint(1, 99)), 19 | "splits": (randint(1, 10), randint(1, 10)), 20 | }, 21 | { 22 | "shape": (randint(1, 99), randint(1, 99), randint(1, 99)), 23 | "splits": (randint(1, 10), randint(1, 10), randint(1, 10)), 24 | }, 25 | ] 26 | 27 | pars_cl = [ 28 | {"shape": (randint(32, 129), randint(32, 129))}, 29 | {"shape": (randint(32, 129), randint(32, 129), randint(32, 129))}, 30 | ] 31 | 32 | 33 | def test_array_split_nd_simple(): 34 | x = np.outer(1 + np.arange(2), 2 + np.arange(3)) 35 | y = array_split_nd(x, 2, 3) 36 | assert len(x) == 2 37 | for subx in x: 38 | assert len(subx) == 3 39 | assert y[0][0] == 2 40 | assert y[0][1] == 3 41 | assert y[0][2] == 4 42 | assert y[1][0] == 4 43 | assert y[1][1] == 6 44 | assert y[1][2] == 8 45 | 46 | 47 | @pytest.mark.parametrize("par", pars) 48 | def test_array_split_nd_sizes(par): 49 | shape = par["shape"] 50 | splits = par["splits"] 51 | x = np.zeros(tuple(a * b for (a, b) in zip(shape, splits))) 52 | y = array_split_nd(x, *splits) 53 | for split in splits: 54 | assert split == len(y) 55 | y = y[0] 56 | assert y.shape == shape 57 | 58 | 59 | @pytest.mark.parametrize("par", pars) 60 | def test_split_nd_sizes(par): 61 | shape = par["shape"] 62 | splits = par["splits"] 63 | x = np.zeros(tuple(a * b for (a, b) in zip(shape, splits))) 64 | y = split_nd(x, *splits) 65 | for split in splits: 66 | assert split == len(y) 67 | y = y[0] 68 | assert y.shape == shape 69 | 70 | 71 | @pytest.mark.parametrize("par", pars_cl) 72 | def test_apply_along_wedges(par): 73 | shape = par["shape"] 74 | Cop = FDCT(shape, axes=list(range(len(shape)))) 75 | x = np.random.normal(0.0, 1.0, shape) + np.random.normal(0.0, 1.0, shape) * 1j 76 | # Create a vector of curvelet coeffs 77 | y = Cop @ x 78 | # Convert to structure 79 | y_struct = Cop.struct(Cop @ x) 80 | # Add 1 to each wedge 81 | y_struct_one = apply_along_wedges( 82 | y_struct, 83 | lambda c, w, s, na, ns: c + 1.0, 84 | ) 85 | # Convert back to vector 86 | y_one = Cop.vect(y_struct_one) 87 | 88 | # Ensure that each wedge of the modified wedge - original is 89 | # equal to 2d array of ones 90 | apply_along_wedges( 91 | Cop.struct(y_one - y), 92 | lambda c, w, s, na, ns: np.testing.assert_allclose(c, np.ones_like(c)), 93 | ) 94 | 95 | 96 | def test_energy(): 97 | ndim = np.random.randint(1, 10) 98 | shape = [np.random.randint(1, 10) for _ in range(ndim)] 99 | ones = np.ones(shape) 100 | e = energy(ones) 101 | np.testing.assert_allclose(1.0, e) 102 | 103 | 104 | def test_energy_split(): 105 | shape = [np.random.randint(1, 100), np.random.randint(1, 100)] 106 | rows, cols = np.random.randint(1, shape[0]), np.random.randint(1, shape[1]) 107 | ones = np.ones(shape) 108 | e = energy_split(ones, rows, cols) 109 | for row in range(rows): 110 | for col in range(cols): 111 | np.testing.assert_allclose(1.0, e[row][col]) 112 | 113 | 114 | def test_ndargmax(): 115 | ndim = np.random.randint(1, 10) 116 | shape = [np.random.randint(1, 10) for _ in range(ndim)] 117 | ary = np.zeros(shape) 118 | index = tuple([np.random.randint(0, shape[i]) for i in range(ndim)]) 119 | ary[index] = 1.0 120 | assert index == ndargmax(ary) 121 | --------------------------------------------------------------------------------