├── notebooks └── .gitkeep ├── docs ├── source │ ├── z_ref.rst │ ├── modopt_logo.png │ ├── output_17_0.png │ ├── output_21_0.png │ ├── output_5_0.png │ ├── output_9_0.png │ ├── cosmostat_logo.jpg │ ├── neurospin_logo.png │ ├── notebooks.rst │ ├── citing.rst │ ├── quickstart.rst │ ├── toc.rst │ ├── contributing.rst │ ├── index.rst │ ├── my_ref.bib │ ├── installation.rst │ ├── about.rst │ ├── dependencies.rst │ ├── plugin_example.rst │ ├── refs.bib │ └── conf.py ├── _templates │ ├── toc.rst_t │ ├── module.rst_t │ └── package.rst_t └── requirements.txt ├── examples ├── README.rst ├── __init__.py ├── conftest.py └── example_lasso_forward_backward.py ├── .coveragerc ├── tests ├── test_helpers │ ├── __init__.py │ └── utils.py ├── test_base.py └── test_algorithms.py ├── src └── modopt │ ├── plot │ ├── __init__.py │ └── cost_plot.py │ ├── interface │ ├── __init__.py │ ├── log.py │ └── errors.py │ ├── math │ ├── __init__.py │ ├── convolve.py │ ├── metrics.py │ └── stats.py │ ├── signal │ ├── __init__.py │ ├── validation.py │ ├── positivity.py │ ├── filter.py │ ├── noise.py │ └── svd.py │ ├── opt │ ├── __init__.py │ ├── linear │ │ ├── __init__.py │ │ └── base.py │ ├── algorithms │ │ ├── __init__.py │ │ └── primal_dual.py │ ├── reweight.py │ └── gradient.py │ ├── base │ ├── __init__.py │ ├── types.py │ ├── backend.py │ ├── np_adjust.py │ ├── observable.py │ └── transform.py │ └── __init__.py ├── .github ├── ISSUE_TEMPLATE │ ├── help-.md │ ├── installation-issue.md │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── style.yml │ ├── ci-build.yml │ └── cd-build.yml ├── LICENCE.txt ├── pyproject.toml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── README.md └── CONTRIBUTING.md /notebooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/z_ref.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | 4 | .. bibliography:: refs.bib 5 | -------------------------------------------------------------------------------- /docs/source/modopt_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/modopt_logo.png -------------------------------------------------------------------------------- /docs/source/output_17_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/output_17_0.png -------------------------------------------------------------------------------- /docs/source/output_21_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/output_21_0.png -------------------------------------------------------------------------------- /docs/source/output_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/output_5_0.png -------------------------------------------------------------------------------- /docs/source/output_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/output_9_0.png -------------------------------------------------------------------------------- /docs/source/cosmostat_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/cosmostat_logo.jpg -------------------------------------------------------------------------------- /docs/source/neurospin_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/neurospin_logo.png -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Examples 3 | ======== 4 | 5 | This is a collection of Python scripts demonstrating the use of ModOpt. 6 | -------------------------------------------------------------------------------- /docs/source/notebooks.rst: -------------------------------------------------------------------------------- 1 | Notebooks 2 | ========= 3 | 4 | List of Notebooks 5 | ----------------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit= 3 | *interface* 4 | *plot* 5 | 6 | [report] 7 | show_missing=True 8 | exclude_lines = 9 | pragma: no cover 10 | -------------------------------------------------------------------------------- /tests/test_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities for tests.""" 2 | 3 | from .utils import Dummy, failparam, skipparam 4 | 5 | __all__ = ["Dummy", "failparam", "skipparam"] 6 | -------------------------------------------------------------------------------- /docs/_templates/toc.rst_t: -------------------------------------------------------------------------------- 1 | {{ header | heading }} 2 | 3 | .. toctree:: 4 | :maxdepth: {{ maxdepth }} 5 | {% for docname in docnames %} 6 | {{ docname }} 7 | {%- endfor %} 8 | 9 | -------------------------------------------------------------------------------- /docs/_templates/module.rst_t: -------------------------------------------------------------------------------- 1 | {%- if show_headings %} 2 | {{- basename | e | heading(2) }} 3 | 4 | {% endif -%} 5 | .. automodule:: {{ qualname }} 6 | {%- for option in automodule_options %} 7 | :{{ option }}: 8 | {%- endfor %} 9 | -------------------------------------------------------------------------------- /docs/source/citing.rst: -------------------------------------------------------------------------------- 1 | Citing this Package 2 | =================== 3 | 4 | We kindly request that any academic work making use of this package to cite :cite:`farrens:2020`. 5 | 6 | .. bibliography:: my_ref.bib 7 | :style: alpha 8 | -------------------------------------------------------------------------------- /src/modopt/plot/__init__.py: -------------------------------------------------------------------------------- 1 | """PLOTTING ROUTINES. 2 | 3 | This module contains submodules for plotting applications. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | __all__ = ["cost_plot"] 10 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter==1.0.0 2 | myst-parser==0.16.1 3 | nbsphinx==0.8.7 4 | nbsphinx-link==1.3.0 5 | numpydoc==1.1.0 6 | sphinx==4.3.1 7 | sphinxcontrib-bibtex==2.4.1 8 | sphinxawesome-theme==3.2.1 9 | sphinx-gallery==0.11.1 10 | -------------------------------------------------------------------------------- /docs/source/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart Tutorial 2 | =================== 3 | 4 | You can import the package as follows: 5 | 6 | .. code-block:: python 7 | 8 | import modopt 9 | 10 | .. note:: 11 | 12 | Examples coming soon! 13 | -------------------------------------------------------------------------------- /src/modopt/interface/__init__.py: -------------------------------------------------------------------------------- 1 | """INTERFACE ROUTINES. 2 | 3 | This module contains submodules for error handling, logging and IO interaction. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | __all__ = ["errors", "log"] 10 | -------------------------------------------------------------------------------- /src/modopt/math/__init__.py: -------------------------------------------------------------------------------- 1 | """MATHEMATICS ROUTINES. 2 | 3 | This module contains submodules for mathematical applications. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | __all__ = ["convolve", "matrix", "stats", "metrics"] 10 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | """EXAMPLES. 2 | 3 | This module contains documented examples that demonstrate the usage of various 4 | ModOpt tools. 5 | 6 | These examples also serve as integration tests for various methods. 7 | 8 | :Author: Pierre-Antoine Comby 9 | 10 | """ 11 | -------------------------------------------------------------------------------- /src/modopt/signal/__init__.py: -------------------------------------------------------------------------------- 1 | """SIGNAL PROCESSING ROUTINES. 2 | 3 | This module contains submodules for signal processing. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | __all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"] 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/help-.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Help! 3 | about: Users are welcome to ask any question relating to ModOpt and we will endeavour 4 | to reply as soon as possible. 5 | title: "[HELP]" 6 | labels: help wanted 7 | assignees: '' 8 | 9 | --- 10 | 11 | Let us know how we can help. 12 | -------------------------------------------------------------------------------- /src/modopt/opt/__init__.py: -------------------------------------------------------------------------------- 1 | """OPTIMISATION PROBLEM MODULES. 2 | 3 | This module contains submodules for solving optimisation problems. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | __all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"] 10 | -------------------------------------------------------------------------------- /src/modopt/base/__init__.py: -------------------------------------------------------------------------------- 1 | """BASE ROUTINES. 2 | 3 | This module contains submodules for basic operations such as type 4 | transformations and adjustments to the default output of Numpy functions. 5 | 6 | :Author: Samuel Farrens 7 | 8 | """ 9 | 10 | __all__ = ["np_adjust", "transform", "types", "observable"] 11 | -------------------------------------------------------------------------------- /src/modopt/opt/linear/__init__.py: -------------------------------------------------------------------------------- 1 | """LINEAR OPERATORS. 2 | 3 | This module contains linear operator classes. 4 | 5 | :Author: Samuel Farrens 6 | :Author: Pierre-Antoine Comby 7 | """ 8 | 9 | from .base import LinearParent, Identity, MatrixOperator, LinearCombo 10 | 11 | from .wavelet import WaveletConvolve, WaveletTransform 12 | 13 | 14 | __all__ = [ 15 | "LinearParent", 16 | "Identity", 17 | "MatrixOperator", 18 | "LinearCombo", 19 | "WaveletConvolve", 20 | "WaveletTransform", 21 | ] 22 | -------------------------------------------------------------------------------- /docs/source/toc.rst: -------------------------------------------------------------------------------- 1 | .. Toctrees define sidebar contents 2 | 3 | .. toctree:: 4 | :hidden: 5 | :titlesonly: 6 | :caption: Getting Started 7 | 8 | about 9 | installation 10 | dependencies 11 | quickstart 12 | 13 | .. toctree:: 14 | :hidden: 15 | :titlesonly: 16 | :caption: API Documentation 17 | 18 | modopt 19 | z_ref 20 | 21 | .. toctree:: 22 | :hidden: 23 | :titlesonly: 24 | :caption: Examples 25 | 26 | plugin_example 27 | notebooks 28 | auto_examples/index 29 | 30 | .. toctree:: 31 | :hidden: 32 | :titlesonly: 33 | :caption: Guidelines 34 | 35 | contributing 36 | citing 37 | -------------------------------------------------------------------------------- /docs/source/contributing.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | Read our |link-to-contrib|. 5 | for details on how to contribute to the development of this package. 6 | 7 | All contributors are kindly asked to adhere to the |link-to-conduct| 8 | at all times to ensure a safe and inclusive environment for everyone. 9 | 10 | .. |link-to-contrib| raw:: html 11 | 12 | Contribution Guidelines 13 | 14 | .. |link-to-conduct| raw:: html 15 | 16 | Code of Conduct 17 | -------------------------------------------------------------------------------- /src/modopt/__init__.py: -------------------------------------------------------------------------------- 1 | """MODOPT PACKAGE. 2 | 3 | ModOpt is a series of Modular Optimisation tools for solving inverse problems. 4 | 5 | """ 6 | 7 | from warnings import warn 8 | 9 | from importlib_metadata import version 10 | 11 | from modopt.base import np_adjust, transform, types, observable 12 | 13 | __all__ = ["np_adjust", "transform", "types", "observable"] 14 | 15 | try: 16 | _version = version("modopt") 17 | except Exception: # pragma: no cover 18 | _version = "Unkown" 19 | warn( 20 | "Could not extract package metadata. Make sure the package is " 21 | + "correctly installed.", 22 | stacklevel=1, 23 | ) 24 | 25 | __version__ = _version 26 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. modopt documentation master file, created by 2 | sphinx-quickstart on Mon Oct 24 16:46:22 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | ModOpt Documentation 7 | ====================== 8 | 9 | .. image:: modopt_logo.png 10 | :width: 100% 11 | :alt: ModOpt logo 12 | 13 | .. Include table of contents 14 | .. include:: toc.rst 15 | 16 | :Author: Samuel Farrens `(samuel.farrens@cea.fr) `_ 17 | :Version: 1.6.0 18 | :Release Date: 17/12/2021 19 | :Repository: |link-to-repo| 20 | 21 | .. |link-to-repo| raw:: html 22 | 23 | https://github.com/CEA-COSMIC/ModOpt 25 | 26 | ModOpt is a series of **Modular Optimisation** tools for solving inverse problems. 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/installation-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Installation issue 3 | about: If you encounter difficulties installing ModOpt be sure to re-read the installation 4 | instructions provided before submitting an issue. 5 | title: "[INSTALLATION ERROR]" 6 | labels: installation 7 | assignees: '' 8 | 9 | --- 10 | 11 | **System setup** 12 | OS: [e.g] macOS v10.14.1 13 | Python version: [e.g.] v3.6.7 14 | Python environment (if any): [e.g.] conda v4.5.11 15 | 16 | **Describe the bug** 17 | A clear and concise description of what the problem is. 18 | 19 | **To Reproduce** 20 | List the exact steps you followed that lead to the problem you encountered so that we can attempt to recreate the conditions. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Are you planning to submit a Pull Request?** 26 | - [ ] Yes 27 | - [X] No 28 | -------------------------------------------------------------------------------- /tests/test_helpers/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some helper functions for the test parametrization. 3 | 4 | They should be used inside ``@pytest.mark.parametrize`` call. 5 | 6 | :Author: Pierre-Antoine Comby 7 | """ 8 | 9 | import pytest 10 | 11 | 12 | def failparam(*args, raises=None): 13 | """Return a pytest parameterization that should raise an error.""" 14 | if not issubclass(raises, Exception): 15 | raise ValueError("raises should be an expected Exception.") 16 | return pytest.param(*args, marks=[pytest.mark.xfail(exception=raises)]) 17 | 18 | 19 | def skipparam(*args, cond=True, reason=""): 20 | """Return a pytest parameterization that should be skip if cond is valid.""" 21 | return pytest.param(*args, marks=[pytest.mark.skipif(cond, reason=reason)]) 22 | 23 | 24 | class Dummy: 25 | """Dummy Class.""" 26 | 27 | pass 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: If you believe ModOpt could be improved with the addition of extra functionality 4 | or features feel free to let us know. 5 | title: "[NEW FEATURE]" 6 | labels: enhancement 7 | assignees: '' 8 | 9 | --- 10 | 11 | **Is your feature request related to a problem? Please describe.** 12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 13 | 14 | **Describe the solution you'd like** 15 | In order to increase your chances of having a feature included, be sure to be as clear and specific as possible as to the properties this feature should have. 16 | 17 | **Describe alternatives you've considered** 18 | A clear and concise description of any alternative solutions or features you've considered. 19 | 20 | **Are you planning to submit a Pull Request?** 21 | - [ ] Yes 22 | - [X] No 23 | -------------------------------------------------------------------------------- /docs/source/my_ref.bib: -------------------------------------------------------------------------------- 1 | @ARTICLE{farrens:2020, 2 | author = {{Farrens}, S. and {Grigis}, A. and {El Gueddari}, L. and {Ramzi}, Z. and {Chaithya}, G.~R. and {Starck}, S. and {Sarthou}, B. and {Cherkaoui}, H. and {Ciuciu}, P. and {Starck}, J. -L.}, 3 | title = "{PySAP: Python Sparse Data Analysis Package for multidisciplinary image processing}", 4 | journal = {Astronomy and Computing}, 5 | keywords = {Image processing, Convex optimisation, Reconstruction, Open-source software, Astrophysics - Instrumentation and Methods for Astrophysics}, 6 | year = 2020, 7 | month = jul, 8 | volume = {32}, 9 | eid = {100402}, 10 | pages = {100402}, 11 | doi = {10.1016/j.ascom.2020.100402}, 12 | archivePrefix = {arXiv}, 13 | eprint = {1910.08465}, 14 | primaryClass = {astro-ph.IM}, 15 | adsurl = {https://ui.adsabs.harvard.edu/abs/2020A&C....3200402F}, 16 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 17 | } 18 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | name: Style checking 2 | 3 | on: 4 | push: 5 | branches: [ "master", "main", "develop" ] 6 | pull_request: 7 | branches: [ "master", "main", "develop" ] 8 | 9 | workflow_dispatch: 10 | 11 | env: 12 | PYTHON_VERSION: "3.10" 13 | 14 | jobs: 15 | linter-check: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | - name: Set up Python ${{ env.PYTHON_VERSION }} 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ env.PYTHON_VERSION }} 24 | cache: pip 25 | 26 | - name: Install Python deps 27 | shell: bash 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install -e .[test,dev] 31 | 32 | - name: Black Check 33 | shell: bash 34 | run: black . --diff --color --check 35 | 36 | - name: ruff Check 37 | shell: bash 38 | run: ruff check 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: If you discover a bug while using ModOpt please provide the following information 4 | to help us resolve the issue. 5 | title: "[BUG]" 6 | labels: bug 7 | assignees: '' 8 | 9 | --- 10 | 11 | **System setup** 12 | OS: [e.g] macOS v10.14.1 13 | Python version: [e.g.] v3.6.7 14 | Python environment (if any): [e.g.] conda v4.5.11 15 | 16 | **Describe the bug** 17 | A clear and concise description of what the bug is. 18 | 19 | **To Reproduce** 20 | List the exact steps you followed that lead to the bug you encountered so that we can attempt to recreate the conditions. 21 | 22 | **Expected behavior** 23 | A clear and concise description of what you expected to happen. 24 | 25 | **Screenshots** 26 | If applicable, add screenshots to help explain your problem. 27 | 28 | **Module and lines involved** 29 | If you are aware of the source of the bug we would very much appreciate if you could provide the module(s) and line number(s) affected. This will enable us to more rapidly fix the problem. 30 | 31 | **Are you planning to submit a Pull Request?** 32 | - [ ] Yes 33 | - [X] No 34 | -------------------------------------------------------------------------------- /LICENCE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Samuel Farrens 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/_templates/package.rst_t: -------------------------------------------------------------------------------- 1 | {%- macro automodule(modname, options) -%} 2 | .. automodule:: {{ modname }} 3 | {%- for option in options %} 4 | :{{ option }}: 5 | {%- endfor %} 6 | {%- endmacro %} 7 | 8 | {%- macro toctree(docnames) -%} 9 | .. toctree:: 10 | :maxdepth: {{ maxdepth }} 11 | {% for docname in docnames %} 12 | {{ docname }} 13 | {%- endfor %} 14 | {%- endmacro %} 15 | 16 | {%- if is_namespace %} 17 | {{- [pkgname, "namespace"] | join(" ") | e | heading }} 18 | {% else %} 19 | {{- pkgname | e | heading }} 20 | 21 | {% endif %} 22 | 23 | {%- if modulefirst and not is_namespace %} 24 | {{ automodule(pkgname, automodule_options) }} 25 | {% endif %} 26 | 27 | {%- if subpackages %} 28 | Subpackages 29 | ----------- 30 | 31 | {{ toctree(subpackages) }} 32 | {% endif %} 33 | 34 | {%- if submodules %} 35 | Submodules 36 | ---------- 37 | {% if separatemodules %} 38 | {{ toctree(submodules) }} 39 | {% else %} 40 | {%- for submodule in submodules %} 41 | {% if show_headings %} 42 | {{- submodule | e | heading(2) }} 43 | {% endif %} 44 | {{ automodule(submodule, automodule_options) }} 45 | {% endfor %} 46 | {%- endif %} 47 | {%- endif %} 48 | 49 | {%- if not modulefirst and not is_namespace %} 50 | Module contents 51 | --------------- 52 | 53 | {{ automodule(pkgname, automodule_options) }} 54 | {% endif %} 55 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | .. note:: 5 | 6 | ModOpt is automatically installed by |link-to-pysap|. The following steps are 7 | intended for those that wish to use ModOpt independently of PySAP or those 8 | aiming to contribute to the development of the package. 9 | 10 | .. |link-to-pysap| raw:: html 11 | 12 | PySAP 13 | 14 | Users 15 | ----- 16 | 17 | You can install the latest release of ModOpt from `PyPi `_ 18 | as follows: 19 | 20 | .. code-block:: bash 21 | 22 | pip install modopt 23 | 24 | 25 | Alternatively clone the repository and build the package locally as follows: 26 | 27 | .. code-block:: bash 28 | 29 | pip install . 30 | 31 | 32 | Developers 33 | ---------- 34 | 35 | Developers are recommend to clone the repository and build the package locally 36 | in development mode with testing and documentation packages as follows: 37 | 38 | .. code-block:: bash 39 | 40 | pip install -e ".[develop]" 41 | 42 | Troubleshooting 43 | --------------- 44 | If you encounter any difficulties installing ModOpt we recommend that you 45 | open a |link-to-issue| and we will do our best to help you. 46 | 47 | .. |link-to-issue| raw:: html 48 | 49 | new issue 51 | -------------------------------------------------------------------------------- /src/modopt/plot/cost_plot.py: -------------------------------------------------------------------------------- 1 | """PLOTTING ROUTINES. 2 | 3 | This module contains methods for making plots. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | 11 | try: 12 | import matplotlib.pyplot as plt 13 | except ImportError: # pragma: no cover 14 | import_fail = True 15 | else: 16 | import_fail = False 17 | 18 | 19 | def plotCost(cost_list, output=None): 20 | """Plot cost function. 21 | 22 | Plot the final cost function. 23 | 24 | Parameters 25 | ---------- 26 | cost_list : list 27 | List of cost function values 28 | output : str, optional 29 | Output file name (default is ``None``) 30 | 31 | Raises 32 | ------ 33 | ImportError 34 | If Matplotlib package not found 35 | 36 | """ 37 | if import_fail: 38 | raise ImportError("Matplotlib package not found") 39 | 40 | else: 41 | if isinstance(output, type(None)): 42 | file_name = "cost_function.png" 43 | else: 44 | file_name = f"{output}_cost_function.png" 45 | 46 | plt.figure() 47 | plt.plot(np.log10(cost_list), "r-") 48 | plt.title("Cost Function") 49 | plt.xlabel("Iteration") 50 | plt.ylabel(r"$\log_{10}$ Cost") 51 | plt.savefig(file_name) 52 | plt.close() 53 | 54 | print(" - Saving cost function data to:", file_name) 55 | -------------------------------------------------------------------------------- /examples/conftest.py: -------------------------------------------------------------------------------- 1 | """TEST CONFIGURATION. 2 | 3 | This module contains methods for configuring the testing of the example 4 | scripts. 5 | 6 | :Author: Pierre-Antoine Comby 7 | 8 | Notes 9 | ----- 10 | Based on: 11 | https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test 12 | 13 | """ 14 | 15 | from pathlib import Path 16 | import runpy 17 | import pytest 18 | 19 | 20 | def pytest_collect_file(path, parent): 21 | """Pytest hook. 22 | 23 | Create a collector for the given path, or None if not relevant. 24 | The new node needs to have the specified parent as parent. 25 | """ 26 | p = Path(path) 27 | if p.suffix == ".py" and "example" in p.name: 28 | return Script.from_parent(parent, path=p, name=p.name) 29 | 30 | 31 | class Script(pytest.File): 32 | """Script files collected by pytest.""" 33 | 34 | def collect(self): 35 | """Collect the script as its own item.""" 36 | yield ScriptItem.from_parent(self, name=self.name) 37 | 38 | 39 | class ScriptItem(pytest.Item): 40 | """Item script collected by pytest.""" 41 | 42 | def runtest(self): 43 | """Run the script as a test.""" 44 | runpy.run_path(str(self.path)) 45 | 46 | def repr_failure(self, excinfo): 47 | """Return only the error traceback of the script.""" 48 | excinfo.traceback = excinfo.traceback.cut(path=self.path) 49 | return super().repr_failure(excinfo) 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name="modopt" 3 | description = 'Modular Optimisation tools for soliving inverse problems.' 4 | version = "1.7.2" 5 | requires-python= ">=3.8" 6 | 7 | authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"}, 8 | {name="Chaithya GR", email="chaithyagr@gmail.com"}, 9 | {name="Pierre-Antoine Comby", email="pierre-antoine.comby@cea.fr"}, 10 | {name="Philippe Ciuciu", email="philippe.ciuciu@cea.fr"} 11 | ] 12 | readme="README.md" 13 | license={file="LICENCE.txt"} 14 | 15 | dependencies = ["numpy", "scipy", "tqdm", "importlib_metadata"] 16 | 17 | [project.optional-dependencies] 18 | gpu=["torch", "ptwt"] 19 | doc=["myst-parser", 20 | "nbsphinx", 21 | "nbsphinx-link", 22 | "sphinx-gallery", 23 | "numpydoc", 24 | "sphinxawesome-theme", 25 | "sphinxcontrib-bibtex"] 26 | dev=["black", "ruff"] 27 | test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-xdist", "pytest-sugar"] 28 | 29 | [build-system] 30 | requires=["setuptools", "setuptools-scm[toml]", "wheel"] 31 | 32 | [tool.coverage.run] 33 | omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"] 34 | 35 | [tool.coverage.report] 36 | precision = 2 37 | exclude_lines = ["pragma: no cover", "raise NotImplementedError"] 38 | 39 | [tool.black] 40 | 41 | 42 | [tool.ruff] 43 | exclude = ["examples", "docs"] 44 | [tool.ruff.lint] 45 | select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"] 46 | 47 | ignore = ["F401"] # we like the try: import ... expect: ... 48 | 49 | [tool.ruff.lint.pydocstyle] 50 | convention="numpy" 51 | 52 | [tool.isort] 53 | profile="black" 54 | 55 | [tool.pytest.ini_options] 56 | minversion = "6.0" 57 | norecursedirs = ["tests/test_helpers"] 58 | addopts = ["--cov=modopt", "--cov-report=term-missing", "--cov-report=xml"] 59 | -------------------------------------------------------------------------------- /docs/source/about.rst: -------------------------------------------------------------------------------- 1 | About 2 | ===== 3 | 4 | ModOpt was developed as part of |link-to-cosmic|, a multi-disciplinary collaboration 5 | between |link-to-neurospin|, experts in biomedical imaging, and |link-to-cosmostat|, 6 | experts in astrophysical image processing. The package was 7 | designed to provide the backend optimisation algorithms for 8 | |link-to-pysap| :cite:`farrens:2020`, but also serves as a stand-alone library 9 | of inverse problem solving tools. While PySAP aims to provide 10 | application-specific tools for solving complex imaging problems, ModOpt can in 11 | principle be applied to any linear inverse problem. 12 | 13 | Contributors 14 | ------------ 15 | 16 | You can find a |link-to-contributors|. 17 | 18 | |CS_LOGO| |NS_LOGO| 19 | 20 | .. |link-to-cosmic| raw:: html 21 | 22 | COSMIC 23 | 24 | .. |link-to-neurospin| raw:: html 25 | 26 | NeuroSpin 28 | 29 | .. |link-to-cosmostat| raw:: html 30 | 31 | CosmoStat 33 | 34 | .. |link-to-pysap| raw:: html 35 | 36 | PySAP 37 | 38 | .. |link-to-contributors| raw:: html 39 | 40 | list of ModOpt contributors here 42 | 43 | .. |CS_LOGO| image:: cosmostat_logo.jpg 44 | :width: 45% 45 | :alt: CosmoStat Logo 46 | :target: http://www.cosmostat.org/ 47 | 48 | .. |NS_LOGO| image:: neurospin_logo.png 49 | :width: 45% 50 | :alt: NeuroSpin Logo 51 | :target: https://joliot.cea.fr/drf/joliot/en/Pages/research_entities/NeuroSpin.aspx 52 | -------------------------------------------------------------------------------- /.github/workflows/ci-build.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - develop 7 | - master 8 | - main 9 | 10 | jobs: 11 | test-full: 12 | name: Full Test Suite 13 | runs-on: ${{ matrix.os }} 14 | 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [ubuntu-latest, macos-latest] 19 | python-version: ["3.8", "3.9", "3.10"] 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | cache: pip 27 | 28 | - name: Install Dependencies 29 | shell: bash -l {0} 30 | run: | 31 | python --version 32 | python -m pip install --upgrade pip 33 | python -m pip install .[test] 34 | python -m pip install astropy scikit-image scikit-learn matplotlib 35 | python -m pip install tensorflow>=2.4.1 torch 36 | 37 | - name: Run Tests 38 | shell: bash -l {0} 39 | run: | 40 | pytest -n 2 41 | 42 | - name: Save Test Results 43 | if: always() 44 | uses: actions/upload-artifact@v4 45 | with: 46 | name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }} 47 | path: coverage.xml 48 | 49 | - name: Check API Documentation build 50 | shell: bash -l {0} 51 | run: | 52 | apt install pandoc 53 | pip install .[doc] ipykernel 54 | sphinx-apidoc -t docs/_templates -feTMo docs/source modopt 55 | sphinx-build -b doctest -E docs/source docs/_build 56 | 57 | - name: Upload Coverage to Codecov 58 | uses: codecov/codecov-action@v1 59 | with: 60 | token: ${{ secrets.CODECOV_TOKEN }} 61 | file: coverage.xml 62 | flags: unittests 63 | 64 | -------------------------------------------------------------------------------- /src/modopt/opt/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | r"""OPTIMISATION ALGORITHMS. 2 | 3 | This module contains class implementations of various optimisation algoritms. 4 | 5 | :Authors: 6 | 7 | * Samuel Farrens , 8 | * Zaccharie Ramzi , 9 | * Pierre-Antoine Comby 10 | 11 | :Notes: 12 | 13 | Input classes must have the following properties: 14 | 15 | * **Gradient Operators** 16 | 17 | Must have the following methods: 18 | 19 | * ``get_grad()`` - calculate the gradient 20 | 21 | Must have the following variables: 22 | 23 | * ``grad`` - the gradient 24 | 25 | * **Linear Operators** 26 | 27 | Must have the following methods: 28 | 29 | * ``op()`` - operator 30 | * ``adj_op()`` - adjoint operator 31 | 32 | * **Proximity Operators** 33 | 34 | Must have the following methods: 35 | 36 | * ``op()`` - operator 37 | 38 | The following notation is used to implement the algorithms: 39 | 40 | * ``x_old`` is used in place of :math:`x_{n}`. 41 | * ``x_new`` is used in place of :math:`x_{n+1}`. 42 | * ``x_prox`` is used in place of :math:`\tilde{x}_{n+1}`. 43 | * ``x_temp`` is used for intermediate operations. 44 | 45 | """ 46 | 47 | from .forward_backward import FISTA, ForwardBackward, GenForwardBackward, POGM 48 | from .primal_dual import Condat 49 | from .gradient_descent import ( 50 | ADAMGradOpt, 51 | AdaGenericGradOpt, 52 | GenericGradOpt, 53 | MomentumGradOpt, 54 | RMSpropGradOpt, 55 | SAGAOptGradOpt, 56 | VanillaGenericGradOpt, 57 | ) 58 | from .admm import ADMM, FastADMM 59 | 60 | __all__ = [ 61 | "FISTA", 62 | "ForwardBackward", 63 | "GenForwardBackward", 64 | "POGM", 65 | "Condat", 66 | "ADAMGradOpt", 67 | "AdaGenericGradOpt", 68 | "GenericGradOpt", 69 | "MomentumGradOpt", 70 | "RMSpropGradOpt", 71 | "SAGAOptGradOpt", 72 | "VanillaGenericGradOpt", 73 | "ADMM", 74 | "FastADMM", 75 | ] 76 | -------------------------------------------------------------------------------- /.github/workflows/cd-build.yml: -------------------------------------------------------------------------------- 1 | name: CD 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - main 8 | 9 | jobs: 10 | 11 | coverage: 12 | name: Deploy Coverage Results 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | 19 | - uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | cache: pip 23 | 24 | - name: Install dependencies 25 | shell: bash -l {0} 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install twine 29 | python -m pip install .[doc,test] 30 | 31 | - name: Run Tests 32 | shell: bash -l {0} 33 | run: | 34 | pytest 35 | 36 | - name: Check distribution 37 | shell: bash -l {0} 38 | run: | 39 | twine check dist/* 40 | 41 | - name: Upload coverage to Codecov 42 | uses: codecov/codecov-action@v1 43 | with: 44 | token: ${{ secrets.CODECOV_TOKEN }} 45 | file: coverage.xml 46 | flags: unittests 47 | 48 | api: 49 | name: Deploy API Documentation 50 | needs: coverage 51 | runs-on: ubuntu-latest 52 | if: success() 53 | 54 | steps: 55 | - name: Checkout 56 | uses: actions/checkout@v4 57 | 58 | 59 | - name: Install dependencies 60 | shell: bash -l {0} 61 | run: | 62 | conda install -c conda-forge pandoc 63 | python -m pip install --upgrade pip 64 | python -m pip install .[doc] 65 | 66 | - name: Build API documentation 67 | shell: bash -l {0} 68 | run: | 69 | sphinx-apidoc -t docs/_templates -feTMo docs/source modopt 70 | sphinx-build -E docs/source docs/_build 71 | 72 | - name: Deploy API documentation 73 | uses: peaceiris/actions-gh-pages@v3.5.9 74 | with: 75 | github_token: ${{ secrets.GITHUB_TOKEN }} 76 | publish_dir: docs/_build 77 | -------------------------------------------------------------------------------- /src/modopt/interface/log.py: -------------------------------------------------------------------------------- 1 | """LOGGING ROUTINES. 2 | 3 | This module contains methods for handing logging. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import logging 10 | 11 | 12 | def set_up_log(filename, verbose=True): 13 | """Set up log. 14 | 15 | This method sets up a basic log. 16 | 17 | Parameters 18 | ---------- 19 | filename : str 20 | Log file name 21 | verbose : bool 22 | Option for verbose output (default is ``True``) 23 | 24 | Returns 25 | ------- 26 | logging.Logger 27 | Logging instance 28 | 29 | """ 30 | # Add file extension. 31 | filename = f"{filename}.log" 32 | 33 | if verbose: 34 | print("Preparing log file:", filename) 35 | 36 | # Capture warnings. 37 | logging.captureWarnings(True) 38 | 39 | # Set output format. 40 | formatter = logging.Formatter( 41 | fmt="%(asctime)s %(message)s", 42 | datefmt="%d/%m/%Y %H:%M:%S", 43 | ) 44 | 45 | # Create file handler. 46 | fh = logging.FileHandler(filename=filename, mode="w") 47 | fh.setLevel(logging.DEBUG) 48 | fh.setFormatter(formatter) 49 | 50 | # Create log. 51 | log = logging.getLogger(filename) 52 | log.setLevel(logging.DEBUG) 53 | log.addHandler(fh) 54 | 55 | # Send opening message. 56 | log.info("The log file has been set-up.") 57 | 58 | return log 59 | 60 | 61 | def close_log(log, verbose=True): 62 | """Close log. 63 | 64 | This method closes and active logging.Logger instance. 65 | 66 | Parameters 67 | ---------- 68 | log : logging.Logger 69 | Logging instance 70 | verbose : bool 71 | Option for verbose output (default is ``True``) 72 | 73 | """ 74 | if verbose: 75 | print("Closing log file:", log.name) 76 | 77 | # Send closing message. 78 | log.info("The log file has been closed.") 79 | 80 | # Remove all handlers from log. 81 | for log_handler in log.handlers: 82 | log.removeHandler(log_handler) 83 | -------------------------------------------------------------------------------- /src/modopt/signal/validation.py: -------------------------------------------------------------------------------- 1 | """VALIDATION ROUTINES. 2 | 3 | This module contains methods for testing signal and operator properties. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | 11 | 12 | def transpose_test( 13 | operator, 14 | operator_t, 15 | x_shape, 16 | x_args=None, 17 | y_shape=None, 18 | y_args=None, 19 | rng=None, 20 | ): 21 | """Transpose test. 22 | 23 | This method tests two operators to see if they are the transpose of each 24 | other. 25 | 26 | Parameters 27 | ---------- 28 | operator : callable 29 | Operator function 30 | operator_t : callable 31 | Transpose operator function 32 | x_shape : tuple 33 | Shape of operator input data 34 | x_args : tuple 35 | Arguments to be passed to operator (default is ``None``) 36 | y_shape : tuple, optional 37 | Shape of transpose operator input data (default is ``None``) 38 | y_args : tuple, optional 39 | Arguments to be passed to transpose operator (default is ``None``) 40 | rng: numpy.random.Generator or int or None (default is ``None``) 41 | Initialized random number generator or seed. 42 | 43 | Raises 44 | ------ 45 | TypeError 46 | If input operators not callable 47 | 48 | Examples 49 | -------- 50 | >>> import numpy as np 51 | >>> from modopt.signal.validation import transpose_test 52 | >>> np.random.seed(1) 53 | >>> a = np.random.ranf((3, 3)) 54 | >>> transpose_test(lambda x, y: x.dot(y), lambda x, y: x.dot(y.T), 55 | ... a.shape, x_args=a) 56 | - | - | = 0.0 57 | 58 | """ 59 | if not callable(operator) or not callable(operator_t): 60 | raise TypeError("The input operators must be callable functions.") 61 | 62 | if isinstance(y_shape, type(None)): 63 | y_shape = x_shape 64 | 65 | if isinstance(y_args, type(None)): 66 | y_args = x_args 67 | 68 | if not isinstance(rng, np.random.Generator): 69 | rng = np.random.default_rng(rng) 70 | # Generate random arrays. 71 | x_val = rng.random(x_shape) 72 | y_val = rng.random(y_shape) 73 | 74 | # Calculate 75 | mx_y = np.sum(np.multiply(operator(x_val, x_args), y_val)) 76 | 77 | # Calculate 78 | x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args))) 79 | 80 | # Test the difference between the two. 81 | print(" - | - | =", np.abs(mx_y - x_mty)) 82 | -------------------------------------------------------------------------------- /src/modopt/opt/reweight.py: -------------------------------------------------------------------------------- 1 | """REWEIGHTING CLASSES. 2 | 3 | This module contains classes for reweighting optimisation implementations. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | 11 | from modopt.base.types import check_float 12 | 13 | 14 | class cwbReweight: 15 | """Candes, Wakin and Boyd reweighting class. 16 | 17 | This class implements the reweighting scheme described in 18 | :cite:`candes2007`. 19 | 20 | Parameters 21 | ---------- 22 | weights : numpy.ndarray 23 | Array of weights 24 | thresh_factor : float 25 | Threshold factor (default is ``1.0``) 26 | 27 | Examples 28 | -------- 29 | >>> import numpy as np 30 | >>> from modopt.opt.reweight import cwbReweight 31 | >>> a = np.arange(9).reshape(3, 3).astype(float) + 1 32 | >>> rw = cwbReweight(a) 33 | >>> rw.weights 34 | array([[1., 2., 3.], 35 | [4., 5., 6.], 36 | [7., 8., 9.]]) 37 | >>> rw.reweight(a) 38 | >>> rw.weights 39 | array([[0.5, 1. , 1.5], 40 | [2. , 2.5, 3. ], 41 | [3.5, 4. , 4.5]]) 42 | 43 | """ 44 | 45 | def __init__(self, weights, thresh_factor=1.0, verbose=False): 46 | self.weights = check_float(weights) 47 | self.original_weights = np.copy(self.weights) 48 | self.thresh_factor = check_float(thresh_factor) 49 | self._rw_num = 1 50 | self.verbose = verbose 51 | 52 | def reweight(self, input_data): 53 | r"""Reweight. 54 | 55 | This method implements the reweighting from section 4 in 56 | :cite:`candes2007`. 57 | 58 | Parameters 59 | ---------- 60 | input_data : numpy.ndarray 61 | Input data 62 | 63 | Raises 64 | ------ 65 | ValueError 66 | For invalid input shape 67 | 68 | Notes 69 | ----- 70 | Reweighting implemented as: 71 | 72 | .. math:: 73 | 74 | w = w \left( \frac{1}{1 + \frac{|x^w|}{n \sigma}} \right) 75 | 76 | where :math:`w` are the weights, :math:`x` is the ``input_data`` and 77 | :math:`n` is the ``thresh_factor``. 78 | 79 | """ 80 | if self.verbose: 81 | print(f" - Reweighting: {self._rw_num}") 82 | 83 | self._rw_num += 1 84 | 85 | input_data = check_float(input_data) 86 | 87 | if input_data.shape != self.weights.shape: 88 | raise ValueError( 89 | "Input data must have the same shape as the initial weights.", 90 | ) 91 | 92 | thresh_weights = self.thresh_factor * self.original_weights 93 | 94 | self.weights *= np.array( 95 | 1.0 / (1.0 + np.abs(input_data) / (thresh_weights)), 96 | ) 97 | -------------------------------------------------------------------------------- /docs/source/dependencies.rst: -------------------------------------------------------------------------------- 1 | Dependencies 2 | ============ 3 | 4 | .. note:: 5 | 6 | All packages required by ModOpt should be installed automatically. Optional 7 | packages, however, will need to be installed manually. 8 | 9 | Required Packages 10 | ----------------- 11 | 12 | In order to use ModOpt the following packages must be installed: 13 | 14 | * |link-to-python| ``[>= 3.6]`` 15 | * |link-to-metadata| ``[>=3.7.0]`` 16 | * |link-to-numpy| ``[>=1.19.5]`` 17 | * |link-to-scipy| ``[>=1.5.4]`` 18 | * |link-to-progressbar| ``[>=3.53.1]`` 19 | 20 | .. |link-to-python| raw:: html 21 | 22 | Python 24 | 25 | .. |link-to-metadata| raw:: html 26 | 27 | importlib_metadata 29 | 30 | .. |link-to-numpy| raw:: html 31 | 32 | Numpy 34 | 35 | .. |link-to-scipy| raw:: html 36 | 37 | Scipy 39 | 40 | .. |link-to-progressbar| raw:: html 41 | 42 | Progressbar 2 44 | 45 | Optional Packages 46 | ----------------- 47 | 48 | The following packages can optionally be installed to add extra functionality: 49 | 50 | * |link-to-astropy| 51 | * |link-to-matplotlib| 52 | * |link-to-skimage| 53 | * |link-to-sklearn| 54 | * |link-to-termcolor| 55 | 56 | .. |link-to-astropy| raw:: html 57 | 58 | Astropy 60 | 61 | .. |link-to-matplotlib| raw:: html 62 | 63 | Matplotlib 65 | 66 | .. |link-to-skimage| raw:: html 67 | 68 | Scikit-Image 70 | 71 | .. |link-to-sklearn| raw:: html 72 | 73 | Scikit-Learn 75 | 76 | .. |link-to-termcolor| raw:: html 77 | 78 | Termcolor 80 | 81 | For GPU compliance the following packages can also be installed: 82 | 83 | * |link-to-cupy| 84 | * |link-to-torch| 85 | * |link-to-tf| 86 | 87 | .. |link-to-cupy| raw:: html 88 | 89 | CuPy 91 | 92 | .. |link-to-torch| raw:: html 93 | 94 | Torch 96 | 97 | .. |link-to-tf| raw:: html 98 | 99 | TensorFlow 101 | 102 | .. note:: 103 | 104 | Note that none of these are required for running on a CPU. 105 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | pytest.xml 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | docs/source/fortuna.* 75 | docs/source/scripts.* 76 | docs/source/auto_examples/ 77 | docs/source/*.nblink 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | -------------------------------------------------------------------------------- /src/modopt/signal/positivity.py: -------------------------------------------------------------------------------- 1 | """POSITIVITY. 2 | 3 | This module contains a function that retains only positive coefficients in 4 | an array. 5 | 6 | :Author: Samuel Farrens 7 | 8 | """ 9 | 10 | import numpy as np 11 | 12 | 13 | def pos_thresh(input_data): 14 | """Positive Threshold. 15 | 16 | Keep only positive coefficients from input data. 17 | 18 | Parameters 19 | ---------- 20 | input_data : int, float, list, tuple or numpy.ndarray 21 | Input data 22 | 23 | Returns 24 | ------- 25 | int, float, or numpy.ndarray 26 | Positive coefficients 27 | 28 | """ 29 | return input_data * (input_data > 0) 30 | 31 | 32 | def pos_recursive(input_data): 33 | """Positive Recursive. 34 | 35 | Run pos_thresh on input array or recursively for ragged nested arrays. 36 | 37 | Parameters 38 | ---------- 39 | input_data : numpy.ndarray 40 | Input data 41 | 42 | Returns 43 | ------- 44 | numpy.ndarray 45 | Positive coefficients 46 | 47 | """ 48 | if input_data.dtype == "O": 49 | res = np.array([pos_recursive(elem) for elem in input_data], dtype="object") 50 | 51 | else: 52 | res = pos_thresh(input_data) 53 | 54 | return res 55 | 56 | 57 | def positive(input_data, ragged=False): 58 | """Positivity operator. 59 | 60 | This method preserves only the positive coefficients of the input data, all 61 | negative coefficients are set to zero. 62 | 63 | Parameters 64 | ---------- 65 | input_data : int, float, or numpy.ndarray 66 | Input data 67 | ragged : bool, optional 68 | Specify if the input_data is a ragged nested array 69 | (defaul is ``False``) 70 | 71 | Returns 72 | ------- 73 | int or float, or numpy.ndarray 74 | Positive coefficients 75 | 76 | Raises 77 | ------ 78 | TypeError 79 | For invalid input type. 80 | 81 | Examples 82 | -------- 83 | >>> import numpy as np 84 | >>> from modopt.signal.positivity import positive 85 | >>> a = np.arange(9).reshape(3, 3) - 5 86 | >>> a 87 | array([[-5, -4, -3], 88 | [-2, -1, 0], 89 | [ 1, 2, 3]]) 90 | >>> positive(a) 91 | array([[0, 0, 0], 92 | [0, 0, 0], 93 | [1, 2, 3]]) 94 | 95 | """ 96 | if not isinstance(input_data, (int, float, list, tuple, np.ndarray)): 97 | raise TypeError( 98 | "Invalid data type, input must be `int`, `float`, `list`, " 99 | + "`tuple` or `np.ndarray`.", 100 | ) 101 | 102 | if isinstance(input_data, (int, float)): 103 | return pos_thresh(input_data) 104 | 105 | if ragged: 106 | input_data = np.array(input_data, dtype="object") 107 | 108 | else: 109 | input_data = np.array(input_data) 110 | 111 | return pos_recursive(input_data) 112 | -------------------------------------------------------------------------------- /src/modopt/signal/filter.py: -------------------------------------------------------------------------------- 1 | """FILTER ROUTINES. 2 | 3 | This module contains methods for distance measurements in cosmology. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | 11 | from modopt.base.types import check_float 12 | 13 | 14 | def gaussian_filter(data_point, sigma, norm=True): 15 | """Gaussian filter. 16 | 17 | This method implements a Gaussian filter. 18 | 19 | Parameters 20 | ---------- 21 | data_point : float 22 | Input data point 23 | sigma : float 24 | Standard deviation (filter scale) 25 | norm : bool 26 | Option to return normalised data (default is ``True``) 27 | 28 | Returns 29 | ------- 30 | float 31 | Gaussian filtered data point 32 | 33 | Examples 34 | -------- 35 | >>> from modopt.signal.filter import gaussian_filter 36 | >>> gaussian_filter(1, 1) 37 | 0.24197072451914337 38 | 39 | >>> gaussian_filter(1, 1, False) 40 | 0.6065306597126334 41 | 42 | """ 43 | data_point = check_float(data_point) 44 | sigma = check_float(sigma) 45 | 46 | numerator = np.exp(-0.5 * (data_point / sigma) ** 2) 47 | 48 | if norm: 49 | return numerator / (np.sqrt(2 * np.pi) * sigma) 50 | 51 | return numerator 52 | 53 | 54 | def mex_hat(data_point, sigma): 55 | """Mexican hat. 56 | 57 | This method implements a Mexican hat (or Ricker) wavelet. 58 | 59 | Parameters 60 | ---------- 61 | data_point : float 62 | Input data point 63 | sigma : float 64 | Standard deviation (filter scale) 65 | 66 | Returns 67 | ------- 68 | float 69 | Mexican hat filtered data point 70 | 71 | Examples 72 | -------- 73 | >>> from modopt.signal.filter import mex_hat 74 | >>> round(mex_hat(2, 1), 15) 75 | -0.352139052257134 76 | 77 | """ 78 | data_point = check_float(data_point) 79 | sigma = check_float(sigma) 80 | 81 | xs = (data_point / sigma) ** 2 82 | factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25 83 | 84 | return factor * (1 - xs) * np.exp(-0.5 * xs) 85 | 86 | 87 | def mex_hat_dir(data_gauss, data_mex, sigma): 88 | """Directional Mexican hat. 89 | 90 | This method implements a directional Mexican hat (or Ricker) wavelet. 91 | 92 | Parameters 93 | ---------- 94 | data_gauss : float 95 | Input data point for Gaussian 96 | data_mex : float 97 | Input data point for Mexican hat 98 | sigma : float 99 | Standard deviation (filter scale) 100 | 101 | Returns 102 | ------- 103 | float 104 | Directional Mexican hat filtered data point 105 | 106 | Examples 107 | -------- 108 | >>> from modopt.signal.filter import mex_hat_dir 109 | >>> round(mex_hat_dir(1, 2, 1), 16) 110 | 0.1760695261285668 111 | 112 | """ 113 | data_gauss = check_float(data_gauss) 114 | sigma = check_float(sigma) 115 | 116 | return -0.5 * (data_gauss / sigma) ** 2 * mex_hat(data_mex, sigma) 117 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of experience, 9 | education, socio-economic status, nationality, personal appearance, race, 10 | religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at {{ email }}. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | -------------------------------------------------------------------------------- /src/modopt/interface/errors.py: -------------------------------------------------------------------------------- 1 | """ERROR HANDLING ROUTINES. 2 | 3 | This module contains methods for handing warnings and errors. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import os 10 | import sys 11 | import warnings 12 | 13 | try: 14 | from termcolor import colored 15 | except ImportError: 16 | import_fail = True 17 | else: 18 | import_fail = False 19 | 20 | 21 | def warn(warn_string, log=None): 22 | """Warning. 23 | 24 | This method creates custom warning messages. 25 | 26 | Parameters 27 | ---------- 28 | warn_string : str 29 | Warning message string 30 | log : logging.Logger, optional 31 | Logging structure instance (default is ``None``) 32 | 33 | """ 34 | if import_fail: 35 | warn_txt = "WARNING" 36 | else: 37 | warn_txt = colored("WARNING", "yellow") 38 | 39 | # Print warning to stdout. 40 | sys.stderr.write(f"{warn_txt}: {warn_string}\n") 41 | 42 | # Check if a logging structure is provided. 43 | if not isinstance(log, type(None)): 44 | warnings.warn(warn_string, stacklevel=2) 45 | 46 | 47 | def catch_error(exception, log=None): 48 | """Catch error. 49 | 50 | This method catches errors and prints them to the terminal. It also saves 51 | the errors to a log if provided. 52 | 53 | Parameters 54 | ---------- 55 | exception : str 56 | Exception message string 57 | log : logging.Logger, optional 58 | Logging structure instance (default is ``None``) 59 | 60 | """ 61 | if import_fail: 62 | err_txt = "ERROR" 63 | else: 64 | err_txt = colored("ERROR", "red") 65 | 66 | # Print exception to stdout. 67 | stream_txt = f"{err_txt}: {exception}\n" 68 | sys.stderr.write(stream_txt) 69 | 70 | # Check if a logging structure is provided. 71 | if not isinstance(log, type(None)): 72 | log_txt = f"ERROR: {exception}\n" 73 | log.exception(log_txt) 74 | 75 | 76 | def file_name_error(file_name): 77 | """File name error. 78 | 79 | This method checks if the input file name is valid. 80 | 81 | Parameters 82 | ---------- 83 | file_name : str 84 | File name string 85 | 86 | Raises 87 | ------ 88 | IOError 89 | If file name not specified or file not found 90 | 91 | """ 92 | if file_name == "" or file_name[0][0] == "-": 93 | raise OSError("Input file name not specified.") 94 | 95 | elif not os.path.isfile(file_name): 96 | raise OSError(f"Input file name {file_name} not found!") 97 | 98 | 99 | def is_exe(fpath): 100 | """Is Executable. 101 | 102 | Check if the input file path corresponds to an executable on the system. 103 | 104 | Parameters 105 | ---------- 106 | fpath : str 107 | File path 108 | 109 | Returns 110 | ------- 111 | bool 112 | True if file path exists 113 | 114 | """ 115 | return os.path.isfile(fpath) and os.access(fpath, os.X_OK) 116 | 117 | 118 | def is_executable(exe_name): 119 | """Check if Input is Executable. 120 | 121 | This method checks if the input executable exists. 122 | 123 | Parameters 124 | ---------- 125 | exe_name : str 126 | Executable name 127 | 128 | Raises 129 | ------ 130 | TypeError 131 | For invalid input type 132 | IOError 133 | For invalid system executable 134 | 135 | """ 136 | if not isinstance(exe_name, str): 137 | raise TypeError("Executable name must be a string.") 138 | 139 | fpath, fname = os.path.split(exe_name) 140 | 141 | if fpath: 142 | res = is_exe(exe_name) 143 | 144 | else: 145 | res = any( 146 | is_exe(os.path.join(path, exe_name)) 147 | for path in os.environ["PATH"].split(os.pathsep) 148 | ) 149 | 150 | if not res: 151 | message = "{0} does not appear to be a valid executable on this system." 152 | raise OSError(message.format(exe_name)) 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ModOpt 2 | 3 | 4 | 5 | | Usage | Development | Release | 6 | | ----- | ----------- | ------- | 7 | | [![docs](https://img.shields.io/badge/docs-Sphinx-blue)](https://cea-cosmic.github.io/ModOpt/) | [![build](https://github.com/CEA-COSMIC/modopt/workflows/CI/badge.svg)](https://github.com/CEA-COSMIC/modopt/actions?query=workflow%3ACI) | [![release](https://img.shields.io/github/v/release/CEA-COSMIC/modopt)](https://github.com/CEA-COSMIC/modopt/releases/latest) | 8 | | [![license](https://img.shields.io/github/license/CEA-COSMIC/modopt)](https://github.com/CEA-COSMIC/modopt/blob/master/LICENCE.txt) | [![deploy](https://github.com/CEA-COSMIC/modopt/workflows/CD/badge.svg)](https://github.com/CEA-COSMIC/modopt/actions?query=workflow%3ACD) | [![pypi](https://img.shields.io/pypi/v/modopt)](https://pypi.org/project/modopt/) | 9 | | [![wemake-python-styleguide](https://img.shields.io/badge/style-wemake-000000.svg)](https://github.com/wemake-services/wemake-python-styleguide) | [![codecov](https://codecov.io/gh/CEA-COSMIC/modopt/branch/master/graph/badge.svg?token=XHJIQXV7AX)](https://codecov.io/gh/CEA-COSMIC/modopt) | [![python](https://img.shields.io/pypi/pyversions/modopt)](https://www.python.org/downloads/source/) | 10 | | [![contribute](https://img.shields.io/badge/contribute-read-lightgrey)](https://github.com/CEA-COSMIC/modopt/blob/master/CONTRIBUTING.md) | [![CodeFactor](https://www.codefactor.io/repository/github/CEA-COSMIC/modopt/badge)](https://www.codefactor.io/repository/github/CEA-COSMIC/modopt) | | 11 | | [![coc](https://img.shields.io/badge/conduct-read-lightgrey)](https://github.com/CEA-COSMIC/modopt/blob/master/CODE_OF_CONDUCT.md) | [![Updates](https://pyup.io/repos/github/CEA-COSMIC/modopt/shield.svg)](https://pyup.io/repos/github/CEA-COSMIC/ModOpt/) | | 12 | 13 | ModOpt is a series of **Modular Optimisation** tools for solving inverse problems. 14 | 15 | See [documentation](https://CEA-COSMIC.github.io/ModOpt/) for more details. 16 | 17 | ## Installation 18 | 19 | To install using `pip` run the following command: 20 | 21 | ```bash 22 | $ pip install modopt 23 | ``` 24 | 25 | To clone the ModOpt repository from GitHub run the following command: 26 | 27 | ```bash 28 | $ git clone https://github.com/CEA-COSMIC/ModOpt.git 29 | ``` 30 | 31 | ## Dependencies 32 | 33 | All packages required by ModOpt should be installed automatically. Optional packages, however, will need to be installed manually. 34 | 35 | ### Required Packages 36 | 37 | In order to run the code in this repository the following packages must be 38 | installed: 39 | 40 | * [Python](https://www.python.org/) [> 3.7] 41 | * [importlib_metadata](https://importlib-metadata.readthedocs.io/en/latest/) [==3.7.0] 42 | * [Numpy](http://www.numpy.org/) [==1.19.5] 43 | * [Scipy](http://www.scipy.org/) [==1.5.4] 44 | * [tqdm](https://tqdm.github.io/) [>=4.64.0] 45 | 46 | ### Optional Packages 47 | 48 | The following packages can optionally be installed to add extra functionality: 49 | 50 | * [Astropy](http://www.astropy.org/) 51 | * [Matplotlib](http://matplotlib.org/) 52 | * [Scikit-Image](https://scikit-image.org/) 53 | * [Scikit-Learn](https://scikit-learn.org/) 54 | * [Termcolor](https://pypi.python.org/pypi/termcolor) 55 | 56 | For (partial) GPU compliance the following packages can also be installed. 57 | Note that none of these are required for running on a CPU. 58 | 59 | * [CuPy](https://cupy.dev/) 60 | * [Torch](https://pytorch.org/) 61 | * [TensorFlow](https://www.tensorflow.org/) 62 | 63 | ## Citation 64 | 65 | If you use ModOpt in a scientific publication, we would appreciate citations to the following paper: 66 | 67 | [PySAP: Python Sparse Data Analysis Package for multidisciplinary image processing](https://www.sciencedirect.com/science/article/pii/S2213133720300561), S. Farrens et al., Astronomy and Computing 32, 2020 68 | 69 | The BibTeX citation is the following: 70 | ``` 71 | @Article{farrens2020pysap, 72 | title={{PySAP: Python Sparse Data Analysis Package for multidisciplinary image processing}}, 73 | author={Farrens, S and Grigis, A and El Gueddari, L and Ramzi, Z and Chaithya, GR and Starck, S and Sarthou, B and Cherkaoui, H and Ciuciu, P and Starck, J-L}, 74 | journal={Astronomy and Computing}, 75 | volume={32}, 76 | pages={100402}, 77 | year={2020}, 78 | publisher={Elsevier} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /src/modopt/base/types.py: -------------------------------------------------------------------------------- 1 | """TYPE HANDLING ROUTINES. 2 | 3 | This module contains methods for handing object types. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | from modopt.interface.errors import warn 11 | 12 | 13 | def check_callable(input_obj): 14 | """Check input object is callable. 15 | 16 | This method checks if the input operator is a callable funciton and 17 | optionally adds support for arguments and keyword arguments if not already 18 | provided. 19 | 20 | Parameters 21 | ---------- 22 | input_obj : callable 23 | Callable function 24 | 25 | Raises 26 | ------ 27 | TypeError 28 | For invalid input type 29 | """ 30 | if not callable(input_obj): 31 | raise TypeError("The input object must be a callable function.") 32 | return input_obj 33 | 34 | 35 | def check_float(input_obj): 36 | """Check Float. 37 | 38 | Check if input object is a float or a numpy.ndarray of floats, if not 39 | convert. 40 | 41 | Parameters 42 | ---------- 43 | input_obj : any 44 | Input value 45 | 46 | Returns 47 | ------- 48 | float or numpy.ndarray 49 | Input value as a float 50 | 51 | Raises 52 | ------ 53 | TypeError 54 | For invalid input type 55 | 56 | Examples 57 | -------- 58 | >>> import numpy as np 59 | >>> from modopt.base.types import check_float 60 | >>> a = np.arange(5) 61 | >>> a 62 | array([0, 1, 2, 3, 4]) 63 | >>> check_float(a) 64 | array([0., 1., 2., 3., 4.]) 65 | 66 | See Also 67 | -------- 68 | check_int : related function 69 | 70 | """ 71 | if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)): 72 | raise TypeError("Invalid input type.") 73 | if isinstance(input_obj, int): 74 | input_obj = float(input_obj) 75 | elif isinstance(input_obj, (list, tuple)): 76 | input_obj = np.array(input_obj, dtype=float) 77 | elif isinstance(input_obj, np.ndarray) and ( 78 | not np.issubdtype(input_obj.dtype, np.floating) 79 | ): 80 | input_obj = input_obj.astype(float) 81 | 82 | return input_obj 83 | 84 | 85 | def check_int(input_obj): 86 | """Check Integer. 87 | 88 | Check if input value is an int or a np.ndarray of ints, if not convert. 89 | 90 | Parameters 91 | ---------- 92 | input_obj : any 93 | Input value 94 | 95 | Returns 96 | ------- 97 | int or numpy.ndarray 98 | Input value as an integer 99 | 100 | Raises 101 | ------ 102 | TypeError 103 | For invalid input type 104 | 105 | Examples 106 | -------- 107 | >>> import numpy as np 108 | >>> from modopt.base.types import check_int 109 | >>> a = np.arange(5).astype(float) 110 | >>> a 111 | array([0., 1., 2., 3., 4.]) 112 | >>> check_int(a) 113 | array([0, 1, 2, 3, 4]) 114 | 115 | See Also 116 | -------- 117 | check_float : related function 118 | 119 | """ 120 | if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)): 121 | raise TypeError("Invalid input type.") 122 | if isinstance(input_obj, float): 123 | input_obj = int(input_obj) 124 | elif isinstance(input_obj, (list, tuple)): 125 | input_obj = np.array(input_obj, dtype=int) 126 | elif isinstance(input_obj, np.ndarray) and ( 127 | not np.issubdtype(input_obj.dtype, np.integer) 128 | ): 129 | input_obj = input_obj.astype(int) 130 | 131 | return input_obj 132 | 133 | 134 | def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True): 135 | """Check Numpy ND-Array. 136 | 137 | Check if input object is a numpy array. 138 | 139 | Parameters 140 | ---------- 141 | input_obj : numpy.ndarray 142 | Input object 143 | dtype : type 144 | Numpy ndarray data type 145 | writeable : bool 146 | Option to make array immutable 147 | verbose : bool 148 | Verbosity option 149 | 150 | Raises 151 | ------ 152 | TypeError 153 | For invalid input type 154 | TypeError 155 | For invalid numpy.ndarray dtype 156 | 157 | """ 158 | if not isinstance(input_obj, np.ndarray): 159 | raise TypeError("Input is not a numpy array.") 160 | 161 | if (not isinstance(dtype, type(None))) and ( 162 | not np.issubdtype(input_obj.dtype, dtype) 163 | ): 164 | raise ( 165 | TypeError( 166 | f"The numpy array elements are not of type: {dtype}", 167 | ), 168 | ) 169 | 170 | if not writeable and verbose and input_obj.flags.writeable: 171 | warn("Making input data immutable.") 172 | 173 | input_obj.flags.writeable = writeable 174 | -------------------------------------------------------------------------------- /examples/example_lasso_forward_backward.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solving the LASSO Problem with the Forward Backward Algorithm. 3 | ============================================================== 4 | 5 | This an example to show how to solve an example LASSO Problem 6 | using the Forward-Backward Algorithm. 7 | 8 | In this example we are going to use: 9 | - Modopt Operators (Linear, Gradient, Proximal) 10 | - Modopt implementation of solvers 11 | - Modopt Metric API. 12 | TODO: add reference to LASSO paper. 13 | """ 14 | 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | 18 | from modopt.opt.algorithms import ForwardBackward, POGM 19 | from modopt.opt.cost import costObj 20 | from modopt.opt.linear import LinearParent, Identity 21 | from modopt.opt.gradient import GradBasic 22 | from modopt.opt.proximity import SparseThreshold 23 | from modopt.math.matrix import PowerMethod 24 | from modopt.math.stats import mse 25 | 26 | # %% 27 | # Here we create a instance of the LASSO Problem 28 | 29 | BETA_TRUE = np.array( 30 | [3.0, 1.5, 0, 0, 2, 0, 0, 0] 31 | ) # 8 original values from lLASSO Paper 32 | DIM = len(BETA_TRUE) 33 | 34 | 35 | rng = np.random.default_rng() 36 | sigma_noise = 1 37 | obs = 20 38 | # create a measurement matrix with decaying covariance matrix. 39 | cov = 0.4 ** abs((np.arange(DIM) * np.ones((DIM, DIM))).T - np.arange(DIM)) 40 | x = rng.multivariate_normal(np.zeros(DIM), cov, obs) 41 | 42 | y = x @ BETA_TRUE 43 | y_noise = y + (sigma_noise * np.random.standard_normal(obs)) 44 | 45 | 46 | # %% 47 | # Next we create Operators for solving the problem. 48 | 49 | # MatrixOperator could also work here. 50 | lin_op = LinearParent(lambda b: x @ b, lambda bb: x.T @ bb) 51 | grad_op = GradBasic(y_noise, op=lin_op.op, trans_op=lin_op.adj_op) 52 | 53 | prox_op = SparseThreshold(Identity(), 1, thresh_type="soft") 54 | 55 | # %% 56 | # In order to get the best convergence rate, we first determine the Lipschitz constant of the gradient Operator 57 | # 58 | 59 | calc_lips = PowerMethod(grad_op.trans_op_op, 8, data_type="float32", auto_run=True) 60 | lip = calc_lips.spec_rad 61 | print("lipschitz constant:", lip) 62 | 63 | # %% 64 | # Solving using FISTA algorithm 65 | # ----------------------------- 66 | # 67 | # TODO: Add description/Reference of FISTA. 68 | 69 | cost_op_fista = costObj([grad_op, prox_op], verbose=False) 70 | 71 | fb_fista = ForwardBackward( 72 | np.zeros(8), 73 | beta_param=1 / lip, 74 | grad=grad_op, 75 | prox=prox_op, 76 | cost=cost_op_fista, 77 | metric_call_period=1, 78 | auto_iterate=False, # Just to give us the pleasure of doing things by ourself. 79 | ) 80 | 81 | fb_fista.iterate() 82 | 83 | # %% 84 | # After the run we can have a look at the results 85 | 86 | print(fb_fista.x_final) 87 | mse_fista = mse(fb_fista.x_final, BETA_TRUE) 88 | plt.stem(fb_fista.x_final, label="estimation", linefmt="C0-") 89 | plt.stem(BETA_TRUE, label="reference", linefmt="C1-") 90 | plt.legend() 91 | plt.title(f"FISTA Estimation MSE={mse_fista:.4f}") 92 | 93 | # sphinx_gallery_start_ignore 94 | assert mse(fb_fista.x_final, BETA_TRUE) < 1 95 | # sphinx_gallery_end_ignore 96 | 97 | 98 | # %% 99 | # Solving Using the POGM Algorithm 100 | # -------------------------------- 101 | # 102 | # TODO: Add description/Reference to POGM. 103 | 104 | 105 | cost_op_pogm = costObj([grad_op, prox_op], verbose=False) 106 | 107 | fb_pogm = POGM( 108 | np.zeros(8), 109 | np.zeros(8), 110 | np.zeros(8), 111 | np.zeros(8), 112 | beta_param=1 / lip, 113 | grad=grad_op, 114 | prox=prox_op, 115 | cost=cost_op_pogm, 116 | metric_call_period=1, 117 | auto_iterate=False, # Just to give us the pleasure of doing things by ourself. 118 | ) 119 | 120 | fb_pogm.iterate() 121 | 122 | # %% 123 | # After the run we can have a look at the results 124 | 125 | print(fb_pogm.x_final) 126 | mse_pogm = mse(fb_pogm.x_final, BETA_TRUE) 127 | 128 | plt.stem(fb_pogm.x_final, label="estimation", linefmt="C0-") 129 | plt.stem(BETA_TRUE, label="reference", linefmt="C1-") 130 | plt.legend() 131 | plt.title(f"FISTA Estimation MSE={mse_pogm:.4f}") 132 | # 133 | # sphinx_gallery_start_ignore 134 | assert mse(fb_pogm.x_final, BETA_TRUE) < 1 135 | # sphinx_gallery_end_ignore 136 | 137 | # %% 138 | # Comparing the Two algorithms 139 | # ---------------------------- 140 | 141 | plt.figure() 142 | plt.semilogy(cost_op_fista._cost_list, label="FISTA convergence") 143 | plt.semilogy(cost_op_pogm._cost_list, label="POGM convergence") 144 | plt.xlabel("iterations") 145 | plt.ylabel("Cost Function") 146 | plt.legend() 147 | plt.show() 148 | 149 | 150 | # %% 151 | # We can see that the two algorithm converges quickly, and POGM requires less iterations. 152 | # However the POGM iterations are more costly, so a proper benchmark with time measurement is needed. 153 | # Check the benchopt benchmark for more details. 154 | -------------------------------------------------------------------------------- /src/modopt/math/convolve.py: -------------------------------------------------------------------------------- 1 | """CONVOLUTION ROUTINES. 2 | 3 | This module contains methods for convolution. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | import scipy.signal 11 | 12 | from modopt.base.np_adjust import rotate_stack 13 | from modopt.interface.errors import warn 14 | 15 | try: 16 | from astropy.convolution import convolve_fft 17 | except ImportError: # pragma: no cover 18 | import_astropy = False 19 | warn("astropy not found, will default to scipy for convolution") 20 | else: 21 | import_astropy = True 22 | try: 23 | import pyfftw 24 | except ImportError: # pragma: no cover 25 | pass 26 | else: # pragma: no cover 27 | scipy.fftpack = pyfftw.interfaces.scipy_fftpack 28 | warn('Using pyFFTW "monkey patch" for scipy.fftpack') 29 | 30 | 31 | def convolve(input_data, kernel, method="scipy"): 32 | """Convolve data with kernel. 33 | 34 | This method convolves the input data with a given kernel using FFT and 35 | is the default convolution used for all routines. 36 | 37 | Parameters 38 | ---------- 39 | input_data : numpy.ndarray 40 | Input data array, normally a 2D image 41 | kernel : numpy.ndarray 42 | Input kernel array, normally a 2D kernel 43 | method : {'scipy', 'astropy'}, optional 44 | Convolution method (default is ``'scipy'``) 45 | 46 | Returns 47 | ------- 48 | numpy.ndarray 49 | Convolved data 50 | 51 | Raises 52 | ------ 53 | ValueError 54 | If `data` and `kernel` do not have the same number of dimensions 55 | ValueError 56 | If `method` is not 'astropy' or 'scipy' 57 | 58 | Examples 59 | -------- 60 | >>> import numpy as np 61 | >>> from modopt.math.convolve import convolve 62 | >>> a = np.arange(9).reshape(3, 3) 63 | >>> b = a + 10 64 | >>> convolve(a, b) 65 | array([[ 86., 170., 146.], 66 | [246., 444., 354.], 67 | [290., 494., 374.]]) 68 | 69 | >>> convolve(a, b, method='astropy') 70 | array([[534., 525., 534.], 71 | [453., 444., 453.], 72 | [534., 525., 534.]]) 73 | 74 | See Also 75 | -------- 76 | scipy.signal.fftconvolve : scipy FFT convolution 77 | astropy.convolution.convolve_fft : astropy FFT convolution 78 | 79 | """ 80 | if input_data.ndim != kernel.ndim: 81 | raise ValueError("Data and kernel must have the same dimensions.") 82 | 83 | if method not in {"astropy", "scipy"}: 84 | raise ValueError('Invalid method. Options are "astropy" or "scipy".') 85 | 86 | if not import_astropy: # pragma: no cover 87 | method = "scipy" 88 | 89 | if method == "astropy": 90 | return convolve_fft( 91 | input_data, 92 | kernel, 93 | boundary="wrap", 94 | crop=False, 95 | nan_treatment="fill", 96 | normalize_kernel=False, 97 | ) 98 | 99 | elif method == "scipy": 100 | return scipy.signal.fftconvolve(input_data, kernel, mode="same") 101 | 102 | 103 | def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"): 104 | """Convolve stack of data with stack of kernels. 105 | 106 | This method convolves the input data with a given kernel using FFT and 107 | is the default convolution used for all routines. 108 | 109 | Parameters 110 | ---------- 111 | input_data : numpy.ndarray 112 | Input data array, normally a 2D image 113 | kernel : numpy.ndarray 114 | Input kernel array, normally a 2D kernel 115 | rot_kernel : bool 116 | Option to rotate kernels by 180 degrees (default is ``False``) 117 | method : {'astropy', 'scipy'}, optional 118 | Convolution method (default is ``'scipy'``) 119 | 120 | Returns 121 | ------- 122 | numpy.ndarray 123 | Convolved data 124 | 125 | Examples 126 | -------- 127 | >>> import numpy as np 128 | >>> from modopt.math.convolve import convolve_stack 129 | >>> a = np.arange(18).reshape(2, 3, 3) 130 | >>> b = a + 10 131 | >>> convolve_stack(a, b, method='astropy') 132 | array([[[ 534., 525., 534.], 133 | [ 453., 444., 453.], 134 | [ 534., 525., 534.]], 135 | 136 | [[2721., 2712., 2721.], 137 | [2640., 2631., 2640.], 138 | [2721., 2712., 2721.]]]) 139 | 140 | >>> convolve_stack(a, b, method='astropy', rot_kernel=True) 141 | array([[[ 474., 483., 474.], 142 | [ 555., 564., 555.], 143 | [ 474., 483., 474.]], 144 | 145 | [[2661., 2670., 2661.], 146 | [2742., 2751., 2742.], 147 | [2661., 2670., 2661.]]]) 148 | 149 | See Also 150 | -------- 151 | convolve : The convolution function called by convolve_stack 152 | 153 | """ 154 | if rot_kernel: 155 | kernel = rotate_stack(kernel) 156 | 157 | return np.array( 158 | [ 159 | convolve(data_i, kernel_i, method=method) 160 | for data_i, kernel_i in zip(input_data, kernel) 161 | ] 162 | ) 163 | -------------------------------------------------------------------------------- /src/modopt/signal/noise.py: -------------------------------------------------------------------------------- 1 | """NOISE ROUTINES. 2 | 3 | This module contains methods for adding and removing noise from data. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | 11 | from modopt.base.backend import get_array_module 12 | 13 | 14 | def add_noise(input_data, sigma=1.0, noise_type="gauss", rng=None): 15 | """Add noise to data. 16 | 17 | This method adds Gaussian or Poisson noise to the input data. 18 | 19 | Parameters 20 | ---------- 21 | input_data : numpy.ndarray, list or tuple 22 | Input data array 23 | sigma : float or list, optional 24 | Standard deviation of the noise to be added (``'gauss'`` only, 25 | default is ``1.0``) 26 | noise_type : {'gauss', 'poisson'} 27 | Type of noise to be added (default is ``'gauss'``) 28 | rng: np.random.Generator or int 29 | A Random number generator or a seed to initialize one. 30 | 31 | 32 | Returns 33 | ------- 34 | numpy.ndarray 35 | Input data with added noise 36 | 37 | Raises 38 | ------ 39 | ValueError 40 | If ``noise_type`` is not ``'gauss'`` or ``'poisson'`` 41 | ValueError 42 | If number of ``sigma`` values does not match the first dimension of the 43 | input data 44 | 45 | Examples 46 | -------- 47 | >>> import numpy as np 48 | >>> from modopt.signal.noise import add_noise 49 | >>> x = np.arange(9).reshape(3, 3).astype(float) 50 | >>> x 51 | array([[0., 1., 2.], 52 | [3., 4., 5.], 53 | [6., 7., 8.]]) 54 | >>> np.random.seed(1) 55 | >>> add_noise(x, noise_type='poisson') 56 | array([[ 0., 2., 2.], 57 | [ 4., 5., 10.], 58 | [11., 15., 18.]]) 59 | 60 | >>> import numpy as np 61 | >>> from modopt.signal.noise import add_noise 62 | >>> x = np.zeros(5) 63 | >>> x 64 | array([0., 0., 0., 0., 0.]) 65 | >>> np.random.seed(1) 66 | >>> add_noise(x, sigma=2.0) 67 | array([ 3.24869073, -1.22351283, -1.0563435 , -2.14593724, 1.73081526]) 68 | 69 | """ 70 | if not isinstance(rng, np.random.Generator): 71 | rng = np.random.default_rng(rng) 72 | 73 | input_data = np.array(input_data) 74 | 75 | if noise_type not in {"gauss", "poisson"}: 76 | raise ValueError( 77 | 'Invalid noise type. Options are "gauss" or "poisson"', 78 | ) 79 | 80 | if isinstance(sigma, (list, tuple, np.ndarray)): 81 | if len(sigma) != input_data.shape[0]: 82 | raise ValueError( 83 | "Number of sigma values must match first dimension of input " + "data", 84 | ) 85 | 86 | if noise_type == "gauss": 87 | random = rng.standard_normal(input_data.shape) 88 | 89 | elif noise_type == "poisson": 90 | random = rng.poisson(np.abs(input_data)) 91 | 92 | if isinstance(sigma, (int, float)): 93 | return input_data + sigma * random 94 | 95 | noise = np.array([sig * rand for sig, rand in zip(sigma, random)]) 96 | 97 | return input_data + noise 98 | 99 | 100 | def thresh(input_data, threshold, threshold_type="hard"): 101 | r"""Threshold data. 102 | 103 | This method perfoms hard or soft thresholding on the input data. 104 | 105 | Parameters 106 | ---------- 107 | input_data : numpy.ndarray, list or tuple 108 | Input data array 109 | threshold : float or numpy.ndarray 110 | Threshold level(s) 111 | threshold_type : {'hard', 'soft'} 112 | Type of noise to be added (default is ``'hard'``) 113 | 114 | Returns 115 | ------- 116 | numpy.ndarray 117 | Thresholded data 118 | 119 | Raises 120 | ------ 121 | ValueError 122 | If ``threshold_type`` is not ``'hard'`` or ``'soft'`` 123 | 124 | Notes 125 | ----- 126 | Implements one of the following two equations: 127 | 128 | * Hard Threshold 129 | .. math:: 130 | \mathrm{HT}_\lambda(x) = 131 | \begin{cases} 132 | x & \text{if } |x|\geq\lambda \\ 133 | 0 & \text{otherwise} 134 | \end{cases} 135 | 136 | * Soft Threshold 137 | .. math:: 138 | \mathrm{ST}_\lambda(x) = 139 | \begin{cases} 140 | x-\lambda\text{sign}(x) & \text{if } |x|\geq\lambda \\ 141 | 0 & \text{otherwise} 142 | \end{cases} 143 | 144 | Examples 145 | -------- 146 | >>> import numpy as np 147 | >>> from modopt.signal.noise import thresh 148 | >>> np.random.seed(1) 149 | >>> x = np.random.randint(-9, 9, 10) 150 | >>> x 151 | array([-4, 2, 3, -1, 0, 2, -4, 6, -9, 7]) 152 | >>> thresh(x, 4) 153 | array([-4, 0, 0, 0, 0, 0, -4, 6, -9, 7]) 154 | 155 | >>> import numpy as np 156 | >>> from modopt.signal.noise import thresh 157 | >>> np.random.seed(1) 158 | >>> x = np.random.ranf((3, 3)) 159 | >>> x 160 | array([[4.17022005e-01, 7.20324493e-01, 1.14374817e-04], 161 | [3.02332573e-01, 1.46755891e-01, 9.23385948e-02], 162 | [1.86260211e-01, 3.45560727e-01, 3.96767474e-01]]) 163 | >>> thresh(x, 0.2, threshold_type='soft') 164 | array([[0.217022 , 0.52032449, 0. ], 165 | [0.10233257, 0. , 0. ], 166 | [0. , 0.14556073, 0.19676747]]) 167 | 168 | """ 169 | xp = get_array_module(input_data) 170 | 171 | input_data = xp.array(input_data) 172 | 173 | if threshold_type not in {"hard", "soft"}: 174 | raise ValueError( 175 | 'Invalid threshold type. Options are "hard" or "soft"', 176 | ) 177 | 178 | if threshold_type == "soft": 179 | denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data)) 180 | max_value = xp.maximum((1.0 - threshold / denominator), 0) 181 | 182 | return xp.around(max_value * input_data, decimals=15) 183 | 184 | return input_data * (xp.abs(input_data) >= threshold) 185 | -------------------------------------------------------------------------------- /src/modopt/base/backend.py: -------------------------------------------------------------------------------- 1 | """BACKEND MODULE. 2 | 3 | This module contains methods for GPU Compatiblity. 4 | 5 | :Author: Chaithya G R 6 | 7 | """ 8 | 9 | from importlib import util 10 | 11 | import numpy as np 12 | 13 | from modopt.interface.errors import warn 14 | 15 | try: 16 | import torch 17 | from torch.utils.dlpack import from_dlpack as torch_from_dlpack 18 | from torch.utils.dlpack import to_dlpack as torch_to_dlpack 19 | 20 | except ImportError: # pragma: no cover 21 | import_torch = False 22 | else: 23 | import_torch = True 24 | 25 | # Handle the compatibility with variable 26 | LIBRARIES = { 27 | "cupy": None, 28 | "tensorflow": None, 29 | "numpy": np, 30 | } 31 | 32 | if util.find_spec("cupy") is not None: 33 | try: 34 | import cupy as cp 35 | 36 | LIBRARIES["cupy"] = cp 37 | except ImportError: 38 | pass 39 | 40 | if util.find_spec("tensorflow") is not None: 41 | try: 42 | from tensorflow.experimental import numpy as tnp 43 | 44 | LIBRARIES["tensorflow"] = tnp 45 | except ImportError: 46 | pass 47 | 48 | 49 | def get_backend(backend): 50 | """Get backend. 51 | 52 | Returns the backend module for input specified by string. 53 | 54 | Parameters 55 | ---------- 56 | backend: str 57 | String holding the backend name. One of ``'tensorflow'``, 58 | ``'numpy'`` or ``'cupy'``. 59 | 60 | Returns 61 | ------- 62 | tuple 63 | Returns the module for carrying out calculations and the actual backend 64 | that was reverted towards. If the right libraries are not installed, 65 | the function warns and reverts to the ``'numpy'`` backend. 66 | """ 67 | if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None: 68 | msg = ( 69 | "{0} backend not possible, please ensure that " 70 | + "the optional libraries are installed.\n" 71 | + "Reverting to numpy." 72 | ) 73 | warn(msg.format(backend)) 74 | backend = "numpy" 75 | return LIBRARIES[backend], backend 76 | 77 | 78 | def get_array_module(input_data): 79 | """Get Array Module. 80 | 81 | This method returns the array module, which tells if the data is residing 82 | on GPU or CPU. 83 | 84 | Parameters 85 | ---------- 86 | input_data : numpy.ndarray, cupy.ndarray or tf.experimental.numpy.ndarray 87 | Input data array 88 | 89 | Returns 90 | ------- 91 | module 92 | The numpy or cupy module 93 | 94 | """ 95 | if LIBRARIES["tensorflow"] is not None: 96 | if isinstance(input_data, LIBRARIES["tensorflow"].ndarray): 97 | return LIBRARIES["tensorflow"] 98 | if LIBRARIES["cupy"] is not None: 99 | if isinstance(input_data, LIBRARIES["cupy"].ndarray): 100 | return LIBRARIES["cupy"] 101 | return np 102 | 103 | 104 | def change_backend(input_data, backend="cupy"): 105 | """Move data to device. 106 | 107 | This method changes the backend of an array. This can be used to copy data 108 | to GPU or to CPU. 109 | 110 | Parameters 111 | ---------- 112 | input_data : numpy.ndarray, cupy.ndarray or tf.experimental.numpy.ndarray 113 | Input data array to be moved 114 | backend: str, optional 115 | The backend to use, one among ``'tensorflow'``, ``'cupy'`` and 116 | ``'numpy'``. Default is ``'cupy'``. 117 | 118 | Returns 119 | ------- 120 | backend.ndarray 121 | An ndarray of specified backend 122 | 123 | """ 124 | xp = get_array_module(input_data) 125 | txp, target_backend = get_backend(backend) 126 | if xp == txp: 127 | return input_data 128 | return txp.array(input_data) 129 | 130 | 131 | def move_to_cpu(input_data): 132 | """Move data to CPU. 133 | 134 | This method moves data from GPU to CPU. It returns the same data if it is 135 | already on CPU. 136 | 137 | Parameters 138 | ---------- 139 | input_data : cupy.ndarray or tf.experimental.numpy.ndarray 140 | Input data array to be moved 141 | 142 | Returns 143 | ------- 144 | numpy.ndarray 145 | The NumPy array residing on CPU 146 | 147 | Raises 148 | ------ 149 | ValueError 150 | if the input does not correspond to any array 151 | """ 152 | xp = get_array_module(input_data) 153 | 154 | if xp == LIBRARIES["numpy"]: 155 | return input_data 156 | elif xp == LIBRARIES["cupy"]: 157 | return input_data.get() 158 | elif xp == LIBRARIES["tensorflow"]: 159 | return input_data.data.numpy() 160 | raise ValueError("Cannot identify the array type.") 161 | 162 | 163 | def convert_to_tensor(input_data): 164 | """Convert data to a tensor. 165 | 166 | This method converts input data to a torch tensor. Particularly, this 167 | method is helpful to convert CuPy array to Tensor. 168 | 169 | Parameters 170 | ---------- 171 | input_data : cupy.ndarray 172 | Input data array to be converted 173 | 174 | Returns 175 | ------- 176 | torch.Tensor 177 | The tensor data 178 | 179 | Raises 180 | ------ 181 | ImportError 182 | If Torch package not found 183 | 184 | """ 185 | if not import_torch: 186 | raise ImportError( 187 | "Required version of Torch package not found" 188 | + "see documentation for details: https://cea-cosmic." 189 | + "github.io/ModOpt/#optional-packages", 190 | ) 191 | 192 | xp = get_array_module(input_data) 193 | 194 | if xp == np: 195 | return torch.Tensor(input_data) 196 | 197 | return torch_from_dlpack(input_data.toDlpack()).float() 198 | 199 | 200 | def convert_to_cupy_array(input_data): 201 | """Convert Tensor data to a CuPy Array. 202 | 203 | This method converts input tensor data to a cupy array. 204 | 205 | Parameters 206 | ---------- 207 | input_data : torch.Tensor 208 | Input Tensor to be converted 209 | 210 | Returns 211 | ------- 212 | cupy.ndarray 213 | The tensor data as a CuPy array 214 | 215 | Raises 216 | ------ 217 | ImportError 218 | If Torch package not found 219 | 220 | """ 221 | if not import_torch: 222 | raise ImportError( 223 | "Required version of Torch package not found" 224 | + "see documentation for details: https://cea-cosmic." 225 | + "github.io/ModOpt/#optional-packages", 226 | ) 227 | 228 | if input_data.is_cuda: 229 | return cp.fromDlpack(torch_to_dlpack(input_data)) 230 | 231 | return input_data.detach().numpy() 232 | -------------------------------------------------------------------------------- /src/modopt/math/metrics.py: -------------------------------------------------------------------------------- 1 | """METRICS. 2 | 3 | This module contains classes of different metric functions for optimization. 4 | 5 | :Author: Benoir Sarthou 6 | 7 | """ 8 | 9 | import numpy as np 10 | 11 | from modopt.base.backend import move_to_cpu 12 | 13 | try: 14 | from skimage.metrics import structural_similarity as compare_ssim 15 | except ImportError: # pragma: no cover 16 | import_skimage = False 17 | else: 18 | import_skimage = True 19 | 20 | 21 | def min_max_normalize(img): 22 | """Min-Max Normalize. 23 | 24 | Normalize a given array in the [0,1] range. 25 | 26 | Parameters 27 | ---------- 28 | img : numpy.ndarray 29 | Input image 30 | 31 | Returns 32 | ------- 33 | numpy.ndarray 34 | normalized array 35 | 36 | """ 37 | min_img = img.min() 38 | max_img = img.max() 39 | 40 | return (img - min_img) / (max_img - min_img) 41 | 42 | 43 | def _preprocess_input(test, ref, mask=None): 44 | """Proprocess Input. 45 | 46 | Wrapper to the metric. 47 | 48 | Parameters 49 | ---------- 50 | ref : numpy.ndarray 51 | The reference image 52 | test : numpy.ndarray 53 | The tested image 54 | mask : numpy.ndarray, optional 55 | The mask for the ROI (default is ``None``) 56 | 57 | Raises 58 | ------ 59 | ValueError 60 | For invalid mask value 61 | 62 | Notes 63 | ----- 64 | Compute the metric only on magnetude. 65 | 66 | Returns 67 | ------- 68 | float 69 | The SNR 70 | 71 | """ 72 | test = np.abs(np.copy(test)).astype("float64") 73 | ref = np.abs(np.copy(ref)).astype("float64") 74 | test = min_max_normalize(test) 75 | ref = min_max_normalize(ref) 76 | 77 | if (not isinstance(mask, np.ndarray)) and (mask is not None): 78 | message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.' 79 | raise ValueError(message.format(mask)) 80 | 81 | if mask is None: 82 | return test, ref, None 83 | 84 | return test, ref, mask 85 | 86 | 87 | def ssim(test, ref, mask=None): 88 | """Structural Similarity (SSIM). 89 | 90 | Calculate the SSIM between a test image and a reference image. 91 | 92 | Parameters 93 | ---------- 94 | ref : numpy.ndarray 95 | The reference image 96 | test : numpy.ndarray 97 | The tested image 98 | mask : numpy.ndarray, optional 99 | The mask for the ROI (default is ``None``) 100 | 101 | Raises 102 | ------ 103 | ImportError 104 | If Scikit-Image package not found 105 | 106 | Notes 107 | ----- 108 | Compute the metric only on magnetude. 109 | 110 | Returns 111 | ------- 112 | float 113 | The SNR 114 | 115 | """ 116 | if not import_skimage: # pragma: no cover 117 | raise ImportError( 118 | "Required version of Scikit-Image package not found" 119 | + "see documentation for details: https://cea-cosmic." 120 | + "github.io/ModOpt/#optional-packages", 121 | ) 122 | 123 | test, ref, mask = _preprocess_input(test, ref, mask) 124 | test = move_to_cpu(test) 125 | assim, ssim_value = compare_ssim(test, ref, full=True, data_range=1.0) 126 | 127 | if mask is None: 128 | return assim 129 | 130 | return (mask * ssim_value).sum() / mask.sum() 131 | 132 | 133 | def snr(test, ref, mask=None): 134 | """Signal-to-Noise Ratio (SNR). 135 | 136 | Calculate the SNR between a test image and a reference image. 137 | 138 | Parameters 139 | ---------- 140 | ref: numpy.ndarray 141 | The reference image 142 | test: numpy.ndarray 143 | The tested image 144 | mask: numpy.ndarray, optional 145 | The mask for the ROI (default is ``None``) 146 | 147 | Notes 148 | ----- 149 | Compute the metric only on magnetude. 150 | 151 | Returns 152 | ------- 153 | float 154 | The SNR 155 | 156 | """ 157 | test, ref, mask = _preprocess_input(test, ref, mask) 158 | 159 | if mask is not None: 160 | test = mask * test 161 | 162 | num = np.mean(np.square(test)) 163 | deno = mse(test, ref) 164 | 165 | return 10.0 * np.log10(num / deno) 166 | 167 | 168 | def psnr(test, ref, mask=None): 169 | """Peak Signal-to-Noise Ratio (PSNR). 170 | 171 | Calculate the PSNR between a test image and a reference image. 172 | 173 | Parameters 174 | ---------- 175 | ref : numpy.ndarray 176 | The reference image 177 | test : numpy.ndarray 178 | The tested image 179 | mask : numpy.ndarray, optional 180 | The mask for the ROI (default is ``None``) 181 | 182 | Notes 183 | ----- 184 | Compute the metric only on magnetude. 185 | 186 | Returns 187 | ------- 188 | float 189 | The PSNR 190 | 191 | """ 192 | test, ref, mask = _preprocess_input(test, ref, mask) 193 | 194 | if mask is not None: 195 | test = mask * test 196 | ref = mask * ref 197 | 198 | num = np.max(np.abs(test)) 199 | deno = mse(test, ref) 200 | 201 | return 10.0 * np.log10(num / deno) 202 | 203 | 204 | def mse(test, ref, mask=None): 205 | r"""Mean Squared Error (MSE). 206 | 207 | Calculate the MSE between a test image and a reference image. 208 | 209 | Parameters 210 | ---------- 211 | ref : numpy.ndarray 212 | The reference image 213 | test : numpy.ndarray 214 | The tested image 215 | mask : numpy.ndarray, optional 216 | The mask for the ROI (default is ``None``) 217 | 218 | Notes 219 | ----- 220 | Compute the metric only on magnetude. 221 | 222 | .. math:: 223 | 1/N * \|ref - test\|_2 224 | 225 | Returns 226 | ------- 227 | float 228 | The MSE 229 | 230 | """ 231 | test, ref, mask = _preprocess_input(test, ref, mask) 232 | 233 | if mask is not None: 234 | test = mask * test 235 | ref = mask * ref 236 | 237 | return np.mean(np.square(test - ref)) 238 | 239 | 240 | def nrmse(test, ref, mask=None): 241 | """Return NRMSE. 242 | 243 | Parameters 244 | ---------- 245 | ref : numpy.ndarray 246 | The reference image 247 | test : numpy.ndarray 248 | The tested image 249 | mask : numpy.ndarray, optional 250 | The mask for the ROI (default is ``None``) 251 | 252 | Notes 253 | ----- 254 | Compute the metric only on magnitude. 255 | 256 | Returns 257 | ------- 258 | float 259 | The NRMSE 260 | 261 | """ 262 | test, ref, mask = _preprocess_input(test, ref, mask) 263 | 264 | if mask is not None: 265 | test = mask * test 266 | ref = mask * ref 267 | 268 | num = np.sqrt(mse(test, ref)) 269 | deno = np.sqrt(np.mean(np.square(test))) 270 | 271 | return num / deno 272 | -------------------------------------------------------------------------------- /src/modopt/base/np_adjust.py: -------------------------------------------------------------------------------- 1 | """NUMPY ADJUSTMENT ROUTINES. 2 | 3 | This module contains methods for adjusting the default output for certain 4 | Numpy functions. 5 | 6 | :Author: Samuel Farrens 7 | 8 | """ 9 | 10 | import numpy as np 11 | 12 | 13 | def rotate(input_data): 14 | """Rotate. 15 | 16 | This method rotates an input numpy array by 180 degrees. 17 | 18 | Parameters 19 | ---------- 20 | input_data : numpy.ndarray 21 | Input data array (at least 2D) 22 | 23 | Returns 24 | ------- 25 | numpy.ndarray 26 | Rotated data 27 | 28 | Notes 29 | ----- 30 | Adjustment to numpy.rot90 31 | 32 | Examples 33 | -------- 34 | >>> import numpy as np 35 | >>> from modopt.base.np_adjust import rotate 36 | >>> x = np.arange(9).reshape((3, 3)) 37 | >>> x 38 | array([[0, 1, 2], 39 | [3, 4, 5], 40 | [6, 7, 8]]) 41 | >>> rotate(x) 42 | array([[8, 7, 6], 43 | [5, 4, 3], 44 | [2, 1, 0]]) 45 | 46 | 47 | See Also 48 | -------- 49 | numpy.rot90 : base function 50 | 51 | """ 52 | return np.rot90(input_data, 2) 53 | 54 | 55 | def rotate_stack(input_data): 56 | """Rotate stack. 57 | 58 | This method rotates each array in a stack of arrays by 180 degrees. 59 | 60 | Parameters 61 | ---------- 62 | input_data : numpy.ndarray 63 | Input data array (at least 3D) 64 | 65 | Returns 66 | ------- 67 | numpy.ndarray 68 | Rotated data 69 | 70 | Examples 71 | -------- 72 | >>> import numpy as np 73 | >>> from modopt.base.np_adjust import rotate_stack 74 | >>> x = np.arange(18).reshape((2, 3, 3)) 75 | >>> x 76 | array([[[ 0, 1, 2], 77 | [ 3, 4, 5], 78 | [ 6, 7, 8]], 79 | 80 | [[ 9, 10, 11], 81 | [12, 13, 14], 82 | [15, 16, 17]]]) 83 | >>> rotate_stack(x) 84 | array([[[ 8, 7, 6], 85 | [ 5, 4, 3], 86 | [ 2, 1, 0]], 87 | 88 | [[17, 16, 15], 89 | [14, 13, 12], 90 | [11, 10, 9]]]) 91 | 92 | See Also 93 | -------- 94 | rotate : looped function 95 | 96 | """ 97 | return np.array([rotate(array) for array in input_data]) 98 | 99 | 100 | def pad2d(input_data, padding): 101 | """Pad array. 102 | 103 | This method pads an input numpy array with zeros in all directions. 104 | 105 | Parameters 106 | ---------- 107 | input_data : numpy.ndarray 108 | Input data array (at least 2D) 109 | padding : int or tuple 110 | Amount of padding in x and y directions, respectively 111 | 112 | Returns 113 | ------- 114 | numpy.ndarray 115 | Padded data 116 | 117 | Raises 118 | ------ 119 | ValueError 120 | For 121 | 122 | Notes 123 | ----- 124 | Adjustment to numpy.pad() 125 | 126 | Examples 127 | -------- 128 | >>> import numpy as np 129 | >>> from modopt.base.np_adjust import pad2d 130 | >>> x = np.arange(9).reshape((3, 3)) 131 | >>> x 132 | array([[0, 1, 2], 133 | [3, 4, 5], 134 | [6, 7, 8]]) 135 | >>> pad2d(x, (1, 1)) 136 | array([[0, 0, 0, 0, 0], 137 | [0, 0, 1, 2, 0], 138 | [0, 3, 4, 5, 0], 139 | [0, 6, 7, 8, 0], 140 | [0, 0, 0, 0, 0]]) 141 | 142 | See Also 143 | -------- 144 | numpy.pad : base function 145 | 146 | """ 147 | input_data = np.array(input_data) 148 | 149 | if isinstance(padding, int): 150 | padding = np.array([padding]) 151 | elif isinstance(padding, (tuple, list)): 152 | padding = np.array(padding) 153 | elif not isinstance(padding, np.ndarray): 154 | raise ValueError( 155 | "Padding must be an integer or a tuple (or list, np.ndarray) of integers", 156 | ) 157 | 158 | if padding.size == 1: 159 | padding = np.repeat(padding, 2) 160 | 161 | pad_x = (padding[0], padding[0]) 162 | pad_y = (padding[1], padding[1]) 163 | 164 | return np.pad(input_data, (pad_x, pad_y), "constant") 165 | 166 | 167 | def ftr(input_data): 168 | """Fancy transpose right. 169 | 170 | Apply ``fancy_transpose`` to data with ``roll=1``. 171 | 172 | Parameters 173 | ---------- 174 | input_data : numpy.ndarray 175 | Input data array 176 | 177 | Returns 178 | ------- 179 | numpy.ndarray 180 | Transposed data 181 | 182 | See Also 183 | -------- 184 | fancy_transpose : base function 185 | 186 | """ 187 | return fancy_transpose(input_data) 188 | 189 | 190 | def ftl(input_data): 191 | """Fancy transpose left. 192 | 193 | Apply ``fancy_transpose`` to data with ``roll=-1``. 194 | 195 | Parameters 196 | ---------- 197 | input_data : numpy.ndarray 198 | Input data array 199 | 200 | Returns 201 | ------- 202 | numpy.ndarray 203 | Transposed data 204 | 205 | See Also 206 | -------- 207 | fancy_transpose : base function 208 | 209 | """ 210 | return fancy_transpose(input_data, -1) 211 | 212 | 213 | def fancy_transpose(input_data, roll=1): 214 | """Fancy transpose. 215 | 216 | This method transposes a multidimensional matrix. 217 | 218 | Parameters 219 | ---------- 220 | input_data : numpy.ndarray 221 | Input data array 222 | roll : int 223 | Roll direction and amount (default is ``1``) 224 | 225 | Returns 226 | ------- 227 | numpy.ndarray 228 | Transposed data 229 | 230 | Notes 231 | ----- 232 | Adjustment to numpy.transpose 233 | 234 | Examples 235 | -------- 236 | >>> import numpy as np 237 | >>> from modopt.base.np_adjust import fancy_transpose 238 | >>> x = np.arange(27).reshape(3, 3, 3) 239 | >>> x 240 | array([[[ 0, 1, 2], 241 | [ 3, 4, 5], 242 | [ 6, 7, 8]], 243 | 244 | [[ 9, 10, 11], 245 | [12, 13, 14], 246 | [15, 16, 17]], 247 | 248 | [[18, 19, 20], 249 | [21, 22, 23], 250 | [24, 25, 26]]]) 251 | >>> fancy_transpose(x) 252 | array([[[ 0, 3, 6], 253 | [ 9, 12, 15], 254 | [18, 21, 24]], 255 | 256 | [[ 1, 4, 7], 257 | [10, 13, 16], 258 | [19, 22, 25]], 259 | 260 | [[ 2, 5, 8], 261 | [11, 14, 17], 262 | [20, 23, 26]]]) 263 | >>> fancy_transpose(x, roll=-1) 264 | array([[[ 0, 9, 18], 265 | [ 1, 10, 19], 266 | [ 2, 11, 20]], 267 | 268 | [[ 3, 12, 21], 269 | [ 4, 13, 22], 270 | [ 5, 14, 23]], 271 | 272 | [[ 6, 15, 24], 273 | [ 7, 16, 25], 274 | [ 8, 17, 26]]]) 275 | 276 | See Also 277 | -------- 278 | numpy.transpose : base function 279 | 280 | """ 281 | axis_roll = np.roll(np.arange(input_data.ndim), roll) 282 | 283 | return np.transpose(input_data, axes=axis_roll) 284 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test for base module. 3 | 4 | :Authors: 5 | Samuel Farrens 6 | Pierre-Antoine Comby 7 | """ 8 | 9 | import numpy as np 10 | import numpy.testing as npt 11 | import pytest 12 | from test_helpers import failparam, skipparam 13 | 14 | from modopt.base import backend, np_adjust, transform, types 15 | from modopt.base.backend import LIBRARIES 16 | 17 | 18 | class TestNpAdjust: 19 | """Test for npadjust.""" 20 | 21 | array33 = np.arange(9).reshape((3, 3)) 22 | array233 = np.arange(18).reshape((2, 3, 3)) 23 | arraypad = np.array( 24 | [ 25 | [0, 0, 0, 0, 0], 26 | [0, 0, 1, 2, 0], 27 | [0, 3, 4, 5, 0], 28 | [0, 6, 7, 8, 0], 29 | [0, 0, 0, 0, 0], 30 | ] 31 | ) 32 | 33 | def test_rotate(self): 34 | """Test rotate.""" 35 | npt.assert_array_equal( 36 | np_adjust.rotate(self.array33), 37 | np.rot90(np.rot90(self.array33)), 38 | err_msg="Incorrect rotation.", 39 | ) 40 | 41 | def test_rotate_stack(self): 42 | """Test rotate_stack.""" 43 | npt.assert_array_equal( 44 | np_adjust.rotate_stack(self.array233), 45 | np.rot90(self.array233, k=2, axes=(1, 2)), 46 | err_msg="Incorrect stack rotation.", 47 | ) 48 | 49 | @pytest.mark.parametrize( 50 | "padding", 51 | [ 52 | 1, 53 | [1, 1], 54 | np.array([1, 1]), 55 | failparam("1", raises=ValueError), 56 | ], 57 | ) 58 | def test_pad2d(self, padding): 59 | """Test pad2d.""" 60 | npt.assert_equal(np_adjust.pad2d(self.array33, padding), self.arraypad) 61 | 62 | def test_fancy_transpose(self): 63 | """Test fancy transpose.""" 64 | npt.assert_array_equal( 65 | np_adjust.fancy_transpose(self.array233), 66 | np.array( 67 | [ 68 | [[0, 3, 6], [9, 12, 15]], 69 | [[1, 4, 7], [10, 13, 16]], 70 | [[2, 5, 8], [11, 14, 17]], 71 | ] 72 | ), 73 | err_msg="Incorrect fancy transpose", 74 | ) 75 | 76 | def test_ftr(self): 77 | """Test ftr.""" 78 | npt.assert_array_equal( 79 | np_adjust.ftr(self.array233), 80 | np.array( 81 | [ 82 | [[0, 3, 6], [9, 12, 15]], 83 | [[1, 4, 7], [10, 13, 16]], 84 | [[2, 5, 8], [11, 14, 17]], 85 | ] 86 | ), 87 | err_msg="Incorrect fancy transpose: ftr", 88 | ) 89 | 90 | def test_ftl(self): 91 | """Test fancy transpose left.""" 92 | npt.assert_array_equal( 93 | np_adjust.ftl(self.array233), 94 | np.array( 95 | [ 96 | [[0, 9], [1, 10], [2, 11]], 97 | [[3, 12], [4, 13], [5, 14]], 98 | [[6, 15], [7, 16], [8, 17]], 99 | ] 100 | ), 101 | err_msg="Incorrect fancy transpose: ftl", 102 | ) 103 | 104 | 105 | class TestTransforms: 106 | """Test for the transform module.""" 107 | 108 | cube = np.arange(16).reshape((4, 2, 2)) 109 | map = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]]) 110 | matrix = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]) 111 | layout = (2, 2) 112 | fail_layout = (3, 3) 113 | 114 | @pytest.mark.parametrize( 115 | ("func", "indata", "layout", "outdata"), 116 | [ 117 | (transform.cube2map, cube, layout, map), 118 | failparam(transform.cube2map, np.eye(2), layout, map, raises=ValueError), 119 | (transform.map2cube, map, layout, cube), 120 | (transform.map2matrix, map, layout, matrix), 121 | (transform.matrix2map, matrix, matrix.shape, map), 122 | ], 123 | ) 124 | def test_map(self, func, indata, layout, outdata): 125 | """Test cube2map.""" 126 | npt.assert_array_equal( 127 | func(indata, layout), 128 | outdata, 129 | ) 130 | if func.__name__ != "map2matrix": 131 | npt.assert_raises(ValueError, func, indata, self.fail_layout) 132 | 133 | def test_cube2matrix(self): 134 | """Test cube2matrix.""" 135 | npt.assert_array_equal( 136 | transform.cube2matrix(self.cube), 137 | self.matrix, 138 | ) 139 | 140 | def test_matrix2cube(self): 141 | """Test matrix2cube.""" 142 | npt.assert_array_equal( 143 | transform.matrix2cube(self.matrix, self.cube[0].shape), 144 | self.cube, 145 | err_msg="Incorrect transformation: matrix2cube", 146 | ) 147 | 148 | 149 | class TestType: 150 | """Test for type module.""" 151 | 152 | data_list = list(range(5)) # noqa: RUF012 153 | data_int = np.arange(5) 154 | data_flt = np.arange(5).astype(float) 155 | 156 | @pytest.mark.parametrize( 157 | ("data", "checked"), 158 | [ 159 | (1.0, 1.0), 160 | (1, 1.0), 161 | (data_list, data_flt), 162 | (data_int, data_flt), 163 | failparam("1.0", 1.0, raises=TypeError), 164 | ], 165 | ) 166 | def test_check_float(self, data, checked): 167 | """Test check float.""" 168 | npt.assert_array_equal(types.check_float(data), checked) 169 | 170 | @pytest.mark.parametrize( 171 | ("data", "checked"), 172 | [ 173 | (1.0, 1), 174 | (1, 1), 175 | (data_list, data_int), 176 | (data_flt, data_int), 177 | failparam("1", None, raises=TypeError), 178 | ], 179 | ) 180 | def test_check_int(self, data, checked): 181 | """Test check int.""" 182 | npt.assert_array_equal(types.check_int(data), checked) 183 | 184 | @pytest.mark.parametrize( 185 | ("data", "dtype"), [(data_flt, np.integer), (data_int, np.floating)] 186 | ) 187 | def test_check_npndarray(self, data, dtype): 188 | """Test check_npndarray.""" 189 | npt.assert_raises( 190 | TypeError, 191 | types.check_npndarray, 192 | data, 193 | dtype=dtype, 194 | ) 195 | 196 | def test_check_callable(self): 197 | """Test callable.""" 198 | npt.assert_raises(TypeError, types.check_callable, 1) 199 | 200 | 201 | @pytest.mark.parametrize( 202 | "backend_name", 203 | [ 204 | skipparam(name, cond=LIBRARIES[name] is None, reason=f"{name} not installed") 205 | for name in LIBRARIES 206 | ], 207 | ) 208 | def test_tf_backend(backend_name): 209 | """Test Modopt computational backends.""" 210 | xp, checked_backend_name = backend.get_backend(backend_name) 211 | if checked_backend_name != backend_name or xp != LIBRARIES[backend_name]: 212 | raise AssertionError(f"{backend_name} get_backend fails!") 213 | xp_input = backend.change_backend(np.array([10, 10]), backend_name) 214 | if ( 215 | backend.get_array_module(LIBRARIES[backend_name].ones(1)) 216 | != backend.LIBRARIES[backend_name] 217 | or backend.get_array_module(xp_input) != LIBRARIES[backend_name] 218 | ): 219 | raise AssertionError(f"{backend_name} backend fails!") 220 | -------------------------------------------------------------------------------- /docs/source/plugin_example.rst: -------------------------------------------------------------------------------- 1 | Plugin Example 2 | ============== 3 | 4 | This is a quick example demonstrating the use of ModOpt in the context 5 | of a PySAP plugin. In this example the Forward Backward algorithm from 6 | ModOpt is used to find the best fitting line to a set of data points. 7 | 8 | -------------- 9 | 10 | First we import the required packages. 11 | 12 | .. code:: python 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from modopt import opt 17 | 18 | Then we define a couple of utility functions for the puroposes of this 19 | example. 20 | 21 | .. code:: python 22 | 23 | def regression_plot(x1, y1, x2=None, y2=None): 24 | """Plot data points and a proposed fit. 25 | """ 26 | 27 | fig = plt.figure(figsize=(12, 8)) 28 | plt.plot(x1, y1, 'o', color="#19C3F5", label='Data') 29 | if not all([isinstance(var, type(None)) for var in (x2, y2)]): 30 | plt.plot(x2, y2, '--', color='#FF4F5B', label='Model') 31 | plt.title('Best Fit Line', fontsize=20) 32 | plt.xlabel('x', fontsize=18) 33 | plt.ylabel('y', fontsize=18) 34 | plt.legend() 35 | plt.show() 36 | 37 | def y_func(x, a): 38 | """Equation of a polynomial line. 39 | """ 40 | 41 | return sum([(a_i * x ** n) for a_i, n in zip(a, range(a.size))]) 42 | 43 | -------------- 44 | 45 | The Problem 46 | ----------- 47 | 48 | Image we have the following set of data points and we want to find the 49 | best fitting 3\ :math:`^{\textrm{rd}}` degree polynomial line. 50 | 51 | .. code:: python 52 | 53 | # A range of positions on the x-axis 54 | x = np.linspace(0.0, 1.0, 21) 55 | 56 | # A set of positions on the y-axis 57 | y = np.array([0.486, 0.866, 0.944, 1.144, 1.103, 1.202, 1.166, 1.191, 1.124, 1.095, 1.122, 1.102, 1.099, 1.017, 58 | 1.111, 1.117, 1.152, 1.265, 1.380, 1.575, 1.857]) 59 | 60 | # Plot the points 61 | regression_plot(x, y) 62 | 63 | 64 | 65 | .. image:: output_5_0.png 66 | 67 | 68 | This corresponds to solving the inverse problem 69 | 70 | .. math:: y = Ha 71 | 72 | where :math:`y` are the points on the y-axis, :math:`H` is a matirx 73 | coposed of the points in the x-axis 74 | 75 | .. code:: python 76 | 77 | H = np.array([np.ones(x.size), x, x ** 2, x ** 3]).T 78 | 79 | and :math:`a` is a model describing the best fit line that we aim to 80 | recover. 81 | 82 | Note: We could easily solve this problem analytically. 83 | 84 | -------------- 85 | 86 | ModOpt solution 87 | --------------- 88 | 89 | Let’s attempt to solve this problem using the Forward Backward algorithm 90 | implemented in ModOpt in a few simple steps. 91 | 92 | Step 1: Set an initial guess for :math:`a` 93 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 94 | 95 | We will start asumming :math:`a` is a vector of zeros, which is 96 | obviously a very bad fit to the data. 97 | 98 | .. code:: python 99 | 100 | a_0 = np.zeros(4) 101 | 102 | regression_plot(x, y, x, y_func(x, a_0)) 103 | 104 | 105 | 106 | .. image:: output_9_0.png 107 | 108 | 109 | Step 2: Define a gradient operator for the problem 110 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 111 | 112 | For this problem we can use the basic gradient provided in ModOpt 113 | 114 | .. math:: \nabla F(x) = H^T(Hx - y) 115 | 116 | we simply need to define the operations of :math:`H` and :math:`H^T`. 117 | 118 | .. code:: python 119 | 120 | grad_op = opt.gradient.GradBasic(y, lambda x: np.dot(H, x), lambda x: np.dot(H.T, x)) 121 | 122 | 123 | .. parsed-literal:: 124 | 125 | WARNING: Making input data immutable. 126 | 127 | 128 | Step 3: Define a proximity operator for the algorithm 129 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 130 | 131 | Since we don’t need to implement any regularisation we simply set the 132 | identity operator. 133 | 134 | .. code:: python 135 | 136 | prox_op = opt.proximity.IdentityProx() 137 | 138 | Step 4: Pass everything to the algorithm 139 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 140 | 141 | Here we pass the initial guess for :math:`a` along with the gradient and 142 | proximity operators we defined. We also specify a value for the 143 | :math:`\beta` parameter in the Forward Backward algorithm. Finally, we 144 | specify that we want a maximum of 500 iterations. 145 | 146 | .. code:: python 147 | 148 | alg = opt.algorithms.ForwardBackward(a_0, grad_op, prox_op, beta_param=0.01, auto_iterate=False) 149 | alg.iterate(max_iter=500) 150 | 151 | 152 | .. parsed-literal:: 153 | 154 | 100% (500 of 500) \|######################\| Elapsed Time: 0:00:00 Time: 0:00:00 155 | 156 | 157 | .. parsed-literal:: 158 | 159 | - ITERATION: 1 160 | - DATA FIDELITY (X): 6.931639464572723 161 | - COST: 6.931639464572723 162 | 163 | - ITERATION: 2 164 | - DATA FIDELITY (X): 3.4179113043303198 165 | - COST: 3.4179113043303198 166 | 167 | - ITERATION: 3 168 | - DATA FIDELITY (X): 1.7894732608136656 169 | - COST: 1.7894732608136656 170 | 171 | - ITERATION: 4 172 | - DATA FIDELITY (X): 0.8712577337041495 173 | - COST: 0.8712577337041495 174 | 175 | - CONVERGENCE TEST - 176 | - CHANGE IN COST: 2.8897396205130526 177 | 178 | ... 179 | 180 | - ITERATION: 499 181 | - DATA FIDELITY (X): 0.05534231539479718 182 | - COST: 0.05534231539479718 183 | 184 | - ITERATION: 500 185 | - DATA FIDELITY (X): 0.05498126500595005 186 | - COST: 0.05498126500595005 187 | 188 | - CONVERGENCE TEST - 189 | - CHANGE IN COST: 0.013163437152771431 190 | 191 | 192 | 193 | Step 5: Extract the final result 194 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 195 | 196 | Once the algorithm has finished running we take the final result. We can 197 | see that it’s a pretty good fit! 198 | 199 | .. code:: python 200 | 201 | a = alg.x_final 202 | 203 | regression_plot(x, y, x, y_func(x, a)) 204 | 205 | 206 | 207 | .. image:: output_17_0.png 208 | 209 | 210 | -------------- 211 | 212 | PySAP Plugin 213 | ------------ 214 | 215 | Now imagine we want to implement the above solution as a PySAP plugin. 216 | To do so we would first need to is create a new repository using the 217 | PySAP plugin template (https://github.com/CEA-COSMIC/pysap-extplugin). 218 | 219 | Afterwards, we could package our solution as a more user friendly 220 | function. 221 | 222 | .. code:: python 223 | 224 | def poly_fit(x, y, deg=3): 225 | 226 | H = np.array([np.ones(x.size), x, x ** 2, x ** 3]).T 227 | a_0 = np.zeros(4) 228 | grad_op = opt.gradient.GradBasic(y, lambda x: np.dot(H, x), lambda x: np.dot(H.T, x)) 229 | prox_op = opt.proximity.IdentityProx() 230 | cost_op = opt.cost.costObj(operators=[grad_op, prox_op], verbose=False) 231 | alg = opt.algorithms.ForwardBackward(a_0, grad_op, prox_op, cost_op, beta_param=0.01, auto_iterate=False, 232 | progress=False) 233 | alg.iterate(max_iter=500) 234 | 235 | return alg.x_final 236 | 237 | Once the plugin (let’s call it foo) was integrated it would be possible 238 | to call this function as follows: 239 | 240 | .. code:: python 241 | 242 | from pysap.plugins.foo import poly_fit 243 | 244 | Then we could use this function to fit our data directly. 245 | 246 | .. code:: python 247 | 248 | a_new = poly_fit(x, y) 249 | 250 | regression_plot(x, y, x, y_func(x, a_new)) 251 | 252 | 253 | 254 | .. image:: output_21_0.png 255 | -------------------------------------------------------------------------------- /src/modopt/math/stats.py: -------------------------------------------------------------------------------- 1 | """STATISTICS ROUTINES. 2 | 3 | This module contains methods for basic statistics. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | 11 | try: 12 | from packaging import version 13 | from astropy import __version__ as astropy_version 14 | from astropy.convolution import Gaussian2DKernel 15 | except ImportError: # pragma: no cover 16 | import_astropy = False 17 | else: 18 | import_astropy = True 19 | 20 | 21 | def gaussian_kernel(data_shape, sigma, norm="max"): 22 | """Gaussian kernel. 23 | 24 | This method produces a Gaussian kerenal of a specified size and dispersion. 25 | 26 | Parameters 27 | ---------- 28 | data_shape : tuple 29 | Desiered shape of the kernel 30 | sigma : float 31 | Standard deviation of the kernel 32 | norm : {'max', 'sum'}, optional, default='max' 33 | Normalisation of the kernel 34 | 35 | Returns 36 | ------- 37 | numpy.ndarray 38 | Kernel 39 | 40 | Raises 41 | ------ 42 | ImportError 43 | If Astropy package not found 44 | ValueError 45 | For invalid norm 46 | 47 | Examples 48 | -------- 49 | >>> from modopt.math.stats import gaussian_kernel 50 | >>> gaussian_kernel((3, 3), 1) 51 | array([[0.36787944, 0.60653066, 0.36787944], 52 | [0.60653066, 1. , 0.60653066], 53 | [0.36787944, 0.60653066, 0.36787944]]) 54 | 55 | >>> gaussian_kernel((3, 3), 1, norm='sum') 56 | array([[0.07511361, 0.1238414 , 0.07511361], 57 | [0.1238414 , 0.20417996, 0.1238414 ], 58 | [0.07511361, 0.1238414 , 0.07511361]]) 59 | 60 | """ 61 | if not import_astropy: # pragma: no cover 62 | raise ImportError("Astropy package not found.") 63 | 64 | if norm not in {"max", "sum"}: 65 | raise ValueError('Invalid norm, options are "max", "sum" or "none".') 66 | 67 | kernel = np.array( 68 | Gaussian2DKernel(sigma, x_size=data_shape[1], y_size=data_shape[0]), 69 | ) 70 | 71 | if norm == "max": 72 | return kernel / np.max(kernel) 73 | 74 | elif version.parse(astropy_version) < version.parse("5.2"): 75 | return kernel / np.sum(kernel) 76 | 77 | else: 78 | return kernel 79 | 80 | 81 | def mad(input_data): 82 | r"""Median absolute deviation. 83 | 84 | This method calculates the median absolute deviation of the input data. 85 | 86 | Parameters 87 | ---------- 88 | input_data : numpy.ndarray 89 | Input data array 90 | 91 | Returns 92 | ------- 93 | float 94 | MAD value 95 | 96 | Examples 97 | -------- 98 | >>> import numpy as np 99 | >>> from modopt.math.stats import mad 100 | >>> a = np.arange(9).reshape(3, 3) 101 | >>> mad(a) 102 | 2.0 103 | 104 | Notes 105 | ----- 106 | The MAD is calculated as follows: 107 | 108 | .. math:: 109 | 110 | \mathrm{MAD} = \mathrm{median}\left(|X_i - \mathrm{median}(X)|\right) 111 | 112 | See Also 113 | -------- 114 | numpy.median : median function used 115 | 116 | """ 117 | return np.median(np.abs(input_data - np.median(input_data))) 118 | 119 | 120 | def mse(data1, data2): 121 | """Mean Squared Error. 122 | 123 | This method returns the Mean Squared Error (MSE) between two data sets. 124 | 125 | Parameters 126 | ---------- 127 | data1 : numpy.ndarray 128 | First data set 129 | data2 : numpy.ndarray 130 | Second data set 131 | 132 | Returns 133 | ------- 134 | float 135 | Mean squared error 136 | 137 | Examples 138 | -------- 139 | >>> import numpy as np 140 | >>> from modopt.math.stats import mse 141 | >>> a = np.arange(9).reshape(3, 3) 142 | >>> mse(a, a + 2) 143 | 4.0 144 | 145 | """ 146 | return np.mean((data1 - data2) ** 2) 147 | 148 | 149 | def psnr(data1, data2, method="starck", max_pix=255): 150 | r"""Peak Signal-to-Noise Ratio. 151 | 152 | This method calculates the Peak Signal-to-Noise Ratio between two data 153 | sets. 154 | 155 | Parameters 156 | ---------- 157 | data1 : numpy.ndarray 158 | First data set 159 | data2 : numpy.ndarray 160 | Second data set 161 | method : {'starck', 'wiki'}, optional 162 | PSNR implementation (default is ``'starck'``) 163 | max_pix : int, optional 164 | Maximum number of pixels (default is ``255``) 165 | 166 | Returns 167 | ------- 168 | float 169 | PSNR value 170 | 171 | Raises 172 | ------ 173 | ValueError 174 | For invalid PSNR method 175 | 176 | Examples 177 | -------- 178 | >>> import numpy as np 179 | >>> from modopt.math.stats import psnr 180 | >>> a = np.arange(9).reshape(3, 3) 181 | >>> psnr(a, a + 2) 182 | 12.041199826559248 183 | 184 | >>> psnr(a, a + 2, method='wiki') 185 | 42.11020369539948 186 | 187 | Notes 188 | ----- 189 | ``'starck'``: 190 | 191 | Implements eq.3.7 from :cite:`starck2010` 192 | 193 | ``'wiki'``: 194 | 195 | Implements PSNR equation on 196 | https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 197 | 198 | .. math:: 199 | 200 | \mathrm{PSNR} = 20\log_{10}(\mathrm{MAX}_I - 201 | 10\log_{10}(\mathrm{MSE})) 202 | 203 | """ 204 | if method == "starck": 205 | return 20 * np.log10( 206 | (data1.shape[0] * np.abs(np.max(data1) - np.min(data1))) 207 | / np.linalg.norm(data1 - data2), 208 | ) 209 | 210 | elif method == "wiki": 211 | return 20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2)) 212 | 213 | raise ValueError( 214 | 'Invalid PSNR method. Options are "starck" and "wiki"', 215 | ) 216 | 217 | 218 | def psnr_stack(data1, data2, metric=np.mean, method="starck"): 219 | """Peak Signa-to-Noise for stack of images. 220 | 221 | This method calculates the PSNRs for two stacks of 2D arrays. 222 | By default the metod returns the mean value of the PSNRs, but any other 223 | metric can be used. 224 | 225 | Parameters 226 | ---------- 227 | data1 : numpy.ndarray 228 | Stack of images, 3D array 229 | data2 : numpy.ndarray 230 | Stack of recovered images, 3D array 231 | metric : function 232 | The desired metric to be applied to the PSNR values (default is 233 | ``numpy.mean``) 234 | method : {'starck', 'wiki'}, optional 235 | PSNR implementation (default is ``'starck'``) 236 | 237 | Returns 238 | ------- 239 | float 240 | Metric result of PSNR values 241 | 242 | Raises 243 | ------ 244 | ValueError 245 | For invalid input data dimensions 246 | 247 | Examples 248 | -------- 249 | >>> import numpy as np 250 | >>> from modopt.math.stats import psnr_stack 251 | >>> a = np.arange(18).reshape(2, 3, 3) 252 | >>> psnr_stack(a, a + 2) 253 | 12.041199826559248 254 | 255 | See Also 256 | -------- 257 | numpy.mean : default metric 258 | 259 | """ 260 | if data1.ndim != 3 or data2.ndim != 3: 261 | raise ValueError("Input data must be a 3D np.ndarray") 262 | 263 | return metric( 264 | [psnr(i_elem, j_elem, method=method) for i_elem, j_elem in zip(data1, data2)] 265 | ) 266 | 267 | 268 | def sigma_mad(input_data): 269 | r"""MAD Standard Deviation. 270 | 271 | This method calculates the standard deviation of the input data from the 272 | MAD. 273 | 274 | Parameters 275 | ---------- 276 | input_data : numpy.ndarray 277 | Input data array 278 | 279 | Returns 280 | ------- 281 | float 282 | Sigma value 283 | 284 | Examples 285 | -------- 286 | >>> import numpy as np 287 | >>> from modopt.math.stats import sigma_mad 288 | >>> a = np.arange(9).reshape(3, 3) 289 | >>> sigma_mad(a) 290 | 2.9652 291 | 292 | Notes 293 | ----- 294 | This function can be used for estimating the standeviation of the noise in 295 | imgaes. 296 | 297 | Sigma is calculated as follows: 298 | 299 | .. math:: 300 | 301 | \sigma = 1.4826 \mathrm{MAD}(X) 302 | 303 | """ 304 | return 1.4826 * mad(input_data) 305 | -------------------------------------------------------------------------------- /src/modopt/opt/gradient.py: -------------------------------------------------------------------------------- 1 | """GRADIENT CLASSES. 2 | 3 | This module contains classses for defining algorithm gradients. 4 | Based on work by Yinghao Ge and Fred Ngole. 5 | 6 | :Author: Samuel Farrens 7 | 8 | """ 9 | 10 | import numpy as np 11 | 12 | from modopt.base.types import check_callable, check_float, check_npndarray 13 | 14 | 15 | class GradParent: 16 | """Gradient Parent Class. 17 | 18 | This class defines the basic methods that will be inherited by specific 19 | gradient classes. 20 | 21 | Parameters 22 | ---------- 23 | input_data : numpy.ndarray 24 | The observed data 25 | op : callable 26 | The operator 27 | trans_op : callable 28 | The transpose operator 29 | get_grad : callable, optional 30 | Method for calculating the gradient (default is ``None``) 31 | cost: callable, optional 32 | Method for calculating the cost (default is ``None``) 33 | data_type : type, optional 34 | Expected data type of the input data (default is ``None``) 35 | input_data_writeable: bool, optional 36 | Option to make the observed data writeable (default is ``False``) 37 | verbose : bool, optional 38 | Option for verbose output (default is ``True``) 39 | 40 | Examples 41 | -------- 42 | >>> import numpy as np 43 | >>> from modopt.opt.gradient import GradParent 44 | >>> y = np.arange(9).reshape(3, 3).astype(float) 45 | >>> g = GradParent(y, lambda x: x ** 2, lambda x: x ** 3) 46 | >>> g.op(y) 47 | array([[ 0., 1., 4.], 48 | [ 9., 16., 25.], 49 | [36., 49., 64.]]) 50 | >>> g.trans_op(y) 51 | array([[ 0., 1., 8.], 52 | [ 27., 64., 125.], 53 | [216., 343., 512.]]) 54 | >>> g.trans_op_op(y) 55 | array([[0.00000e+00, 1.00000e+00, 6.40000e+01], 56 | [7.29000e+02, 4.09600e+03, 1.56250e+04], 57 | [4.66560e+04, 1.17649e+05, 2.62144e+05]]) 58 | 59 | """ 60 | 61 | def __init__( 62 | self, 63 | input_data, 64 | op, 65 | trans_op, 66 | get_grad=None, 67 | cost=None, 68 | data_type=None, 69 | input_data_writeable=False, 70 | verbose=True, 71 | ): 72 | self.verbose = verbose 73 | self._input_data_writeable = input_data_writeable 74 | self._grad_data_type = data_type 75 | self.obs_data = input_data 76 | self.op = op 77 | self.trans_op = trans_op 78 | 79 | if not isinstance(get_grad, type(None)): 80 | self.get_grad = get_grad 81 | if not isinstance(cost, type(None)): 82 | self.cost = cost 83 | 84 | @property 85 | def obs_data(self): 86 | r"""Observed Data. 87 | 88 | The observed data :math:`\mathbf{y}`. 89 | 90 | Returns 91 | ------- 92 | numpy.ndarray 93 | The observed data 94 | 95 | """ 96 | return self._obs_data 97 | 98 | @obs_data.setter 99 | def obs_data(self, input_data): 100 | if self._grad_data_type in {float, np.floating}: 101 | input_data = check_float(input_data) 102 | check_npndarray( 103 | input_data, 104 | dtype=self._grad_data_type, 105 | writeable=self._input_data_writeable, 106 | verbose=self.verbose, 107 | ) 108 | 109 | self._obs_data = input_data 110 | 111 | @property 112 | def op(self): 113 | r"""Operator. 114 | 115 | The operator :math:`\mathbf{H}`. 116 | 117 | Returns 118 | ------- 119 | callable 120 | The operator function 121 | 122 | """ 123 | return self._op 124 | 125 | @op.setter 126 | def op(self, operator): 127 | self._op = check_callable(operator) 128 | 129 | @property 130 | def trans_op(self): 131 | r"""Transpose operator. 132 | 133 | The transpose operator :math:`\mathbf{H}^T`. 134 | 135 | Returns 136 | ------- 137 | callable 138 | The transpose operator function 139 | 140 | """ 141 | return self._trans_op 142 | 143 | @trans_op.setter 144 | def trans_op(self, operator): 145 | self._trans_op = check_callable(operator) 146 | 147 | @property 148 | def get_grad(self): 149 | """Get gradient value.""" 150 | return self._get_grad 151 | 152 | @get_grad.setter 153 | def get_grad(self, method): 154 | self._get_grad = check_callable(method) 155 | 156 | @property 157 | def grad(self): 158 | """Gradient value.""" 159 | return self._grad 160 | 161 | @grad.setter 162 | def grad(self, input_value): 163 | if self._grad_data_type in {float, np.floating}: 164 | input_value = check_float(input_value) 165 | self._grad = input_value 166 | 167 | @property 168 | def cost(self): 169 | """Cost contribution.""" 170 | return self._cost 171 | 172 | @cost.setter 173 | def cost(self, method): 174 | self._cost = check_callable(method) 175 | 176 | def trans_op_op(self, input_data): 177 | r"""Transpose Operation of the Operator. 178 | 179 | This method calculates the action of the transpose operator on 180 | the action of the operator on the data. 181 | 182 | Parameters 183 | ---------- 184 | input_data : numpy.ndarray 185 | Input data array 186 | 187 | Returns 188 | ------- 189 | numpy.ndarray 190 | Result 191 | 192 | Notes 193 | ----- 194 | Implements the following equation: 195 | 196 | .. math:: 197 | \mathbf{H}^T(\mathbf{H}\mathbf{x}) 198 | 199 | where :math:`\mathbf{x}` is the ``input_data``. 200 | 201 | """ 202 | return self.trans_op(self.op(input_data)) 203 | 204 | 205 | class GradBasic(GradParent): 206 | """Basic Gradient Class. 207 | 208 | This class defines the gradient calculation and costs methods for 209 | common inverse problems. 210 | 211 | Parameters 212 | ---------- 213 | *args : tuple 214 | Positional arguments 215 | **kwargs : dict 216 | Keyword arguments 217 | 218 | Examples 219 | -------- 220 | >>> import numpy as np 221 | >>> from modopt.opt.gradient import GradBasic 222 | >>> y = np.arange(9).reshape(3, 3).astype(float) 223 | >>> g = GradBasic(y, lambda x: x ** 2, lambda x: x ** 3) 224 | >>> g.get_grad(y) 225 | >>> g.grad 226 | array([[0.00000e+00, 0.00000e+00, 8.00000e+00], 227 | [2.16000e+02, 1.72800e+03, 8.00000e+03], 228 | [2.70000e+04, 7.40880e+04, 1.75616e+05]]) 229 | 230 | See Also 231 | -------- 232 | GradParent : parent class 233 | 234 | """ 235 | 236 | def __init__(self, *args, **kwargs): 237 | super().__init__(*args, **kwargs) 238 | self.get_grad = self._get_grad_method 239 | self.cost = self._cost_method 240 | 241 | def _get_grad_method(self, input_data): 242 | r"""Get the gradient. 243 | 244 | This method calculates the gradient step from the input data. 245 | 246 | Parameters 247 | ---------- 248 | input_data : numpy.ndarray 249 | Input data array 250 | 251 | Notes 252 | ----- 253 | Implements the following equation: 254 | 255 | .. math:: 256 | \nabla F(x) = \mathbf{H}^T(\mathbf{H}\mathbf{x} - \mathbf{y}) 257 | 258 | """ 259 | self.grad = self.trans_op(self.op(input_data) - self.obs_data) 260 | 261 | def _cost_method(self, *args, **kwargs): 262 | """Calculate gradient component of the cost. 263 | 264 | This method returns the l2 norm error of the difference between the 265 | original data and the data obtained after optimisation. 266 | 267 | Parameters 268 | ---------- 269 | *args : tuple 270 | Positional arguments 271 | **kwargs : dict 272 | Keyword arguments 273 | 274 | Returns 275 | ------- 276 | float 277 | Gradient cost component 278 | 279 | """ 280 | cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2 281 | 282 | if kwargs.get("verbose"): 283 | print(" - DATA FIDELITY (X):", cost_val) 284 | 285 | return cost_val 286 | -------------------------------------------------------------------------------- /src/modopt/opt/linear/base.py: -------------------------------------------------------------------------------- 1 | """Base classes for linear operators.""" 2 | 3 | import numpy as np 4 | 5 | from modopt.base.types import check_callable 6 | from modopt.base.backend import get_array_module 7 | 8 | 9 | class LinearParent: 10 | """Linear Operator Parent Class. 11 | 12 | This class sets the structure for defining linear operator instances. 13 | 14 | Parameters 15 | ---------- 16 | op : callable 17 | Callable function that implements the linear operation 18 | adj_op : callable 19 | Callable function that implements the linear adjoint operation 20 | 21 | Examples 22 | -------- 23 | >>> from modopt.opt.linear import LinearParent 24 | >>> a = LinearParent(lambda x: x * 2, lambda x: x ** 3) 25 | >>> a.op(2) 26 | 4 27 | >>> a.adj_op(2) 28 | 8 29 | 30 | """ 31 | 32 | def __init__(self, op, adj_op): 33 | self.op = op 34 | self.adj_op = adj_op 35 | 36 | @property 37 | def op(self): 38 | """Linear Operator.""" 39 | return self._op 40 | 41 | @op.setter 42 | def op(self, operator): 43 | self._op = check_callable(operator) 44 | 45 | @property 46 | def adj_op(self): 47 | """Linear Adjoint Operator.""" 48 | return self._adj_op 49 | 50 | @adj_op.setter 51 | def adj_op(self, operator): 52 | self._adj_op = check_callable(operator) 53 | 54 | 55 | class Identity(LinearParent): 56 | """Identity Operator Class. 57 | 58 | This is a dummy class that can be used in the optimisation classes. 59 | 60 | See Also 61 | -------- 62 | LinearParent : parent class 63 | 64 | """ 65 | 66 | def __init__(self): 67 | self.op = lambda input_data: input_data 68 | self.adj_op = self.op 69 | self.cost = lambda *args, **kwargs: 0 70 | 71 | 72 | class MatrixOperator(LinearParent): 73 | """ 74 | Matrix Operator class. 75 | 76 | This class transforms an array into a suitable linear operator. 77 | """ 78 | 79 | def __init__(self, array): 80 | self.op = lambda x: array @ x 81 | xp = get_array_module(array) 82 | 83 | if xp.any(xp.iscomplex(array)): 84 | self.adj_op = lambda x: array.T.conjugate() @ x 85 | else: 86 | self.adj_op = lambda x: array.T @ x 87 | 88 | 89 | class LinearCombo(LinearParent): 90 | """Linear Combination Class. 91 | 92 | This class defines a combination of linear transform operators. 93 | 94 | Parameters 95 | ---------- 96 | operators : list, tuple or numpy.ndarray 97 | List of linear operator class instances 98 | weights : list, tuple or numpy.ndarray, optional 99 | List of weights for combining the linear adjoint operator results 100 | 101 | Examples 102 | -------- 103 | >>> from modopt.opt.linear import LinearCombo, LinearParent 104 | >>> a = LinearParent(lambda x: x * 2, lambda x: x ** 3) 105 | >>> b = LinearParent(lambda x: x * 4, lambda x: x ** 5) 106 | >>> c = LinearCombo([a, b]) 107 | >>> a.op(2) 108 | 4 109 | >>> b.op(2) 110 | 8 111 | >>> c.op(2) 112 | array([4, 8], dtype=object) 113 | >>> a.adj_op(2) 114 | 8 115 | >>> b.adj_op(2) 116 | 32 117 | >>> c.adj_op([2, 2]) 118 | 20.0 119 | 120 | See Also 121 | -------- 122 | LinearParent : parent class 123 | """ 124 | 125 | def __init__(self, operators, weights=None): 126 | operators, weights = self._check_inputs(operators, weights) 127 | self.operators = operators 128 | self.weights = weights 129 | self.op = self._op_method 130 | self.adj_op = self._adj_op_method 131 | 132 | def _check_type(self, input_val): 133 | """Check input type. 134 | 135 | This method checks if the input is a list, tuple or a numpy array and 136 | converts the input to a numpy array. 137 | 138 | Parameters 139 | ---------- 140 | input_val : any 141 | Any input object 142 | 143 | Returns 144 | ------- 145 | numpy.ndarray 146 | Numpy array of inputs 147 | 148 | Raises 149 | ------ 150 | TypeError 151 | For invalid input type 152 | ValueError 153 | If input list is empty 154 | 155 | """ 156 | if not isinstance(input_val, (list, tuple, np.ndarray)): 157 | raise TypeError( 158 | "Invalid input type, input must be a list, tuple or numpy " + "array.", 159 | ) 160 | 161 | input_val = np.array(input_val) 162 | 163 | if not input_val.size: 164 | raise ValueError("Input list is empty.") 165 | 166 | return input_val 167 | 168 | def _check_inputs(self, operators, weights): 169 | """Check inputs. 170 | 171 | This method cheks that the input operators and weights are correctly 172 | formatted. 173 | 174 | Parameters 175 | ---------- 176 | operators : list, tuple or numpy.ndarray 177 | List of linear operator class instances 178 | weights : list, tuple or numpy.ndarray 179 | List of weights for combining the linear adjoint operator results 180 | 181 | Returns 182 | ------- 183 | tuple 184 | Operators and weights 185 | 186 | Raises 187 | ------ 188 | ValueError 189 | If the number of weights does not match the number of operators 190 | TypeError 191 | If the individual weight values are not floats 192 | 193 | """ 194 | operators = self._check_type(operators) 195 | 196 | for operator in operators: 197 | if not hasattr(operator, "op"): 198 | raise ValueError('Operators must contain "op" method.') 199 | 200 | if not hasattr(operator, "adj_op"): 201 | raise ValueError('Operators must contain "adj_op" method.') 202 | 203 | operator.op = check_callable(operator.op) 204 | operator.adj_op = check_callable(operator.adj_op) 205 | 206 | if not isinstance(weights, type(None)): 207 | weights = self._check_type(weights) 208 | 209 | if weights.size != operators.size: 210 | raise ValueError( 211 | "The number of weights must match the number of " + "operators.", 212 | ) 213 | 214 | if not np.issubdtype(weights.dtype, np.floating): 215 | raise TypeError("The weights must be a list of float values.") 216 | 217 | return operators, weights 218 | 219 | def _op_method(self, input_data): 220 | """Operator. 221 | 222 | This method returns the input data operated on by all of the operators. 223 | 224 | Parameters 225 | ---------- 226 | input_data : numpy.ndarray 227 | Input data array 228 | 229 | Returns 230 | ------- 231 | numpy.ndarray 232 | Linear operation results 233 | 234 | """ 235 | res = np.empty(len(self.operators), dtype=np.ndarray) 236 | 237 | for index, _ in enumerate(self.operators): 238 | res[index] = self.operators[index].op(input_data) 239 | 240 | return res 241 | 242 | def _adj_op_method(self, input_data): 243 | """Adjoint operator. 244 | 245 | This method returns the combination of the result of all of the 246 | adjoint operators. If weights are provided the comibination is the sum 247 | of the weighted results, otherwise the combination is the mean. 248 | 249 | Parameters 250 | ---------- 251 | input_data : numpy.ndarray 252 | Input data array 253 | 254 | Returns 255 | ------- 256 | numpy.ndarray 257 | Adjoint operation results 258 | 259 | """ 260 | if isinstance(self.weights, type(None)): 261 | return np.mean( 262 | [ 263 | operator.adj_op(elem) 264 | for elem, operator in zip(input_data, self.operators) 265 | ], 266 | axis=0, 267 | ) 268 | 269 | return np.sum( 270 | [ 271 | weight * operator.adj_op(elem) 272 | for elem, operator, weight in zip( 273 | input_data, 274 | self.operators, 275 | self.weights, 276 | ) 277 | ], 278 | axis=0, 279 | ) 280 | -------------------------------------------------------------------------------- /docs/source/refs.bib: -------------------------------------------------------------------------------- 1 | @book{bauschke2009, 2 | author = {Bauschke, Heinz and Burachik, Regina and Combettes, Patrick and Luke, D.}, 3 | year = {2009}, 4 | month = {11}, 5 | pages = {}, 6 | title = {Fixed-Point Algorithms for Inverse Problems in Science and Engineering}, 7 | publisher={Springer}, 8 | url = {https://www.springer.com/gp/book/9781441995681} 9 | } 10 | 11 | @ARTICLE{candes2007, 12 | author = {Candes, Emmanuel J. and Wakin, Michael B. and Boyd, Stephen P.}, 13 | title = {Enhancing Sparsity by Reweighted L1 Minimization}, 14 | journal = {arXiv e-prints}, 15 | keywords = {Statistics - Methodology, Mathematics - Statistics, 49N30, 49N45, 94A12}, 16 | year = "2007", 17 | month = "Nov", 18 | eid = {arXiv:0711.1612}, 19 | pages = {arXiv:0711.1612}, 20 | archivePrefix = {arXiv}, 21 | eprint = {0711.1612}, 22 | primaryClass = {stat.ME}, 23 | adsurl = {https://ui.adsabs.harvard.edu/abs/2007arXiv0711.1612C}, 24 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 25 | } 26 | 27 | @article{chambolle2015, 28 | TITLE = {{On the convergence of the iterates of ``FISTA''}}, 29 | AUTHOR = {Chambolle, Antonin and Dossal, Charles}, 30 | URL = {https://hal.inria.fr/hal-01060130}, 31 | JOURNAL = {{Journal of Optimization Theory and Applications}}, 32 | PUBLISHER = {{Springer Verlag}}, 33 | VOLUME = {Volume 166}, 34 | NUMBER = { Issue 3}, 35 | PAGES = {25}, 36 | YEAR = {2015}, 37 | MONTH = Aug, 38 | KEYWORDS = {proximal ; FISTA ; optimization ; convergence ; forward backward}, 39 | PDF = {https://hal.inria.fr/hal-01060130/file/Fista10.pdf}, 40 | HAL_ID = {hal-01060130}, 41 | HAL_VERSION = {v3}, 42 | } 43 | 44 | @article{combettes2005, 45 | author = {Combettes, Patrick L. and Wajs, Valérie R.}, 46 | title = {Signal Recovery by Proximal Forward-Backward Splitting}, 47 | journal = {Multiscale Modeling \& Simulation}, 48 | volume = {4}, 49 | number = {4}, 50 | pages = {1168-1200}, 51 | year = {2005}, 52 | doi = {10.1137/050626090}, 53 | URL = {https://doi.org/10.1137/050626090}, 54 | eprint = {https://doi.org/10.1137/050626090} 55 | } 56 | 57 | @article{condat2013, 58 | author = {Condat, Laurent}, 59 | year = {2013}, 60 | month = {08}, 61 | pages = {}, 62 | title = {A Primal–Dual Splitting Method for Convex Optimization Involving Lipschitzian, Proximable and Linear Composite Terms}, 63 | volume = {158}, 64 | journal = {Journal of Optimization Theory and Applications}, 65 | doi = {10.1007/s10957-012-0245-9} 66 | } 67 | 68 | @ARTICLE{defazio2014, 69 | author = {{Defazio}, Aaron and {Bach}, Francis and {Lacoste-Julien}, Simon}, 70 | title = "{SAGA: A Fast Incremental Gradient Method With Support for Non-Strongly Convex Composite Objectives}", 71 | journal = {arXiv e-prints}, 72 | keywords = {Computer Science - Machine Learning, Mathematics - Optimization and Control, Statistics - Machine Learning}, 73 | year = 2014, 74 | month = jul, 75 | eid = {arXiv:1407.0202}, 76 | pages = {arXiv:1407.0202}, 77 | archivePrefix = {arXiv}, 78 | eprint = {1407.0202}, 79 | primaryClass = {cs.LG}, 80 | adsurl = {https://ui.adsabs.harvard.edu/abs/2014arXiv1407.0202D}, 81 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 82 | } 83 | 84 | @ARTICLE{figueiredo2014, 85 | author = {Figueiredo, Mario A.~T. and Nowak, Robert D.}, 86 | title = {Sparse Estimation with Strongly Correlated Variables using Ordered Weighted L1 Regularization}, 87 | journal = {arXiv e-prints}, 88 | keywords = {Statistics - Machine Learning}, 89 | year = "2014", 90 | month = "Sep", 91 | eid = {arXiv:1409.4005}, 92 | pages = {arXiv:1409.4005}, 93 | archivePrefix = {arXiv}, 94 | eprint = {1409.4005}, 95 | primaryClass = {stat.ML}, 96 | adsurl = {https://ui.adsabs.harvard.edu/abs/2014arXiv1409.4005F}, 97 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 98 | } 99 | 100 | @ARTICLE{kim2017, 101 | author = {Kim, Donghwan and Fessler, Jeffrey A.}, 102 | title = {Adaptive Restart of the Optimized Gradient Method for Convex Optimization}, 103 | journal = {arXiv e-prints}, 104 | keywords = {Mathematics - Optimization and Control}, 105 | year = "2017", 106 | month = "Mar", 107 | eid = {arXiv:1703.04641}, 108 | pages = {arXiv:1703.04641}, 109 | archivePrefix = {arXiv}, 110 | eprint = {1703.04641}, 111 | primaryClass = {math.OC}, 112 | adsurl = {https://ui.adsabs.harvard.edu/abs/2017arXiv170304641K}, 113 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 114 | } 115 | 116 | @ARTICLE{liang2018, 117 | author = {Liang, Jingwei and Luo, Tao and Schonlieb, Carola-Bibiane}, 118 | title = {Improving Fast Iterative Shrinkage-Thresholding Algorithm: Faster, Smarter and Greedier}, 119 | journal = {arXiv e-prints}, 120 | keywords = {Mathematics - Optimization and Control}, 121 | year = "2018", 122 | month = "Nov", 123 | eid = {arXiv:1811.01430}, 124 | pages = {arXiv:1811.01430}, 125 | archivePrefix = {arXiv}, 126 | eprint = {1811.01430}, 127 | primaryClass = {math.OC}, 128 | adsurl = {https://ui.adsabs.harvard.edu/abs/2018arXiv181101430L}, 129 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 130 | } 131 | 132 | @ARTICLE{mcdonald2014, 133 | author = {McDonald, Andrew M. and Pontil, Massimiliano and Stamos, Dimitris}, 134 | title = {New Perspectives on k-Support and Cluster Norms}, 135 | journal = {arXiv e-prints}, 136 | keywords = {Statistics - Machine Learning}, 137 | year = "2014", 138 | month = "Mar", 139 | eid = {arXiv:1403.1481}, 140 | pages = {arXiv:1403.1481}, 141 | archivePrefix = {arXiv}, 142 | eprint = {1403.1481}, 143 | primaryClass = {stat.ML}, 144 | adsurl = {https://ui.adsabs.harvard.edu/abs/2014arXiv1403.1481M}, 145 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 146 | } 147 | 148 | @ARTICLE{raguet2011, 149 | author = {Raguet, Hugo and Fadili, Jalal and Peyr{\'e}, Gabriel}, 150 | title = {Generalized Forward-Backward Splitting}, 151 | journal = {arXiv e-prints}, 152 | keywords = {Mathematics - Optimization and Control, 65K05}, 153 | year = "2011", 154 | month = "Aug", 155 | eid = {arXiv:1108.4404}, 156 | pages = {arXiv:1108.4404}, 157 | archivePrefix = {arXiv}, 158 | eprint = {1108.4404}, 159 | primaryClass = {math.OC}, 160 | adsurl = {https://ui.adsabs.harvard.edu/abs/2011arXiv1108.4404R}, 161 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 162 | } 163 | 164 | @ARTICLE{ruder2017, 165 | author = {{Ruder}, Sebastian}, 166 | title = "{An overview of gradient descent optimization algorithms}", 167 | journal = {arXiv e-prints}, 168 | keywords = {Computer Science - Machine Learning}, 169 | year = 2016, 170 | month = sep, 171 | eid = {arXiv:1609.04747}, 172 | pages = {arXiv:1609.04747}, 173 | archivePrefix = {arXiv}, 174 | eprint = {1609.04747}, 175 | primaryClass = {cs.LG}, 176 | adsurl = {https://ui.adsabs.harvard.edu/abs/2016arXiv160904747R}, 177 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 178 | } 179 | 180 | @book{starck2010, 181 | place={Cambridge}, 182 | title={Sparse Image and Signal Processing: Wavelets, Curvelets, Morphological Diversity}, 183 | DOI={10.1017/CBO9780511730344}, 184 | publisher={Cambridge University Press}, 185 | author={Starck, Jean-Luc and Murtagh, Fionn and Fadili, Jalal M.}, 186 | year={2010} 187 | } 188 | 189 | @article{yuan2006, 190 | author = {Yuan, Ming and Lin, Yi}, 191 | year = {2006}, 192 | month = {02}, 193 | pages = {49-67}, 194 | title = {Model Selection and Estimation in Regression With Grouped Variables}, 195 | volume = {68}, 196 | journal = {Journal of the Royal Statistical Society Series B}, 197 | doi = {10.1111/j.1467-9868.2005.00532.x} 198 | } 199 | 200 | @article{zou2005, 201 | author = {Zou, Hui and Hastie, Trevor}, 202 | year = {2005}, 203 | month = {02}, 204 | pages = {768-768}, 205 | title = {Regularization and variable selection via the elastic net (vol B 67, pg 301, 2005)}, 206 | volume = {67}, 207 | journal = {Journal of the Royal Statistical Society Series B}, 208 | doi = {10.1111/j.1467-9868.2005.00527.x} 209 | } 210 | 211 | @article{Goldstein2014, 212 | author={Goldstein, Tom and O’Donoghue, Brendan and Setzer, Simon and Baraniuk, Richard}, 213 | year={2014}, 214 | month={Jan}, 215 | pages={1588–1623}, 216 | title={Fast Alternating Direction Optimization Methods}, 217 | journal={SIAM Journal on Imaging Sciences}, 218 | volume={7}, 219 | ISSN={1936-4954}, 220 | doi={10/gdwr49}, 221 | } 222 | -------------------------------------------------------------------------------- /src/modopt/base/observable.py: -------------------------------------------------------------------------------- 1 | """Observable. 2 | 3 | This module contains observable classes 4 | 5 | :Author: Benoir Sarthou 6 | 7 | """ 8 | 9 | import time 10 | 11 | import numpy as np 12 | 13 | 14 | class SignalObject: 15 | """Dummy class for signals.""" 16 | 17 | pass 18 | 19 | 20 | class Observable: 21 | """Base class for observable classes. 22 | 23 | This class defines a simple interface to add or remove observers 24 | on an object. 25 | 26 | Parameters 27 | ---------- 28 | signals : list 29 | The allowed signals 30 | 31 | """ 32 | 33 | def __init__(self, signals): 34 | # Define class parameters 35 | self._allowed_signals = [] 36 | self._observers = {} 37 | 38 | # Set allowed signals 39 | for signal in signals: 40 | self._allowed_signals.append(signal) 41 | self._observers[signal] = [] 42 | 43 | # Set a lock option to avoid multiple observer notifications 44 | self._locked = False 45 | 46 | def add_observer(self, signal, observer): 47 | """Add an observer to the object. 48 | 49 | Raise an exception if the signal is not allowed. 50 | 51 | Parameters 52 | ---------- 53 | signal : str 54 | A valid signal 55 | observer : callable 56 | A function that will be called when the signal is emitted 57 | 58 | """ 59 | self._is_allowed_signal(signal) 60 | self._add_observer(signal, observer) 61 | 62 | def remove_observer(self, signal, observer): 63 | """Remove an observer from the object. 64 | 65 | Raise an eception if the signal is not allowed. 66 | 67 | Parameters 68 | ---------- 69 | signal : str 70 | A valid signal 71 | observer : callable 72 | An obervation function to be removed 73 | 74 | """ 75 | self._is_allowed_event(signal) 76 | self._remove_observer(signal, observer) 77 | 78 | def notify_observers(self, signal, **kwargs): 79 | """Notify observers of a given signal. 80 | 81 | Parameters 82 | ---------- 83 | signal : str 84 | A valid signal 85 | **kwargs : dict 86 | The parameters that will be sent to the observers 87 | 88 | Returns 89 | ------- 90 | bool 91 | ``False`` if a notification is in progress, otherwise ``True`` 92 | 93 | """ 94 | # Check if a notification if in progress 95 | if self._locked: 96 | return False 97 | # Set the lock 98 | self._locked = True 99 | 100 | # Create a signal object 101 | signal_to_be_notified = SignalObject() 102 | signal_to_be_notified.object = self 103 | signal_to_be_notified.signal = signal 104 | 105 | for name, key_value in kwargs.items(): 106 | setattr(signal_to_be_notified, name, key_value) 107 | # Notify all the observers 108 | for observer in self._observers[signal]: 109 | observer(signal_to_be_notified) 110 | # Unlock the notification process 111 | self._locked = False 112 | 113 | def _get_allowed_signals(self): 114 | """Get allowed signals. 115 | 116 | Events allowed for the current object. 117 | 118 | Returns 119 | ------- 120 | list 121 | List of allowed signals 122 | 123 | """ 124 | return self._allowed_signals 125 | 126 | allowed_signals = property(_get_allowed_signals) 127 | 128 | def _is_allowed_signal(self, signal): 129 | """Check if a signal is valid. 130 | 131 | Raise an exception if the signal is not allowed. 132 | 133 | Parameters 134 | ---------- 135 | signal: str 136 | A signal 137 | 138 | Raises 139 | ------ 140 | ValueError 141 | For invalid signal 142 | 143 | """ 144 | if signal not in self._allowed_signals: 145 | message = 'Signal "{0}" is not allowed for "{1}"' 146 | raise ValueError(message.format(signal, type(self))) 147 | 148 | def _add_observer(self, signal, observer): 149 | """Associate an observer to a valid signal. 150 | 151 | Parameters 152 | ---------- 153 | signal : str 154 | A valid signal 155 | observer : callable 156 | An obervation function 157 | 158 | """ 159 | if observer not in self._observers[signal]: 160 | self._observers[signal].append(observer) 161 | 162 | def _remove_observer(self, signal, observer): 163 | """Remove an observer to a valid signal. 164 | 165 | Parameters 166 | ---------- 167 | signal : str 168 | A valid signal 169 | observer : callable 170 | An obervation function to be removed 171 | 172 | """ 173 | if observer in self._observers[signal]: 174 | self._observers[signal].remove(observer) 175 | 176 | 177 | class MetricObserver: 178 | """Metric observer. 179 | 180 | Wrapper of the metric to the observer object notify by the Observable 181 | class. 182 | 183 | Parameters 184 | ---------- 185 | name : str 186 | The name of the metric 187 | metric : callable 188 | Metric function with this precise signature func(test, ref) 189 | mapping : dict 190 | Define the mapping between the iterate variable and the metric 191 | keyword: ``{'x_new':'name_var_1', 'y_new':'name_var_2'}``. To cancel 192 | the need of a variable, the dict value should be None: 193 | ``'y_new': None``. 194 | cst_kwargs : dict 195 | Keywords arguments of constant argument for the metric computation 196 | early_stopping : bool 197 | If True it will compute the convergence flag (default is ``False``) 198 | wind : int 199 | Window on with the convergence criteria is compute (default is ``6``) 200 | eps : float 201 | The level of criteria of convergence (default is ``1.0e-3``) 202 | 203 | """ 204 | 205 | def __init__( 206 | self, 207 | name, 208 | metric, 209 | mapping, 210 | cst_kwargs, 211 | early_stopping=False, 212 | wind=6, 213 | eps=1.0e-3, 214 | ): 215 | self.name = name 216 | self.metric = metric 217 | self.mapping = mapping 218 | self.cst_kwargs = cst_kwargs 219 | self.list_cv_values = [] 220 | self.list_iters = [] 221 | self.list_dates = [] 222 | self.eps = eps 223 | self.wind = wind 224 | self.converge_flag = False 225 | self.early_stopping = early_stopping 226 | 227 | def __call__(self, signal): 228 | """Call Method. 229 | 230 | Wrapper the call from the observer signature to the metric 231 | signature. 232 | 233 | Parameters 234 | ---------- 235 | signal : str 236 | A valid signal 237 | 238 | """ 239 | kwargs = {} 240 | for key, key_value in self.mapping.items(): 241 | if key_value is not None: 242 | kwargs[key_value] = getattr(signal, key) 243 | kwargs.update(self.cst_kwargs) 244 | self.list_iters.append(signal.idx) 245 | self.list_dates.append(time.time()) 246 | self.list_cv_values.append(self.metric(**kwargs)) 247 | 248 | if self.early_stopping: 249 | self.is_converge() 250 | 251 | def is_converge(self): 252 | """Check convergence. 253 | 254 | Return ``True`` if the convergence criteria is matched. 255 | 256 | """ 257 | if len(self.list_cv_values) < self.wind: 258 | return 259 | start_idx = -self.wind 260 | mid_idx = -(self.wind // 2) 261 | old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean() 262 | current_mean = np.array(self.list_cv_values[mid_idx:]).mean() 263 | normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean) 264 | self.converge_flag = normalize_residual_metrics < self.eps 265 | 266 | def retrieve_metrics(self): 267 | """Retrieve metrics. 268 | 269 | Return the convergence metrics saved with the corresponding 270 | iterations. 271 | 272 | Returns 273 | ------- 274 | dict 275 | Convergence metrics 276 | 277 | """ 278 | time_val = np.array(self.list_dates) 279 | 280 | if time_val.size >= 1: 281 | time_val -= time_val[0] 282 | 283 | return { 284 | "time": time_val, 285 | "index": self.list_iters, 286 | "values": self.list_cv_values, 287 | } 288 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ModOpt 2 | 3 | ModOpt is a series of Modular Optimisation tools for solving inverse problems. 4 | This package has been developed in collaboration between [CosmoStat](http://www.cosmostat.org/) and [NeuroSpin](http://joliot.cea.fr/drf/joliot/Pages/Entites_de_recherche/NeuroSpin.aspx) via the [COSMIC](http://cosmic.cosmostat.org/) project. 5 | 6 | ## Contents 7 | 8 | 1. [Introduction](#introduction) 9 | 2. [Issues](#issues) 10 | a. [Asking Questions](#asking-questions) 11 | b. [Installation Issues](#installation-issues) 12 | c. [Reporting Bugs](#reporting-bugs) 13 | d. [Requesting Features](#requesting-features) 14 | 3. [Pull Requests](#pull-requests) 15 | a. [Before Making a PR](#before-making-a-pr) 16 | b. [Making a PR](#making-a-pr) 17 | c. [After Making a PR](#after-making-a-pr) 18 | d. [Content](#content) 19 | e. [CI Tests](#ci-tests) 20 | f. [Coverage](#coverage) 21 | g. [Style Guide](#style-guide) 22 | 23 | ## Introduction 24 | 25 | ModOpt is fully open-source and as such users are welcome to fork, clone and/or reuse the software freely. Users wishing to contribute to the development of this package, however, are kindly requested to adhere to the following guidelines and the [code of conduct](./CODE_OF_CONDUCT.md). 26 | 27 | ## Issues 28 | 29 | The easiest way to contribute to ModOpt is by raising a "New issue". This will give you the opportunity to ask questions, report bugs or even request new features. 30 | 31 | Remember to use clear and descriptive titles for issues. This will help other users that encounter similar problems find quick solutions. We also ask that you read the available documentation and browse existing issues on similar topics before raising a new issue in order to avoid repetition. 32 | 33 | ### Asking Questions 34 | 35 | Users are of course welcome to ask any question relating to ModOpt and we will endeavour to reply as soon as possible. 36 | 37 | These issues should include the `help wanted` label. 38 | 39 | ### Installation Issues 40 | 41 | If you encounter difficulties installing ModOpt be sure to re-read the installation instructions provided. If you are still unable to install the package please remember to include the following details in the issue you raise: 42 | 43 | * your operating system and the corresponding version (*e.g.* macOS v10.14.1, Ubuntu v16.04.1, *etc.*), 44 | * the version of Python you are using (*e.g* v3.6.7, *etc.*), 45 | * the python environment you are using (if any) and the corresponding version (*e.g.* virtualenv v16.1.0, conda v4.5.11, *etc.*), 46 | * the exact steps followed while attempting to install ModOpt 47 | * and the error message printed or a screen capture of the terminal output. 48 | 49 | These issues should include the `installation` label. 50 | 51 | ### Reporting Bugs 52 | 53 | If you discover a bug while using ModOpt please provide the same information requested for installation issues. Be sure to list the exact steps you followed that lead to the bug you encountered so that we can attempt to recreate the conditions. 54 | 55 | If you are aware of the source of the bug we would very much appreciate if you could provide the module(s) and line number(s) affected. This will enable us to more rapidly fix the problem. 56 | 57 | These issues should include the `bug` label. 58 | 59 | ### Requesting Features 60 | 61 | If you believe ModOpt could be improved with the addition of extra functionality or features feel free to let us know. We cannot guarantee that we will include these features, but we will certainly take your suggestions into consideration. 62 | 63 | In order to increase your chances of having a feature included, be sure to be as clear and specific as possible as to the properties this feature should have. 64 | 65 | These issues should include the `enhancement` label. 66 | 67 | ## Pull Requests 68 | 69 | If you would like to take a more active roll in the development of ModOpt you can do so by submitting a "Pull request". A Pull Requests (PR) is a way by which a user can submit modifications or additions to the ModOpt package directly. PRs need to be reviewed by the package moderators and if accepted are merged into the master branch of the repository. 70 | 71 | Before making a PR, be sure to carefully read the following guidelines. 72 | 73 | ### Before Making a PR 74 | 75 | The following steps should be followed before making a pull request: 76 | 77 | 1. Log into your GitHub account or create an account if you do not already have one. 78 | 79 | 1. Go to the main ModOpt repository page: [https://github.com/CEA-COSMIC/ModOpt](https://github.com/CEA-COSMIC/ModOpt) 80 | 81 | 1. Fork the repository, *i.e.* press the button on the top right with this symbol . This will create an independent copy of the repository on your account. 82 | 83 | 1. Clone your fork of ModOpt. 84 | 85 | ```bash 86 | git clone https://github.com/YOUR_USERNAME/ModOpt 87 | ``` 88 | 89 | 5. Add the original repository (*upstream*) to remote. 90 | 91 | ```bash 92 | git remote add upstream https://github.com/CEA-COSMIC/ModOpt 93 | ``` 94 | 95 | ### Making a PR 96 | 97 | The following steps should be followed to make a pull request: 98 | 99 | 1. Pull the latest updates to the original repository. 100 | 101 | ```bash 102 | git pull upstream master 103 | ``` 104 | 105 | 2. Create a new branch for your modifications. 106 | 107 | ```bash 108 | git checkout -b BRANCH_NAME 109 | ``` 110 | 111 | 3. Make the desired modifications to the relevant modules. 112 | 113 | 4. Add the modified files to the staging area. 114 | 115 | ```bash 116 | git add . 117 | ``` 118 | 119 | 5. Make sure all of the appropriate files have been staged. Note that all files listed in green will be included in the following commit. 120 | 121 | ```bash 122 | git status 123 | ``` 124 | 125 | 6. Commit the changes with an appropriate description. 126 | 127 | ```bash 128 | git commit -m "Description of commit" 129 | ``` 130 | 131 | 7. Push the commits to a branch on your fork of ModOpt. 132 | 133 | ```bash 134 | git push origin BRANCH_NAME 135 | ``` 136 | 137 | 8. Make a pull request for your branch with a clear description of what has been done, why and what issues this relates to. 138 | 139 | 9. Wait for feedback and repeat steps 3 through 7 if necessary. 140 | 141 | ### After Making a PR 142 | 143 | If your PR is accepted and merged it is recommended that the following steps be followed to keep your fork up to date. 144 | 145 | 1. Make sure you switch back to your local master branch. 146 | 147 | ```bash 148 | git checkout master 149 | ``` 150 | 151 | 2. Delete the local branch you used for the PR. 152 | 153 | ```bash 154 | git branch -d BRANCH_NAME 155 | ``` 156 | 157 | 3. Pull the latest updates to the original repository, which include your PR changes. 158 | 159 | ```bash 160 | git pull upstream master 161 | ``` 162 | 163 | 4. Push the commits to your fork. 164 | 165 | ```bash 166 | git push origin master 167 | ``` 168 | 169 | ### Content 170 | 171 | Every PR should correspond to a bug fix or new feature issue that has already be raised. When you make a PR be sure to tag the issue that it resolves (*e.g.* this PR relates to issue #1). This way the issue can be closed once the PR has been merged. 172 | 173 | The content of a given PR should be as concise as possible. To that end, aim to restrict modifications to those needed to resolve a single issue. Additional bug fixes or features should be made as separate PRs. 174 | 175 | ### CI Tests 176 | 177 | Continuous Integration (CI) tests are implemented via [Travis CI](https://travis-ci.org/). All PRs must pass the CI tests before being merged. Your PR may not be reviewed by a moderator until all CI test are passed. Therefore, try to resolve any issues in your PR that may cause the tests to fail. 178 | 179 | In some cases it may be necessary to modify the unit tests, but this should be clearly justified in the PR description. 180 | 181 | ### Coverage 182 | 183 | Coverage tests are implemented via [Coveralls](https://coveralls.io/). These tests will fail if the coverage, *i.e.* the number of lines of code covered by unit tests, decreases. When submitting new code in a PR, contributors should aim to write appropriate unit tests. If the coverage drops significantly moderators may request unit tests be added before the PR is merged. 184 | 185 | ### Style Guide 186 | 187 | All contributions should adhere to the following style guides currently implemented in ModOpt: 188 | 189 | 1. All code should be compatible with the Python versions listed in `README.rst`. 190 | 191 | 1. All code should adhere to [PEP8](https://www.python.org/dev/peps/pep-0008/) standards. 192 | 193 | 1. Docstrings need to be provided for all new modules, methods and classes. These should adhere to [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) standards. 194 | 195 | 1. When in doubt look at the existing code for inspiration. 196 | -------------------------------------------------------------------------------- /src/modopt/base/transform.py: -------------------------------------------------------------------------------- 1 | """DATA TRANSFORM ROUTINES. 2 | 3 | This module contains methods for transforming data. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | 11 | 12 | def cube2map(data_cube, layout): 13 | """Cube to Map. 14 | 15 | This method transforms the input data from a 3D cube to a 2D map with a 16 | specified layout. 17 | 18 | Parameters 19 | ---------- 20 | data_cube : numpy.ndarray 21 | Input data cube, 3D array of 2D images 22 | layout : tuple 23 | 2D layout of 2D images 24 | 25 | Returns 26 | ------- 27 | numpy.ndarray 28 | 2D map 29 | 30 | Raises 31 | ------ 32 | ValueError 33 | For invalid data dimensions 34 | ValueError 35 | For invalid layout 36 | 37 | Examples 38 | -------- 39 | >>> import numpy as np 40 | >>> from modopt.base.transform import cube2map 41 | >>> a = np.arange(16).reshape((4, 2, 2)) 42 | >>> cube2map(a, (2, 2)) 43 | array([[ 0, 1, 4, 5], 44 | [ 2, 3, 6, 7], 45 | [ 8, 9, 12, 13], 46 | [10, 11, 14, 15]]) 47 | 48 | See Also 49 | -------- 50 | map2cube : complimentary function 51 | 52 | """ 53 | if data_cube.ndim != 3: 54 | raise ValueError("The input data must have 3 dimensions.") 55 | 56 | if data_cube.shape[0] != np.prod(layout): 57 | raise ValueError( 58 | "The desired layout must match the number of input " + "data layers.", 59 | ) 60 | 61 | res = [ 62 | np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))]) 63 | for elem in range(layout[0]) 64 | ] 65 | 66 | return np.vstack(res) 67 | 68 | 69 | def map2cube(data_map, layout): 70 | """Map to cube. 71 | 72 | This method transforms the input data from a 2D map with given layout to 73 | a 3D cube. 74 | 75 | Parameters 76 | ---------- 77 | data_map : numpy.ndarray 78 | Input data map, 2D array 79 | layout : tuple 80 | 2D layout of 2D images 81 | 82 | Returns 83 | ------- 84 | numpy.ndarray 85 | 3D cube 86 | 87 | Raises 88 | ------ 89 | ValueError 90 | For invalid layout 91 | 92 | Examples 93 | -------- 94 | >>> import numpy as np 95 | >>> from modopt.base.transform import map2cube 96 | >>> a = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], 97 | ... [10, 11, 14, 15]]) 98 | >>> map2cube(a, (2, 2)) 99 | array([[[ 0, 1], 100 | [ 2, 3]], 101 | 102 | [[ 4, 5], 103 | [ 6, 7]], 104 | 105 | [[ 8, 9], 106 | [10, 11]], 107 | 108 | [[12, 13], 109 | [14, 15]]]) 110 | 111 | See Also 112 | -------- 113 | cube2map : complimentary function 114 | 115 | """ 116 | if np.all(np.array(data_map.shape) % np.array(layout)): 117 | raise ValueError( 118 | "The desired layout must be a multiple of the number " 119 | + "pixels in the data map.", 120 | ) 121 | 122 | d_shape = np.array(data_map.shape) // np.array(layout) 123 | 124 | return np.array( 125 | [ 126 | data_map[ 127 | ( 128 | slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]), 129 | slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]), 130 | ) 131 | ] 132 | for i_elem in range(layout[0]) 133 | for j_elem in range(layout[1]) 134 | ] 135 | ) 136 | 137 | 138 | def map2matrix(data_map, layout): 139 | """Map to Matrix. 140 | 141 | This method transforms a 2D map to a 2D matrix. 142 | 143 | Parameters 144 | ---------- 145 | data_map : numpy.ndarray 146 | Input data map, 2D array 147 | layout : tuple 148 | 2D layout of 2D images 149 | 150 | Returns 151 | ------- 152 | numpy.ndarray 153 | 2D matrix 154 | 155 | Examples 156 | -------- 157 | >>> import numpy as np 158 | >>> from modopt.base.transform import map2matrix 159 | >>> a = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], 160 | ... [10, 11, 14, 15]]) 161 | >>> map2matrix(a, (2, 2)) 162 | array([[ 0, 4, 8, 12], 163 | [ 1, 5, 9, 13], 164 | [ 2, 6, 10, 14], 165 | [ 3, 7, 11, 15]]) 166 | 167 | See Also 168 | -------- 169 | matrix2map : complimentary function 170 | 171 | """ 172 | layout = np.array(layout) 173 | 174 | # Get the shape of the images 175 | image_shape = (np.array(data_map.shape) // layout)[0] 176 | 177 | # Stack objects from map 178 | data_matrix = [] 179 | 180 | for i_elem in range(np.prod(layout)): 181 | lower = ( 182 | image_shape * (i_elem // layout[1]), 183 | image_shape * (i_elem % layout[1]), 184 | ) 185 | upper = ( 186 | image_shape * (i_elem // layout[1] + 1), 187 | image_shape * (i_elem % layout[1] + 1), 188 | ) 189 | data_matrix.append( 190 | (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape( 191 | image_shape**2 192 | ), 193 | ) 194 | 195 | return np.array(data_matrix).T 196 | 197 | 198 | def matrix2map(data_matrix, map_shape): 199 | """Matrix to Map. 200 | 201 | This method transforms a 2D matrix to a 2D map. 202 | 203 | Parameters 204 | ---------- 205 | data_matrix : numpy.ndarray 206 | Input data matrix, 2D array 207 | map_shape : tuple 208 | 2D shape of the output map 209 | 210 | Returns 211 | ------- 212 | numpy.ndarray 213 | 2D map 214 | 215 | Examples 216 | -------- 217 | >>> import numpy as np 218 | >>> from modopt.base.transform import matrix2map 219 | >>> a = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], 220 | ... [3, 7, 11, 15]]) 221 | >>> matrix2map(a, (4, 4)) 222 | array([[ 0, 1, 4, 5], 223 | [ 2, 3, 6, 7], 224 | [ 8, 9, 12, 13], 225 | [10, 11, 14, 15]]) 226 | 227 | See Also 228 | -------- 229 | map2matrix : complimentary function 230 | 231 | """ 232 | map_shape = np.array(map_shape) 233 | 234 | # Get the shape and layout of the images 235 | image_shape = np.sqrt(data_matrix.shape[0]).astype(int) 236 | layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int") 237 | 238 | # Map objects from matrix 239 | data_map = np.zeros(map_shape) 240 | 241 | temp = data_matrix.reshape(image_shape, image_shape, data_matrix.shape[1]) 242 | 243 | for i_elem in range(data_matrix.shape[1]): 244 | lower = ( 245 | image_shape * (i_elem // layout[1]), 246 | image_shape * (i_elem % layout[1]), 247 | ) 248 | upper = ( 249 | image_shape * (i_elem // layout[1] + 1), 250 | image_shape * (i_elem % layout[1] + 1), 251 | ) 252 | data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem] 253 | 254 | return data_map.astype(int) 255 | 256 | 257 | def cube2matrix(data_cube): 258 | """Cube to Matrix. 259 | 260 | This method transforms a 3D cube to a 2D matrix. 261 | 262 | Parameters 263 | ---------- 264 | data_cube : numpy.ndarray 265 | Input data cube, 3D array 266 | 267 | Returns 268 | ------- 269 | numpy.ndarray 270 | 2D matrix 271 | 272 | Examples 273 | -------- 274 | >>> import numpy as np 275 | >>> from modopt.base.transform import cube2matrix 276 | >>> a = np.arange(16).reshape((4, 2, 2)) 277 | >>> cube2matrix(a) 278 | array([[ 0, 4, 8, 12], 279 | [ 1, 5, 9, 13], 280 | [ 2, 6, 10, 14], 281 | [ 3, 7, 11, 15]]) 282 | 283 | See Also 284 | -------- 285 | matrix2cube : complimentary function 286 | 287 | """ 288 | return data_cube.reshape( 289 | [data_cube.shape[0], np.prod(data_cube.shape[1:])], 290 | ).T 291 | 292 | 293 | def matrix2cube(data_matrix, im_shape): 294 | """Matrix to Cube. 295 | 296 | This method transforms a 2D matrix to a 3D cube. 297 | 298 | Parameters 299 | ---------- 300 | data_matrix : numpy.ndarray 301 | Input data cube, 2D array 302 | im_shape : tuple 303 | 2D shape of the individual images 304 | 305 | Returns 306 | ------- 307 | numpy.ndarray 308 | 3D cube 309 | 310 | Examples 311 | -------- 312 | >>> import numpy as np 313 | >>> from modopt.base.transform import matrix2cube 314 | >>> a = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], 315 | ... [3, 7, 11, 15]]) 316 | >>> matrix2cube(a, (2, 2)) 317 | array([[[ 0, 1], 318 | [ 2, 3]], 319 | 320 | [[ 4, 5], 321 | [ 6, 7]], 322 | 323 | [[ 8, 9], 324 | [10, 11]], 325 | 326 | [[12, 13], 327 | [14, 15]]]) 328 | 329 | See Also 330 | -------- 331 | cube2matrix : complimentary function 332 | 333 | """ 334 | return data_matrix.T.reshape([data_matrix.shape[1], *list(im_shape)]) 335 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Python Template sphinx config 2 | 3 | # Import relevant modules 4 | import sys 5 | import os 6 | from importlib_metadata import metadata 7 | 8 | # If extensions (or modules to document with autodoc) are in another directory, 9 | # add these directories to sys.path here. If the directory is relative to the 10 | # documentation root, use os.path.abspath to make it absolute, like shown here. 11 | sys.path.insert(0, os.path.abspath("../..")) 12 | 13 | # -- General configuration ------------------------------------------------ 14 | 15 | # General information about the project. 16 | project = "modopt" 17 | 18 | mdata = metadata(project) 19 | author = "Samuel Farrens, Pierre-Antoine Comby, Chaithya GR, Philippe Ciuciu" 20 | version = mdata["Version"] 21 | copyright = f"2020, {author}" 22 | gh_user = "sfarrens" 23 | 24 | # Add any Sphinx extension module names here, as strings. They can be 25 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 26 | # ones. 27 | extensions = [ 28 | "sphinx.ext.autodoc", 29 | "sphinx.ext.autosummary", 30 | "sphinx.ext.coverage", 31 | "sphinx.ext.doctest", 32 | "sphinx.ext.ifconfig", 33 | "sphinx.ext.intersphinx", 34 | "sphinx.ext.mathjax", 35 | "sphinx.ext.napoleon", 36 | "sphinx.ext.todo", 37 | "sphinx.ext.viewcode", 38 | "sphinxawesome_theme.highlighting", 39 | "sphinxcontrib.bibtex", 40 | "myst_parser", 41 | "nbsphinx", 42 | "nbsphinx_link", 43 | "numpydoc", 44 | "sphinx_gallery.gen_gallery", 45 | ] 46 | 47 | # Include module names for objects 48 | add_module_names = False 49 | 50 | # Set class documentation standard. 51 | autoclass_content = "class" 52 | 53 | # Audodoc options 54 | autodoc_default_options = { 55 | "member-order": "bysource", 56 | "private-members": True, 57 | "show-inheritance": True, 58 | } 59 | 60 | # Generate summaries 61 | autosummary_generate = True 62 | 63 | # Suppress class members in toctree. 64 | numpydoc_show_class_members = False 65 | 66 | # The suffix(es) of source filenames. 67 | # You can specify multiple suffix as a list of string: 68 | source_suffix = [".rst", ".md"] 69 | 70 | # The master toctree document. 71 | master_doc = "index" 72 | 73 | # If true, sectionauthor and moduleauthor directives will be shown in the 74 | # output. They are ignored by default. 75 | show_authors = True 76 | 77 | # The name of the Pygments (syntax highlighting) style to use. 78 | pygments_style = "default" 79 | 80 | # If true, `todo` and `todoList` produce output, else they produce nothing. 81 | todo_include_todos = True 82 | 83 | # -- Options for HTML output ---------------------------------------------- 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | html_theme = "sphinxawesome_theme" 88 | # html_theme = 'sphinx_book_theme' 89 | 90 | # Theme options are theme-specific and customize the look and feel of a theme 91 | # further. For a list of options available for each theme, see the 92 | # documentation. 93 | html_theme_options = { 94 | "nav_include_hidden": True, 95 | "show_nav": True, 96 | "show_breadcrumbs": True, 97 | "breadcrumbs_separator": "/", 98 | "show_prev_next": True, 99 | "show_scrolltop": True, 100 | } 101 | html_collapsible_definitions = True 102 | html_awesome_headerlinks = True 103 | html_logo = "modopt_logo.png" 104 | html_permalinks_icon = ( 105 | '' 107 | '' 113 | ) 114 | # The name for this set of Sphinx documents. If None, it defaults to 115 | # " v documentation". 116 | html_title = f"{project} v{version}" 117 | 118 | # A shorter title for the navigation bar. Default is the same as html_title. 119 | # html_short_title = None 120 | 121 | # The name of an image file (relative to this directory) to place at the top 122 | # of the sidebar. 123 | # html_logo = None 124 | 125 | # The name of an image file (within the static path) to use as favicon of the 126 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 127 | # pixels large. 128 | # html_favicon = None 129 | 130 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 131 | # using the given strftime format. 132 | html_last_updated_fmt = "%d %b, %Y" 133 | 134 | # If true, SmartyPants will be used to convert quotes and dashes to 135 | # typographically correct entities. 136 | html_use_smartypants = True 137 | 138 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 139 | html_show_sphinx = True 140 | 141 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 142 | html_show_copyright = True 143 | 144 | 145 | # -- Options for Sphinx Gallery ---------------------------------------------- 146 | 147 | sphinx_gallery_conf = { 148 | "examples_dirs": ["../../examples/"], 149 | "filename_pattern": "/example_", 150 | "ignore_pattern": r"/(__init__|conftest)\.py", 151 | } 152 | 153 | 154 | # -- Options for nbshpinx output ------------------------------------------ 155 | 156 | 157 | # Custom fucntion to find notebooks, create .nblink files and update the 158 | # notebooks.rst file 159 | def add_notebooks(nb_path="../../notebooks"): 160 | 161 | print("Looking for notebooks") 162 | nb_ext = ".ipynb" 163 | nb_rst_file_name = "notebooks.rst" 164 | nb_link_format = '{{\n "path": "{0}/{1}"\n}}' 165 | 166 | nbs = sorted([nb for nb in os.listdir(nb_path) if nb.endswith(nb_ext)]) 167 | 168 | for list_pos, nb in enumerate(nbs): 169 | 170 | nb_name = nb.rstrip(nb_ext) 171 | 172 | nb_link_file_name = nb_name + ".nblink" 173 | print(f"Writing {nb_link_file_name}") 174 | with open(nb_link_file_name, "w") as nb_link_file: 175 | nb_link_file.write(nb_link_format.format(nb_path, nb)) 176 | 177 | print(f"Looking for {nb_name} in {nb_rst_file_name}") 178 | with open(nb_rst_file_name) as nb_rst_file: 179 | check_name = nb_name not in nb_rst_file.read() 180 | 181 | if check_name: 182 | print(f"Adding {nb_name} to {nb_rst_file_name}") 183 | with open(nb_rst_file_name, "a") as nb_rst_file: 184 | if list_pos == 0: 185 | nb_rst_file.write("\n") 186 | nb_rst_file.write(f" {nb_name}\n") 187 | 188 | return nbs 189 | 190 | 191 | # Add notebooks 192 | add_notebooks() 193 | 194 | binder = "https://mybinder.org/v2/gh" 195 | binder_badge = "https://mybinder.org/badge_logo.svg" 196 | github = "https://github.com/" 197 | github_badge = "https://badgen.net/badge/icon/github?icon=github&label" 198 | 199 | # Remove promts and add binder badge 200 | nb_header_pt1 = r""" 201 | {% if env.metadata[env.docname]['nbsphinx-link-target'] %} 202 | {% set docpath = env.metadata[env.docname]['nbsphinx-link-target'] %} 203 | {% else %} 204 | {% set docpath = env.doc2path(env.docname, base='docs/source/') %} 205 | {% endif %} 206 | 207 | .. raw:: html 208 | 209 | 215 | 216 | """ 217 | nb_header_pt2 = ( 218 | r"""

""" 219 | rf"""""" 221 | + rf"""Binder badge
""" 223 | r"""

""" 228 | ) 229 | 230 | nbsphinx_prolog = nb_header_pt1 + nb_header_pt2 231 | 232 | # -- Intersphinx Mapping ---------------------------------------------- 233 | 234 | # Refer to the package libraries for type definitions 235 | intersphinx_mapping = { 236 | "python": ("http://docs.python.org/3", None), 237 | "numpy": ("https://numpy.org/doc/stable/", None), 238 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 239 | "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None), 240 | "matplotlib": ("https://matplotlib.org", None), 241 | "astropy": ("http://docs.astropy.org/en/latest/", None), 242 | "cupy": ("https://docs-cupy.chainer.org/en/stable/", None), 243 | "torch": ("https://pytorch.org/docs/stable/", None), 244 | "sklearn": ( 245 | "http://scikit-learn.org/stable", 246 | (None, "./_intersphinx/sklearn-objects.inv"), 247 | ), 248 | "tensorflow": ( 249 | "https://www.tensorflow.org/api_docs/python", 250 | ( 251 | "https://github.com/GPflow/tensorflow-intersphinx/" 252 | + "raw/master/tf2_py_objects.inv" 253 | ), 254 | ), 255 | } 256 | 257 | # -- BibTeX Setting ---------------------------------------------- 258 | 259 | bibtex_bibfiles = ["refs.bib", "my_ref.bib"] 260 | bibtex_default_style = "alpha" 261 | -------------------------------------------------------------------------------- /src/modopt/signal/svd.py: -------------------------------------------------------------------------------- 1 | """SVD ROUTINES. 2 | 3 | This module contains methods for thresholding singular values. 4 | 5 | :Author: Samuel Farrens 6 | 7 | """ 8 | 9 | import numpy as np 10 | from scipy.linalg import svd 11 | from scipy.sparse.linalg import svds 12 | 13 | from modopt.base.transform import matrix2cube 14 | from modopt.interface.errors import warn 15 | from modopt.math.convolve import convolve 16 | from modopt.signal.noise import thresh 17 | 18 | 19 | def find_n_pc(u_vec, factor=0.5): 20 | """Find number of principal components. 21 | 22 | This method finds the minimum number of principal components required. 23 | 24 | Parameters 25 | ---------- 26 | u_vec : numpy.ndarray 27 | Left singular vector of the original data 28 | factor : float, optional 29 | Factor for testing the auto correlation (default is ``0.5``) 30 | 31 | Returns 32 | ------- 33 | int 34 | Number of principal components 35 | 36 | Raises 37 | ------ 38 | ValueError 39 | Invalid left singular vector 40 | 41 | Examples 42 | -------- 43 | >>> import numpy as np 44 | >>> from scipy.linalg import svd 45 | >>> from modopt.signal.svd import find_n_pc 46 | >>> x = np.arange(18).reshape(9, 2).astype(float) 47 | >>> find_n_pc(svd(x)[0]) 48 | 1 49 | 50 | """ 51 | if np.sqrt(u_vec.shape[0]) % 1: 52 | raise ValueError( 53 | "Invalid left singular vector. The size of the first " 54 | + "dimenion of ``u_vec`` must be perfect square.", 55 | ) 56 | 57 | # Get the shape of the array 58 | array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2) 59 | 60 | # Find the auto correlation of the left singular vector. 61 | u_auto = [ 62 | convolve( 63 | elem.reshape(array_shape), 64 | np.rot90(elem.reshape(array_shape), 2), 65 | ) 66 | for elem in u_vec.T 67 | ] 68 | 69 | # Return the required number of principal components. 70 | return np.sum( 71 | [ 72 | (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),) 73 | for u_val in u_auto 74 | ] 75 | ) 76 | 77 | 78 | def calculate_svd(input_data): 79 | """Calculate Singular Value Decomposition. 80 | 81 | This method calculates the Singular Value Decomposition (SVD) of the input 82 | data using SciPy. 83 | 84 | Parameters 85 | ---------- 86 | input_data : numpy.ndarray 87 | Input data array, 2D matrix 88 | 89 | Returns 90 | ------- 91 | tuple 92 | Left singular vector, singular values and right singular vector 93 | 94 | Raises 95 | ------ 96 | TypeError 97 | For invalid data type 98 | 99 | """ 100 | if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2): 101 | raise TypeError("Input data must be a 2D np.ndarray.") 102 | 103 | return svd( 104 | input_data, 105 | check_finite=False, 106 | lapack_driver="gesvd", 107 | full_matrices=False, 108 | ) 109 | 110 | 111 | def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"): 112 | """Threshold the singular values. 113 | 114 | This method thresholds the input data using singular value decomposition. 115 | 116 | Parameters 117 | ---------- 118 | input_data : numpy.ndarray 119 | Input data array, 2D matrix 120 | threshold : float or numpy.ndarray, optional 121 | Threshold value(s) (default is ``None``) 122 | n_pc : int or str, optional 123 | Number of principal components, specify an integer value or ``'all'`` 124 | (default is ``None``) 125 | thresh_type : {'hard', 'soft'}, optional 126 | Type of thresholding (default is ``'hard'``) 127 | 128 | Returns 129 | ------- 130 | numpy.ndarray 131 | Thresholded data 132 | 133 | Raises 134 | ------ 135 | ValueError 136 | For invalid n_pc value 137 | 138 | Examples 139 | -------- 140 | >>> import numpy as np 141 | >>> from modopt.signal.svd import svd_thresh 142 | >>> x = np.arange(18).reshape(9, 2).astype(float) 143 | >>> svd_thresh(x, n_pc=1) 144 | array([[ 0.49815487, 0.54291537], 145 | [ 2.40863386, 2.62505584], 146 | [ 4.31911286, 4.70719631], 147 | [ 6.22959185, 6.78933678], 148 | [ 8.14007085, 8.87147725], 149 | [10.05054985, 10.95361772], 150 | [11.96102884, 13.03575819], 151 | [13.87150784, 15.11789866], 152 | [15.78198684, 17.20003913]]) 153 | 154 | """ 155 | less_than_zero = isinstance(n_pc, int) and n_pc <= 0 156 | str_not_all = isinstance(n_pc, str) and n_pc != "all" 157 | 158 | if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all: 159 | raise ValueError( 160 | 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"', 161 | ) 162 | 163 | # Get SVD of input data. 164 | u_vec, s_values, v_vec = calculate_svd(input_data) 165 | 166 | # Find the threshold if not provided. 167 | if isinstance(threshold, type(None)): 168 | # Find the required number of principal components if not specified. 169 | if isinstance(n_pc, type(None)): 170 | n_pc = find_n_pc(u_vec, factor=0.1) 171 | print("xxxx", n_pc, u_vec) 172 | 173 | # If the number of PCs is too large use all of the singular values. 174 | if (isinstance(n_pc, int) and n_pc >= s_values.size) or ( 175 | isinstance(n_pc, str) and n_pc == "all" 176 | ): 177 | n_pc = s_values.size 178 | warn("Using all singular values.") 179 | 180 | threshold = s_values[n_pc - 1] 181 | 182 | # Threshold the singular values. 183 | s_new = thresh(s_values, threshold, thresh_type) 184 | 185 | if np.all(s_new == s_values): 186 | warn("No change to singular values.") 187 | 188 | # Diagonalize the svd 189 | s_new = np.diag(s_new) 190 | 191 | # Return the thresholded data. 192 | return np.dot(u_vec, np.dot(s_new, v_vec)) 193 | 194 | 195 | def svd_thresh_coef_fast( 196 | input_data, 197 | threshold, 198 | n_vals=-1, 199 | extra_vals=5, 200 | thresh_type="hard", 201 | ): 202 | """Threshold the singular values coefficients. 203 | 204 | This method thresholds the input data by using singular value 205 | decomposition, but only computing the the greastest ``n_vals`` 206 | values. 207 | 208 | Parameters 209 | ---------- 210 | input_data : numpy.ndarray 211 | Input data array, 2D matrix 212 | Operator class instance 213 | threshold : float or numpy.ndarray 214 | Threshold value(s) 215 | n_vals: int, optional 216 | Number of singular values to compute. 217 | If None, compute all singular values. 218 | extra_vals: int, optional 219 | If the number of values computed is not enough to perform thresholding, 220 | recompute by using ``n_vals + extra_vals`` (default is ``5``) 221 | thresh_type : {'hard', 'soft'} 222 | Type of noise to be added (default is ``'hard'``) 223 | 224 | Returns 225 | ------- 226 | tuple 227 | The thresholded data (numpy.ndarray) and the estimated rank after 228 | thresholding (int) 229 | """ 230 | if n_vals == -1: 231 | n_vals = min(input_data.shape) - 1 232 | ok = False 233 | while not ok: 234 | (u_vec, s_values, v_vec) = svds(input_data, k=n_vals) 235 | ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1 236 | n_vals = min(n_vals + extra_vals, *input_data.shape) 237 | 238 | s_values = thresh( 239 | s_values, 240 | threshold, 241 | threshold_type=thresh_type, 242 | ) 243 | rank = np.count_nonzero(s_values) 244 | return ( 245 | np.dot( 246 | u_vec[:, -rank:] * s_values[-rank:], 247 | v_vec[-rank:, :], 248 | ), 249 | rank, 250 | ) 251 | 252 | 253 | def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"): 254 | """Threshold the singular values coefficients. 255 | 256 | This method thresholds the input data using singular value decomposition. 257 | 258 | Parameters 259 | ---------- 260 | input_data : numpy.ndarray 261 | Input data array, 2D matrix 262 | operator : class 263 | Operator class instance 264 | threshold : float or numpy.ndarray 265 | Threshold value(s) 266 | thresh_type : {'hard', 'soft'} 267 | Type of noise to be added (default is ``'hard'``) 268 | 269 | Returns 270 | ------- 271 | numpy.ndarray 272 | Thresholded data 273 | 274 | Raises 275 | ------ 276 | TypeError 277 | If operator not callable 278 | 279 | """ 280 | if not callable(operator): 281 | raise TypeError("Operator must be a callable function.") 282 | 283 | # Get SVD of data matrix 284 | u_vec, s_values, v_vec = calculate_svd(input_data) 285 | 286 | # Diagnalise s 287 | s_values = np.diag(s_values) 288 | 289 | # Compute coefficients 290 | a_matrix = np.dot(s_values, v_vec) 291 | 292 | # Get the shape of the array 293 | array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2) 294 | 295 | # Compute threshold matrix. 296 | ti = np.array( 297 | [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))] 298 | ) 299 | threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape) 300 | 301 | # Threshold coefficients. 302 | a_new = thresh(a_matrix, threshold, thresh_type) 303 | 304 | # Return the thresholded image. 305 | return np.dot(u_vec, a_new) 306 | -------------------------------------------------------------------------------- /tests/test_algorithms.py: -------------------------------------------------------------------------------- 1 | """UNIT TESTS FOR Algorithms. 2 | 3 | This module contains unit tests for the modopt.opt module. 4 | 5 | :Authors: 6 | Samuel Farrens 7 | Pierre-Antoine Comby 8 | """ 9 | 10 | import numpy as np 11 | import numpy.testing as npt 12 | from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight 13 | from pytest_cases import ( 14 | fixture, 15 | parametrize, 16 | parametrize_with_cases, 17 | ) 18 | 19 | 20 | SKLEARN_AVAILABLE = True 21 | try: 22 | import sklearn 23 | except ImportError: 24 | SKLEARN_AVAILABLE = False 25 | 26 | 27 | rng = np.random.default_rng() 28 | 29 | 30 | @fixture 31 | def idty(): 32 | """Identity function.""" 33 | return lambda x: x 34 | 35 | 36 | @fixture 37 | def reweight_op(): 38 | """Reweight operator.""" 39 | data3 = np.arange(9).reshape(3, 3).astype(float) + 1 40 | return reweight.cwbReweight(data3) 41 | 42 | 43 | def build_kwargs(kwargs, use_metrics): 44 | """Build the kwargs for each algorithm, replacing placeholders by true values. 45 | 46 | This function has to be call for each test, as direct parameterization somehow 47 | is not working with pytest-xdist and pytest-cases. 48 | It also adds dummy metric measurement to validate the metric api. 49 | """ 50 | update_value = { 51 | "idty": lambda x: x, 52 | "lin_idty": linear.Identity(), 53 | "reweight_op": reweight.cwbReweight( 54 | np.arange(9).reshape(3, 3).astype(float) + 1 55 | ), 56 | } 57 | new_kwargs = dict() 58 | print(kwargs) 59 | # update the value of the dict is possible. 60 | for key in kwargs: 61 | new_kwargs[key] = update_value.get(kwargs[key], kwargs[key]) 62 | 63 | if use_metrics: 64 | new_kwargs["linear"] = linear.Identity() 65 | new_kwargs["metrics"] = { 66 | "diff": { 67 | "metric": lambda test, ref: np.sum(test - ref), 68 | "mapping": {"x_new": "test"}, 69 | "cst_kwargs": {"ref": np.arange(9).reshape((3, 3))}, 70 | "early_stopping": False, 71 | } 72 | } 73 | 74 | return new_kwargs 75 | 76 | 77 | @parametrize(use_metrics=[True, False]) 78 | class AlgoCases: 79 | r"""Cases for algorithms. 80 | 81 | Most of the test solves the trivial problem 82 | 83 | .. math:: 84 | \\min_x \\frac{1}{2} \\| y - x \\|_2^2 \\quad\\text{s.t.} x \\geq 0 85 | 86 | More complex and concrete usecases are shown in examples. 87 | """ 88 | 89 | data1 = np.arange(9).reshape(3, 3).astype(float) 90 | data2 = data1 + rng.standard_normal(data1.shape) * 1e-6 91 | max_iter = 20 92 | 93 | @parametrize( 94 | kwargs=[ 95 | {"beta_update": "idty", "auto_iterate": False, "cost": None}, 96 | {"beta_update": "idty"}, 97 | {"cost": None, "lambda_update": None}, 98 | {"beta_update": "idty", "a_cd": 3}, 99 | {"beta_update": "idty", "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7}, 100 | {"restart_strategy": "adaptive", "xi_restart": 0.9}, 101 | { 102 | "restart_strategy": "greedy", 103 | "xi_restart": 0.9, 104 | "min_beta": 1.0, 105 | "s_greedy": 1.1, 106 | }, 107 | ] 108 | ) 109 | def case_forward_backward(self, kwargs, idty, use_metrics): 110 | """Forward Backward case.""" 111 | update_kwargs = build_kwargs(kwargs, use_metrics) 112 | algo = algorithms.ForwardBackward( 113 | self.data1, 114 | grad=gradient.GradBasic(self.data1, idty, idty), 115 | prox=proximity.Positivity(), 116 | **update_kwargs, 117 | ) 118 | if update_kwargs.get("auto_iterate", None) is False: 119 | algo.iterate(self.max_iter) 120 | return algo, update_kwargs 121 | 122 | @parametrize( 123 | kwargs=[ 124 | { 125 | "cost": None, 126 | "auto_iterate": False, 127 | "gamma_update": "idty", 128 | "beta_update": "idty", 129 | }, 130 | {"gamma_update": "idty", "lambda_update": "idty"}, 131 | {"cost": True}, 132 | {"cost": True, "step_size": 2}, 133 | ] 134 | ) 135 | def case_gen_forward_backward(self, kwargs, use_metrics, idty): 136 | """General FB setup.""" 137 | update_kwargs = build_kwargs(kwargs, use_metrics) 138 | grad_inst = gradient.GradBasic(self.data1, idty, idty) 139 | prox_inst = proximity.Positivity() 140 | prox_dual_inst = proximity.IdentityProx() 141 | if update_kwargs.get("cost", None) is True: 142 | update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) 143 | algo = algorithms.GenForwardBackward( 144 | self.data1, 145 | grad=grad_inst, 146 | prox_list=[prox_inst, prox_dual_inst], 147 | **update_kwargs, 148 | ) 149 | if update_kwargs.get("auto_iterate", None) is False: 150 | algo.iterate(self.max_iter) 151 | return algo, update_kwargs 152 | 153 | @parametrize( 154 | kwargs=[ 155 | { 156 | "sigma_dual": "idty", 157 | "tau_update": "idty", 158 | "rho_update": "idty", 159 | "auto_iterate": False, 160 | }, 161 | { 162 | "sigma_dual": "idty", 163 | "tau_update": "idty", 164 | "rho_update": "idty", 165 | }, 166 | { 167 | "linear": "lin_idty", 168 | "cost": True, 169 | "reweight": "reweight_op", 170 | }, 171 | ] 172 | ) 173 | def case_condat(self, kwargs, use_metrics, idty): 174 | """Condat Vu Algorithm setup.""" 175 | update_kwargs = build_kwargs(kwargs, use_metrics) 176 | grad_inst = gradient.GradBasic(self.data1, idty, idty) 177 | prox_inst = proximity.Positivity() 178 | prox_dual_inst = proximity.IdentityProx() 179 | if update_kwargs.get("cost", None) is True: 180 | update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) 181 | 182 | algo = algorithms.Condat( 183 | self.data1, 184 | self.data2, 185 | grad=grad_inst, 186 | prox=prox_inst, 187 | prox_dual=prox_dual_inst, 188 | **update_kwargs, 189 | ) 190 | if update_kwargs.get("auto_iterate", None) is False: 191 | algo.iterate(self.max_iter) 192 | return algo, update_kwargs 193 | 194 | @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}]) 195 | def case_pogm(self, kwargs, use_metrics, idty): 196 | """POGM setup.""" 197 | update_kwargs = build_kwargs(kwargs, use_metrics) 198 | grad_inst = gradient.GradBasic(self.data1, idty, idty) 199 | prox_inst = proximity.Positivity() 200 | algo = algorithms.POGM( 201 | u=self.data1, 202 | x=self.data1, 203 | y=self.data1, 204 | z=self.data1, 205 | grad=grad_inst, 206 | prox=prox_inst, 207 | **update_kwargs, 208 | ) 209 | 210 | if update_kwargs.get("auto_iterate", None) is False: 211 | algo.iterate(self.max_iter) 212 | return algo, update_kwargs 213 | 214 | @parametrize( 215 | GradDescent=[ 216 | algorithms.VanillaGenericGradOpt, 217 | algorithms.AdaGenericGradOpt, 218 | algorithms.ADAMGradOpt, 219 | algorithms.MomentumGradOpt, 220 | algorithms.RMSpropGradOpt, 221 | algorithms.SAGAOptGradOpt, 222 | ] 223 | ) 224 | def case_grad(self, GradDescent, use_metrics, idty): 225 | """Gradient Descent algorithm test.""" 226 | update_kwargs = build_kwargs({}, use_metrics) 227 | grad_inst = gradient.GradBasic(self.data1, idty, idty) 228 | prox_inst = proximity.Positivity() 229 | cost_inst = cost.costObj([grad_inst, prox_inst]) 230 | 231 | algo = GradDescent( 232 | self.data1, 233 | grad=grad_inst, 234 | prox=prox_inst, 235 | cost=cost_inst, 236 | **update_kwargs, 237 | ) 238 | algo.iterate() 239 | return algo, update_kwargs 240 | 241 | @parametrize(admm=[algorithms.ADMM, algorithms.FastADMM]) 242 | def case_admm(self, admm, use_metrics, idty): 243 | """ADMM setup.""" 244 | 245 | def optim1(init, obs): 246 | return obs 247 | 248 | def optim2(init, obs): 249 | return obs 250 | 251 | update_kwargs = build_kwargs({}, use_metrics) 252 | algo = admm( 253 | u=self.data1, 254 | v=self.data1, 255 | mu=np.zeros_like(self.data1), 256 | A=linear.Identity(), 257 | B=linear.Identity(), 258 | b=self.data1, 259 | optimizers=(optim1, optim2), 260 | **update_kwargs, 261 | ) 262 | algo.iterate() 263 | return algo, update_kwargs 264 | 265 | 266 | @parametrize_with_cases("algo, kwargs", cases=AlgoCases) 267 | def test_algo(algo, kwargs): 268 | """Test algorithms.""" 269 | if kwargs.get("auto_iterate") is False: 270 | # algo already run 271 | npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1) 272 | else: 273 | npt.assert_almost_equal(algo.x_final, AlgoCases.data1) 274 | 275 | if kwargs.get("metrics"): 276 | print(algo.metrics) 277 | npt.assert_almost_equal(algo.metrics["diff"]["values"][-1], 0, 3) 278 | -------------------------------------------------------------------------------- /src/modopt/opt/algorithms/primal_dual.py: -------------------------------------------------------------------------------- 1 | """Primal-Dual Algorithms.""" 2 | 3 | from modopt.opt.algorithms.base import SetUp 4 | from modopt.opt.cost import costObj 5 | from modopt.opt.linear import Identity 6 | 7 | 8 | class Condat(SetUp): 9 | r"""Condat optimisation. 10 | 11 | This class implements algorithm 3.1 from :cite:`condat2013`. 12 | 13 | Parameters 14 | ---------- 15 | x : numpy.ndarray 16 | Initial guess for the primal variable 17 | y : numpy.ndarray 18 | Initial guess for the dual variable 19 | grad 20 | Gradient operator class instance 21 | prox 22 | Proximity primal operator class instance 23 | prox_dual 24 | Proximity dual operator class instance 25 | linear : class instance, optional 26 | Linear operator class instance (default is ``None``) 27 | cost : class instance or str, optional 28 | Cost function class instance (default is ``'auto'``); Use ``'auto'`` to 29 | automatically generate a ``costObj`` instance 30 | reweight : class instance, optional 31 | Reweighting class instance 32 | rho : float, optional 33 | Relaxation parameter, :math:`\rho` (default is ``0.5``) 34 | sigma : float, optional 35 | Proximal dual parameter, :math:`\sigma` (default is ``1.0``) 36 | tau : float, optional 37 | Proximal primal paramater, :math:`\tau` (default is ``1.0``) 38 | rho_update : callable, optional 39 | Relaxation parameter update method (default is ``None``) 40 | sigma_update : callable, optional 41 | Proximal dual parameter update method (default is ``None``) 42 | tau_update : callable, optional 43 | Proximal primal parameter update method (default is ``None``) 44 | auto_iterate : bool, optional 45 | Option to automatically begin iterations upon initialisation (default 46 | is ``True``) 47 | max_iter : int, optional 48 | Maximum number of iterations (default is ``150``) 49 | n_rewightings : int, optional 50 | Number of reweightings to perform (default is ``1``) 51 | 52 | Notes 53 | ----- 54 | The ``tau_param`` can also be set using the keyword `step_size`, which will 55 | override the value of ``tau_param``. 56 | 57 | The following state variable are available for metrics measurememts at 58 | each iteration : 59 | 60 | * ``'x_new'`` : new estimate of :math:`x` (primal variable) 61 | * ``'y_new'`` : new estimate of :math:`y` (dual variable) 62 | * ``'idx'`` : index of the iteration. 63 | 64 | See Also 65 | -------- 66 | modopt.opt.algorithms.base.SetUp : parent class 67 | modopt.opt.cost.costObj : cost object class 68 | modopt.opt.gradient : gradient operator classes 69 | modopt.opt.proximity : proximity operator classes 70 | modopt.opt.linear : linear operator classes 71 | modopt.opt.reweight : reweighting classes 72 | 73 | """ 74 | 75 | def __init__( 76 | self, 77 | x, 78 | y, 79 | grad, 80 | prox, 81 | prox_dual, 82 | linear=None, 83 | cost="auto", 84 | reweight=None, 85 | rho=0.5, 86 | sigma=1.0, 87 | tau=1.0, 88 | rho_update=None, 89 | sigma_update=None, 90 | tau_update=None, 91 | auto_iterate=True, 92 | max_iter=150, 93 | n_rewightings=1, 94 | metric_call_period=5, 95 | metrics=None, 96 | **kwargs, 97 | ): 98 | # Set default algorithm properties 99 | super().__init__( 100 | metric_call_period=metric_call_period, 101 | metrics=metrics, 102 | **kwargs, 103 | ) 104 | 105 | # Set the initial variable values 106 | for input_data in (x, y): 107 | self._check_input_data(input_data) 108 | 109 | self._x_old = self.xp.copy(x) 110 | self._y_old = self.xp.copy(y) 111 | 112 | # Set the algorithm operators 113 | for operator in (grad, prox, prox_dual, linear, cost): 114 | self._check_operator(operator) 115 | 116 | self._grad = grad 117 | self._prox = prox 118 | self._prox_dual = prox_dual 119 | self._reweight = reweight 120 | if isinstance(linear, type(None)): 121 | self._linear = Identity() 122 | else: 123 | self._linear = linear 124 | if cost == "auto": 125 | self._cost_func = costObj( 126 | [ 127 | self._grad, 128 | self._prox, 129 | self._prox_dual, 130 | ] 131 | ) 132 | else: 133 | self._cost_func = cost 134 | 135 | # Set the algorithm parameters 136 | for param_val in (rho, sigma, tau): 137 | self._check_param(param_val) 138 | 139 | self._rho = rho 140 | self._sigma = sigma 141 | self._tau = self.step_size or tau 142 | 143 | # Set the algorithm parameter update methods 144 | for param_update in (rho_update, sigma_update, tau_update): 145 | self._check_param_update(param_update) 146 | 147 | self._rho_update = rho_update 148 | self._sigma_update = sigma_update 149 | self._tau_update = tau_update 150 | 151 | # Automatically run the algorithm 152 | if auto_iterate: 153 | self.iterate(max_iter=max_iter, n_rewightings=n_rewightings) 154 | 155 | def _update_param(self): 156 | """Update parameters. 157 | 158 | This method updates the values of the algorthm parameters with the 159 | methods provided. 160 | 161 | """ 162 | # Update relaxation parameter. 163 | if not isinstance(self._rho_update, type(None)): 164 | self._rho = self._rho_update(self._rho) 165 | 166 | # Update proximal dual parameter. 167 | if not isinstance(self._sigma_update, type(None)): 168 | self._sigma = self._sigma_update(self._sigma) 169 | 170 | # Update proximal primal parameter. 171 | if not isinstance(self._tau_update, type(None)): 172 | self._tau = self._tau_update(self._tau) 173 | 174 | def _update(self): 175 | """Update. 176 | 177 | This method updates the current reconstruction. 178 | 179 | Notes 180 | ----- 181 | Implements equation 9 (algorithm 3.1) from :cite:`condat2013`. 182 | 183 | - Primal proximity operator set up for positivity constraint. 184 | 185 | """ 186 | # Step 1 from eq.9. 187 | self._grad.get_grad(self._x_old) 188 | 189 | x_prox = self._prox.op( 190 | self._x_old 191 | - self._tau * self._grad.grad 192 | - self._tau * self._linear.adj_op(self._y_old), 193 | ) 194 | 195 | # Step 2 from eq.9. 196 | y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old) 197 | 198 | y_prox = y_temp - self._sigma * self._prox_dual.op( 199 | y_temp / self._sigma, 200 | extra_factor=(1.0 / self._sigma), 201 | ) 202 | 203 | # Step 3 from eq.9. 204 | self._x_new = self._rho * x_prox + (1 - self._rho) * self._x_old 205 | self._y_new = self._rho * y_prox + (1 - self._rho) * self._y_old 206 | 207 | del x_prox, y_prox, y_temp 208 | 209 | # Update old values for next iteration. 210 | self.xp.copyto(self._x_old, self._x_new) 211 | self.xp.copyto(self._y_old, self._y_new) 212 | 213 | # Update parameter values for next iteration. 214 | self._update_param() 215 | 216 | # Test cost function for convergence. 217 | if self._cost_func: 218 | self.converge = self.any_convergence_flag() or self._cost_func.get_cost( 219 | self._x_new, self._y_new 220 | ) 221 | 222 | def iterate(self, max_iter=150, n_rewightings=1, progbar=None): 223 | """Iterate. 224 | 225 | This method calls update until either convergence criteria is met or 226 | the maximum number of iterations is reached. 227 | 228 | Parameters 229 | ---------- 230 | max_iter : int, optional 231 | Maximum number of iterations (default is ``150``) 232 | n_rewightings : int, optional 233 | Number of reweightings to perform (default is ``1``) 234 | progbar: tqdm.tqdm 235 | Progress bar handle (default is ``None``) 236 | """ 237 | self._run_alg(max_iter, progbar) 238 | 239 | if not isinstance(self._reweight, type(None)): 240 | for _ in range(n_rewightings): 241 | self._reweight.reweight(self._linear.op(self._x_new)) 242 | if progbar: 243 | progbar.reset(total=max_iter) 244 | self._run_alg(max_iter, progbar) 245 | 246 | # retrieve metrics results 247 | self.retrieve_outputs() 248 | # rename outputs as attributes 249 | self.x_final = self._x_new 250 | self.y_final = self._y_new 251 | 252 | def get_notify_observers_kwargs(self): 253 | """Notify observers. 254 | 255 | Return the mapping between the metrics call and the iterated 256 | variables. 257 | 258 | Returns 259 | ------- 260 | notify_observers_kwargs : dict, 261 | The mapping between the iterated variables 262 | 263 | """ 264 | return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx} 265 | 266 | def retrieve_outputs(self): 267 | """Retrieve outputs. 268 | 269 | Declare the outputs of the algorithms as attributes: ``x_final``, 270 | ``y_final``, ``metrics``. 271 | 272 | """ 273 | metrics = {} 274 | for obs in self._observers["cv_metrics"]: 275 | metrics[obs.name] = obs.retrieve_metrics() 276 | self.metrics = metrics 277 | --------------------------------------------------------------------------------