├── notebooks
└── .gitkeep
├── docs
├── source
│ ├── z_ref.rst
│ ├── modopt_logo.png
│ ├── output_17_0.png
│ ├── output_21_0.png
│ ├── output_5_0.png
│ ├── output_9_0.png
│ ├── cosmostat_logo.jpg
│ ├── neurospin_logo.png
│ ├── notebooks.rst
│ ├── citing.rst
│ ├── quickstart.rst
│ ├── toc.rst
│ ├── contributing.rst
│ ├── index.rst
│ ├── my_ref.bib
│ ├── installation.rst
│ ├── about.rst
│ ├── dependencies.rst
│ ├── plugin_example.rst
│ ├── refs.bib
│ └── conf.py
├── _templates
│ ├── toc.rst_t
│ ├── module.rst_t
│ └── package.rst_t
└── requirements.txt
├── examples
├── README.rst
├── __init__.py
├── conftest.py
└── example_lasso_forward_backward.py
├── .coveragerc
├── tests
├── test_helpers
│ ├── __init__.py
│ └── utils.py
├── test_base.py
└── test_algorithms.py
├── src
└── modopt
│ ├── plot
│ ├── __init__.py
│ └── cost_plot.py
│ ├── interface
│ ├── __init__.py
│ ├── log.py
│ └── errors.py
│ ├── math
│ ├── __init__.py
│ ├── convolve.py
│ ├── metrics.py
│ └── stats.py
│ ├── signal
│ ├── __init__.py
│ ├── validation.py
│ ├── positivity.py
│ ├── filter.py
│ ├── noise.py
│ └── svd.py
│ ├── opt
│ ├── __init__.py
│ ├── linear
│ │ ├── __init__.py
│ │ └── base.py
│ ├── algorithms
│ │ ├── __init__.py
│ │ └── primal_dual.py
│ ├── reweight.py
│ └── gradient.py
│ ├── base
│ ├── __init__.py
│ ├── types.py
│ ├── backend.py
│ ├── np_adjust.py
│ ├── observable.py
│ └── transform.py
│ └── __init__.py
├── .github
├── ISSUE_TEMPLATE
│ ├── help-.md
│ ├── installation-issue.md
│ ├── feature_request.md
│ └── bug_report.md
└── workflows
│ ├── style.yml
│ ├── ci-build.yml
│ └── cd-build.yml
├── LICENCE.txt
├── pyproject.toml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── README.md
└── CONTRIBUTING.md
/notebooks/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/z_ref.rst:
--------------------------------------------------------------------------------
1 | References
2 | ==========
3 |
4 | .. bibliography:: refs.bib
5 |
--------------------------------------------------------------------------------
/docs/source/modopt_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/modopt_logo.png
--------------------------------------------------------------------------------
/docs/source/output_17_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/output_17_0.png
--------------------------------------------------------------------------------
/docs/source/output_21_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/output_21_0.png
--------------------------------------------------------------------------------
/docs/source/output_5_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/output_5_0.png
--------------------------------------------------------------------------------
/docs/source/output_9_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/output_9_0.png
--------------------------------------------------------------------------------
/docs/source/cosmostat_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/cosmostat_logo.jpg
--------------------------------------------------------------------------------
/docs/source/neurospin_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CEA-COSMIC/ModOpt/HEAD/docs/source/neurospin_logo.png
--------------------------------------------------------------------------------
/examples/README.rst:
--------------------------------------------------------------------------------
1 | ========
2 | Examples
3 | ========
4 |
5 | This is a collection of Python scripts demonstrating the use of ModOpt.
6 |
--------------------------------------------------------------------------------
/docs/source/notebooks.rst:
--------------------------------------------------------------------------------
1 | Notebooks
2 | =========
3 |
4 | List of Notebooks
5 | -----------------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit=
3 | *interface*
4 | *plot*
5 |
6 | [report]
7 | show_missing=True
8 | exclude_lines =
9 | pragma: no cover
10 |
--------------------------------------------------------------------------------
/tests/test_helpers/__init__.py:
--------------------------------------------------------------------------------
1 | """Utilities for tests."""
2 |
3 | from .utils import Dummy, failparam, skipparam
4 |
5 | __all__ = ["Dummy", "failparam", "skipparam"]
6 |
--------------------------------------------------------------------------------
/docs/_templates/toc.rst_t:
--------------------------------------------------------------------------------
1 | {{ header | heading }}
2 |
3 | .. toctree::
4 | :maxdepth: {{ maxdepth }}
5 | {% for docname in docnames %}
6 | {{ docname }}
7 | {%- endfor %}
8 |
9 |
--------------------------------------------------------------------------------
/docs/_templates/module.rst_t:
--------------------------------------------------------------------------------
1 | {%- if show_headings %}
2 | {{- basename | e | heading(2) }}
3 |
4 | {% endif -%}
5 | .. automodule:: {{ qualname }}
6 | {%- for option in automodule_options %}
7 | :{{ option }}:
8 | {%- endfor %}
9 |
--------------------------------------------------------------------------------
/docs/source/citing.rst:
--------------------------------------------------------------------------------
1 | Citing this Package
2 | ===================
3 |
4 | We kindly request that any academic work making use of this package to cite :cite:`farrens:2020`.
5 |
6 | .. bibliography:: my_ref.bib
7 | :style: alpha
8 |
--------------------------------------------------------------------------------
/src/modopt/plot/__init__.py:
--------------------------------------------------------------------------------
1 | """PLOTTING ROUTINES.
2 |
3 | This module contains submodules for plotting applications.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | __all__ = ["cost_plot"]
10 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | jupyter==1.0.0
2 | myst-parser==0.16.1
3 | nbsphinx==0.8.7
4 | nbsphinx-link==1.3.0
5 | numpydoc==1.1.0
6 | sphinx==4.3.1
7 | sphinxcontrib-bibtex==2.4.1
8 | sphinxawesome-theme==3.2.1
9 | sphinx-gallery==0.11.1
10 |
--------------------------------------------------------------------------------
/docs/source/quickstart.rst:
--------------------------------------------------------------------------------
1 | Quickstart Tutorial
2 | ===================
3 |
4 | You can import the package as follows:
5 |
6 | .. code-block:: python
7 |
8 | import modopt
9 |
10 | .. note::
11 |
12 | Examples coming soon!
13 |
--------------------------------------------------------------------------------
/src/modopt/interface/__init__.py:
--------------------------------------------------------------------------------
1 | """INTERFACE ROUTINES.
2 |
3 | This module contains submodules for error handling, logging and IO interaction.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | __all__ = ["errors", "log"]
10 |
--------------------------------------------------------------------------------
/src/modopt/math/__init__.py:
--------------------------------------------------------------------------------
1 | """MATHEMATICS ROUTINES.
2 |
3 | This module contains submodules for mathematical applications.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | __all__ = ["convolve", "matrix", "stats", "metrics"]
10 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | """EXAMPLES.
2 |
3 | This module contains documented examples that demonstrate the usage of various
4 | ModOpt tools.
5 |
6 | These examples also serve as integration tests for various methods.
7 |
8 | :Author: Pierre-Antoine Comby
9 |
10 | """
11 |
--------------------------------------------------------------------------------
/src/modopt/signal/__init__.py:
--------------------------------------------------------------------------------
1 | """SIGNAL PROCESSING ROUTINES.
2 |
3 | This module contains submodules for signal processing.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | __all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"]
10 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/help-.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Help!
3 | about: Users are welcome to ask any question relating to ModOpt and we will endeavour
4 | to reply as soon as possible.
5 | title: "[HELP]"
6 | labels: help wanted
7 | assignees: ''
8 |
9 | ---
10 |
11 | Let us know how we can help.
12 |
--------------------------------------------------------------------------------
/src/modopt/opt/__init__.py:
--------------------------------------------------------------------------------
1 | """OPTIMISATION PROBLEM MODULES.
2 |
3 | This module contains submodules for solving optimisation problems.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | __all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"]
10 |
--------------------------------------------------------------------------------
/src/modopt/base/__init__.py:
--------------------------------------------------------------------------------
1 | """BASE ROUTINES.
2 |
3 | This module contains submodules for basic operations such as type
4 | transformations and adjustments to the default output of Numpy functions.
5 |
6 | :Author: Samuel Farrens
7 |
8 | """
9 |
10 | __all__ = ["np_adjust", "transform", "types", "observable"]
11 |
--------------------------------------------------------------------------------
/src/modopt/opt/linear/__init__.py:
--------------------------------------------------------------------------------
1 | """LINEAR OPERATORS.
2 |
3 | This module contains linear operator classes.
4 |
5 | :Author: Samuel Farrens
6 | :Author: Pierre-Antoine Comby
7 | """
8 |
9 | from .base import LinearParent, Identity, MatrixOperator, LinearCombo
10 |
11 | from .wavelet import WaveletConvolve, WaveletTransform
12 |
13 |
14 | __all__ = [
15 | "LinearParent",
16 | "Identity",
17 | "MatrixOperator",
18 | "LinearCombo",
19 | "WaveletConvolve",
20 | "WaveletTransform",
21 | ]
22 |
--------------------------------------------------------------------------------
/docs/source/toc.rst:
--------------------------------------------------------------------------------
1 | .. Toctrees define sidebar contents
2 |
3 | .. toctree::
4 | :hidden:
5 | :titlesonly:
6 | :caption: Getting Started
7 |
8 | about
9 | installation
10 | dependencies
11 | quickstart
12 |
13 | .. toctree::
14 | :hidden:
15 | :titlesonly:
16 | :caption: API Documentation
17 |
18 | modopt
19 | z_ref
20 |
21 | .. toctree::
22 | :hidden:
23 | :titlesonly:
24 | :caption: Examples
25 |
26 | plugin_example
27 | notebooks
28 | auto_examples/index
29 |
30 | .. toctree::
31 | :hidden:
32 | :titlesonly:
33 | :caption: Guidelines
34 |
35 | contributing
36 | citing
37 |
--------------------------------------------------------------------------------
/docs/source/contributing.rst:
--------------------------------------------------------------------------------
1 | Contributing
2 | ============
3 |
4 | Read our |link-to-contrib|.
5 | for details on how to contribute to the development of this package.
6 |
7 | All contributors are kindly asked to adhere to the |link-to-conduct|
8 | at all times to ensure a safe and inclusive environment for everyone.
9 |
10 | .. |link-to-contrib| raw:: html
11 |
12 | Contribution Guidelines
13 |
14 | .. |link-to-conduct| raw:: html
15 |
16 | Code of Conduct
17 |
--------------------------------------------------------------------------------
/src/modopt/__init__.py:
--------------------------------------------------------------------------------
1 | """MODOPT PACKAGE.
2 |
3 | ModOpt is a series of Modular Optimisation tools for solving inverse problems.
4 |
5 | """
6 |
7 | from warnings import warn
8 |
9 | from importlib_metadata import version
10 |
11 | from modopt.base import np_adjust, transform, types, observable
12 |
13 | __all__ = ["np_adjust", "transform", "types", "observable"]
14 |
15 | try:
16 | _version = version("modopt")
17 | except Exception: # pragma: no cover
18 | _version = "Unkown"
19 | warn(
20 | "Could not extract package metadata. Make sure the package is "
21 | + "correctly installed.",
22 | stacklevel=1,
23 | )
24 |
25 | __version__ = _version
26 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. modopt documentation master file, created by
2 | sphinx-quickstart on Mon Oct 24 16:46:22 2016.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | ModOpt Documentation
7 | ======================
8 |
9 | .. image:: modopt_logo.png
10 | :width: 100%
11 | :alt: ModOpt logo
12 |
13 | .. Include table of contents
14 | .. include:: toc.rst
15 |
16 | :Author: Samuel Farrens `(samuel.farrens@cea.fr) `_
17 | :Version: 1.6.0
18 | :Release Date: 17/12/2021
19 | :Repository: |link-to-repo|
20 |
21 | .. |link-to-repo| raw:: html
22 |
23 | https://github.com/CEA-COSMIC/ModOpt
25 |
26 | ModOpt is a series of **Modular Optimisation** tools for solving inverse problems.
27 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/installation-issue.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Installation issue
3 | about: If you encounter difficulties installing ModOpt be sure to re-read the installation
4 | instructions provided before submitting an issue.
5 | title: "[INSTALLATION ERROR]"
6 | labels: installation
7 | assignees: ''
8 |
9 | ---
10 |
11 | **System setup**
12 | OS: [e.g] macOS v10.14.1
13 | Python version: [e.g.] v3.6.7
14 | Python environment (if any): [e.g.] conda v4.5.11
15 |
16 | **Describe the bug**
17 | A clear and concise description of what the problem is.
18 |
19 | **To Reproduce**
20 | List the exact steps you followed that lead to the problem you encountered so that we can attempt to recreate the conditions.
21 |
22 | **Screenshots**
23 | If applicable, add screenshots to help explain your problem.
24 |
25 | **Are you planning to submit a Pull Request?**
26 | - [ ] Yes
27 | - [X] No
28 |
--------------------------------------------------------------------------------
/tests/test_helpers/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some helper functions for the test parametrization.
3 |
4 | They should be used inside ``@pytest.mark.parametrize`` call.
5 |
6 | :Author: Pierre-Antoine Comby
7 | """
8 |
9 | import pytest
10 |
11 |
12 | def failparam(*args, raises=None):
13 | """Return a pytest parameterization that should raise an error."""
14 | if not issubclass(raises, Exception):
15 | raise ValueError("raises should be an expected Exception.")
16 | return pytest.param(*args, marks=[pytest.mark.xfail(exception=raises)])
17 |
18 |
19 | def skipparam(*args, cond=True, reason=""):
20 | """Return a pytest parameterization that should be skip if cond is valid."""
21 | return pytest.param(*args, marks=[pytest.mark.skipif(cond, reason=reason)])
22 |
23 |
24 | class Dummy:
25 | """Dummy Class."""
26 |
27 | pass
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: If you believe ModOpt could be improved with the addition of extra functionality
4 | or features feel free to let us know.
5 | title: "[NEW FEATURE]"
6 | labels: enhancement
7 | assignees: ''
8 |
9 | ---
10 |
11 | **Is your feature request related to a problem? Please describe.**
12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
13 |
14 | **Describe the solution you'd like**
15 | In order to increase your chances of having a feature included, be sure to be as clear and specific as possible as to the properties this feature should have.
16 |
17 | **Describe alternatives you've considered**
18 | A clear and concise description of any alternative solutions or features you've considered.
19 |
20 | **Are you planning to submit a Pull Request?**
21 | - [ ] Yes
22 | - [X] No
23 |
--------------------------------------------------------------------------------
/docs/source/my_ref.bib:
--------------------------------------------------------------------------------
1 | @ARTICLE{farrens:2020,
2 | author = {{Farrens}, S. and {Grigis}, A. and {El Gueddari}, L. and {Ramzi}, Z. and {Chaithya}, G.~R. and {Starck}, S. and {Sarthou}, B. and {Cherkaoui}, H. and {Ciuciu}, P. and {Starck}, J. -L.},
3 | title = "{PySAP: Python Sparse Data Analysis Package for multidisciplinary image processing}",
4 | journal = {Astronomy and Computing},
5 | keywords = {Image processing, Convex optimisation, Reconstruction, Open-source software, Astrophysics - Instrumentation and Methods for Astrophysics},
6 | year = 2020,
7 | month = jul,
8 | volume = {32},
9 | eid = {100402},
10 | pages = {100402},
11 | doi = {10.1016/j.ascom.2020.100402},
12 | archivePrefix = {arXiv},
13 | eprint = {1910.08465},
14 | primaryClass = {astro-ph.IM},
15 | adsurl = {https://ui.adsabs.harvard.edu/abs/2020A&C....3200402F},
16 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
17 | }
18 |
--------------------------------------------------------------------------------
/.github/workflows/style.yml:
--------------------------------------------------------------------------------
1 | name: Style checking
2 |
3 | on:
4 | push:
5 | branches: [ "master", "main", "develop" ]
6 | pull_request:
7 | branches: [ "master", "main", "develop" ]
8 |
9 | workflow_dispatch:
10 |
11 | env:
12 | PYTHON_VERSION: "3.10"
13 |
14 | jobs:
15 | linter-check:
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Checkout
19 | uses: actions/checkout@v4
20 | - name: Set up Python ${{ env.PYTHON_VERSION }}
21 | uses: actions/setup-python@v4
22 | with:
23 | python-version: ${{ env.PYTHON_VERSION }}
24 | cache: pip
25 |
26 | - name: Install Python deps
27 | shell: bash
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install -e .[test,dev]
31 |
32 | - name: Black Check
33 | shell: bash
34 | run: black . --diff --color --check
35 |
36 | - name: ruff Check
37 | shell: bash
38 | run: ruff check
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: If you discover a bug while using ModOpt please provide the following information
4 | to help us resolve the issue.
5 | title: "[BUG]"
6 | labels: bug
7 | assignees: ''
8 |
9 | ---
10 |
11 | **System setup**
12 | OS: [e.g] macOS v10.14.1
13 | Python version: [e.g.] v3.6.7
14 | Python environment (if any): [e.g.] conda v4.5.11
15 |
16 | **Describe the bug**
17 | A clear and concise description of what the bug is.
18 |
19 | **To Reproduce**
20 | List the exact steps you followed that lead to the bug you encountered so that we can attempt to recreate the conditions.
21 |
22 | **Expected behavior**
23 | A clear and concise description of what you expected to happen.
24 |
25 | **Screenshots**
26 | If applicable, add screenshots to help explain your problem.
27 |
28 | **Module and lines involved**
29 | If you are aware of the source of the bug we would very much appreciate if you could provide the module(s) and line number(s) affected. This will enable us to more rapidly fix the problem.
30 |
31 | **Are you planning to submit a Pull Request?**
32 | - [ ] Yes
33 | - [X] No
34 |
--------------------------------------------------------------------------------
/LICENCE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Samuel Farrens
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/_templates/package.rst_t:
--------------------------------------------------------------------------------
1 | {%- macro automodule(modname, options) -%}
2 | .. automodule:: {{ modname }}
3 | {%- for option in options %}
4 | :{{ option }}:
5 | {%- endfor %}
6 | {%- endmacro %}
7 |
8 | {%- macro toctree(docnames) -%}
9 | .. toctree::
10 | :maxdepth: {{ maxdepth }}
11 | {% for docname in docnames %}
12 | {{ docname }}
13 | {%- endfor %}
14 | {%- endmacro %}
15 |
16 | {%- if is_namespace %}
17 | {{- [pkgname, "namespace"] | join(" ") | e | heading }}
18 | {% else %}
19 | {{- pkgname | e | heading }}
20 |
21 | {% endif %}
22 |
23 | {%- if modulefirst and not is_namespace %}
24 | {{ automodule(pkgname, automodule_options) }}
25 | {% endif %}
26 |
27 | {%- if subpackages %}
28 | Subpackages
29 | -----------
30 |
31 | {{ toctree(subpackages) }}
32 | {% endif %}
33 |
34 | {%- if submodules %}
35 | Submodules
36 | ----------
37 | {% if separatemodules %}
38 | {{ toctree(submodules) }}
39 | {% else %}
40 | {%- for submodule in submodules %}
41 | {% if show_headings %}
42 | {{- submodule | e | heading(2) }}
43 | {% endif %}
44 | {{ automodule(submodule, automodule_options) }}
45 | {% endfor %}
46 | {%- endif %}
47 | {%- endif %}
48 |
49 | {%- if not modulefirst and not is_namespace %}
50 | Module contents
51 | ---------------
52 |
53 | {{ automodule(pkgname, automodule_options) }}
54 | {% endif %}
55 |
--------------------------------------------------------------------------------
/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | .. note::
5 |
6 | ModOpt is automatically installed by |link-to-pysap|. The following steps are
7 | intended for those that wish to use ModOpt independently of PySAP or those
8 | aiming to contribute to the development of the package.
9 |
10 | .. |link-to-pysap| raw:: html
11 |
12 | PySAP
13 |
14 | Users
15 | -----
16 |
17 | You can install the latest release of ModOpt from `PyPi `_
18 | as follows:
19 |
20 | .. code-block:: bash
21 |
22 | pip install modopt
23 |
24 |
25 | Alternatively clone the repository and build the package locally as follows:
26 |
27 | .. code-block:: bash
28 |
29 | pip install .
30 |
31 |
32 | Developers
33 | ----------
34 |
35 | Developers are recommend to clone the repository and build the package locally
36 | in development mode with testing and documentation packages as follows:
37 |
38 | .. code-block:: bash
39 |
40 | pip install -e ".[develop]"
41 |
42 | Troubleshooting
43 | ---------------
44 | If you encounter any difficulties installing ModOpt we recommend that you
45 | open a |link-to-issue| and we will do our best to help you.
46 |
47 | .. |link-to-issue| raw:: html
48 |
49 | new issue
51 |
--------------------------------------------------------------------------------
/src/modopt/plot/cost_plot.py:
--------------------------------------------------------------------------------
1 | """PLOTTING ROUTINES.
2 |
3 | This module contains methods for making plots.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 |
11 | try:
12 | import matplotlib.pyplot as plt
13 | except ImportError: # pragma: no cover
14 | import_fail = True
15 | else:
16 | import_fail = False
17 |
18 |
19 | def plotCost(cost_list, output=None):
20 | """Plot cost function.
21 |
22 | Plot the final cost function.
23 |
24 | Parameters
25 | ----------
26 | cost_list : list
27 | List of cost function values
28 | output : str, optional
29 | Output file name (default is ``None``)
30 |
31 | Raises
32 | ------
33 | ImportError
34 | If Matplotlib package not found
35 |
36 | """
37 | if import_fail:
38 | raise ImportError("Matplotlib package not found")
39 |
40 | else:
41 | if isinstance(output, type(None)):
42 | file_name = "cost_function.png"
43 | else:
44 | file_name = f"{output}_cost_function.png"
45 |
46 | plt.figure()
47 | plt.plot(np.log10(cost_list), "r-")
48 | plt.title("Cost Function")
49 | plt.xlabel("Iteration")
50 | plt.ylabel(r"$\log_{10}$ Cost")
51 | plt.savefig(file_name)
52 | plt.close()
53 |
54 | print(" - Saving cost function data to:", file_name)
55 |
--------------------------------------------------------------------------------
/examples/conftest.py:
--------------------------------------------------------------------------------
1 | """TEST CONFIGURATION.
2 |
3 | This module contains methods for configuring the testing of the example
4 | scripts.
5 |
6 | :Author: Pierre-Antoine Comby
7 |
8 | Notes
9 | -----
10 | Based on:
11 | https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test
12 |
13 | """
14 |
15 | from pathlib import Path
16 | import runpy
17 | import pytest
18 |
19 |
20 | def pytest_collect_file(path, parent):
21 | """Pytest hook.
22 |
23 | Create a collector for the given path, or None if not relevant.
24 | The new node needs to have the specified parent as parent.
25 | """
26 | p = Path(path)
27 | if p.suffix == ".py" and "example" in p.name:
28 | return Script.from_parent(parent, path=p, name=p.name)
29 |
30 |
31 | class Script(pytest.File):
32 | """Script files collected by pytest."""
33 |
34 | def collect(self):
35 | """Collect the script as its own item."""
36 | yield ScriptItem.from_parent(self, name=self.name)
37 |
38 |
39 | class ScriptItem(pytest.Item):
40 | """Item script collected by pytest."""
41 |
42 | def runtest(self):
43 | """Run the script as a test."""
44 | runpy.run_path(str(self.path))
45 |
46 | def repr_failure(self, excinfo):
47 | """Return only the error traceback of the script."""
48 | excinfo.traceback = excinfo.traceback.cut(path=self.path)
49 | return super().repr_failure(excinfo)
50 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name="modopt"
3 | description = 'Modular Optimisation tools for soliving inverse problems.'
4 | version = "1.7.2"
5 | requires-python= ">=3.8"
6 |
7 | authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"},
8 | {name="Chaithya GR", email="chaithyagr@gmail.com"},
9 | {name="Pierre-Antoine Comby", email="pierre-antoine.comby@cea.fr"},
10 | {name="Philippe Ciuciu", email="philippe.ciuciu@cea.fr"}
11 | ]
12 | readme="README.md"
13 | license={file="LICENCE.txt"}
14 |
15 | dependencies = ["numpy", "scipy", "tqdm", "importlib_metadata"]
16 |
17 | [project.optional-dependencies]
18 | gpu=["torch", "ptwt"]
19 | doc=["myst-parser",
20 | "nbsphinx",
21 | "nbsphinx-link",
22 | "sphinx-gallery",
23 | "numpydoc",
24 | "sphinxawesome-theme",
25 | "sphinxcontrib-bibtex"]
26 | dev=["black", "ruff"]
27 | test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-xdist", "pytest-sugar"]
28 |
29 | [build-system]
30 | requires=["setuptools", "setuptools-scm[toml]", "wheel"]
31 |
32 | [tool.coverage.run]
33 | omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"]
34 |
35 | [tool.coverage.report]
36 | precision = 2
37 | exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
38 |
39 | [tool.black]
40 |
41 |
42 | [tool.ruff]
43 | exclude = ["examples", "docs"]
44 | [tool.ruff.lint]
45 | select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
46 |
47 | ignore = ["F401"] # we like the try: import ... expect: ...
48 |
49 | [tool.ruff.lint.pydocstyle]
50 | convention="numpy"
51 |
52 | [tool.isort]
53 | profile="black"
54 |
55 | [tool.pytest.ini_options]
56 | minversion = "6.0"
57 | norecursedirs = ["tests/test_helpers"]
58 | addopts = ["--cov=modopt", "--cov-report=term-missing", "--cov-report=xml"]
59 |
--------------------------------------------------------------------------------
/docs/source/about.rst:
--------------------------------------------------------------------------------
1 | About
2 | =====
3 |
4 | ModOpt was developed as part of |link-to-cosmic|, a multi-disciplinary collaboration
5 | between |link-to-neurospin|, experts in biomedical imaging, and |link-to-cosmostat|,
6 | experts in astrophysical image processing. The package was
7 | designed to provide the backend optimisation algorithms for
8 | |link-to-pysap| :cite:`farrens:2020`, but also serves as a stand-alone library
9 | of inverse problem solving tools. While PySAP aims to provide
10 | application-specific tools for solving complex imaging problems, ModOpt can in
11 | principle be applied to any linear inverse problem.
12 |
13 | Contributors
14 | ------------
15 |
16 | You can find a |link-to-contributors|.
17 |
18 | |CS_LOGO| |NS_LOGO|
19 |
20 | .. |link-to-cosmic| raw:: html
21 |
22 | COSMIC
23 |
24 | .. |link-to-neurospin| raw:: html
25 |
26 | NeuroSpin
28 |
29 | .. |link-to-cosmostat| raw:: html
30 |
31 | CosmoStat
33 |
34 | .. |link-to-pysap| raw:: html
35 |
36 | PySAP
37 |
38 | .. |link-to-contributors| raw:: html
39 |
40 | list of ModOpt contributors here
42 |
43 | .. |CS_LOGO| image:: cosmostat_logo.jpg
44 | :width: 45%
45 | :alt: CosmoStat Logo
46 | :target: http://www.cosmostat.org/
47 |
48 | .. |NS_LOGO| image:: neurospin_logo.png
49 | :width: 45%
50 | :alt: NeuroSpin Logo
51 | :target: https://joliot.cea.fr/drf/joliot/en/Pages/research_entities/NeuroSpin.aspx
52 |
--------------------------------------------------------------------------------
/.github/workflows/ci-build.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - develop
7 | - master
8 | - main
9 |
10 | jobs:
11 | test-full:
12 | name: Full Test Suite
13 | runs-on: ${{ matrix.os }}
14 |
15 | strategy:
16 | fail-fast: false
17 | matrix:
18 | os: [ubuntu-latest, macos-latest]
19 | python-version: ["3.8", "3.9", "3.10"]
20 |
21 | steps:
22 | - uses: actions/checkout@v4
23 | - uses: actions/setup-python@v4
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | cache: pip
27 |
28 | - name: Install Dependencies
29 | shell: bash -l {0}
30 | run: |
31 | python --version
32 | python -m pip install --upgrade pip
33 | python -m pip install .[test]
34 | python -m pip install astropy scikit-image scikit-learn matplotlib
35 | python -m pip install tensorflow>=2.4.1 torch
36 |
37 | - name: Run Tests
38 | shell: bash -l {0}
39 | run: |
40 | pytest -n 2
41 |
42 | - name: Save Test Results
43 | if: always()
44 | uses: actions/upload-artifact@v4
45 | with:
46 | name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }}
47 | path: coverage.xml
48 |
49 | - name: Check API Documentation build
50 | shell: bash -l {0}
51 | run: |
52 | apt install pandoc
53 | pip install .[doc] ipykernel
54 | sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
55 | sphinx-build -b doctest -E docs/source docs/_build
56 |
57 | - name: Upload Coverage to Codecov
58 | uses: codecov/codecov-action@v1
59 | with:
60 | token: ${{ secrets.CODECOV_TOKEN }}
61 | file: coverage.xml
62 | flags: unittests
63 |
64 |
--------------------------------------------------------------------------------
/src/modopt/opt/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | r"""OPTIMISATION ALGORITHMS.
2 |
3 | This module contains class implementations of various optimisation algoritms.
4 |
5 | :Authors:
6 |
7 | * Samuel Farrens ,
8 | * Zaccharie Ramzi ,
9 | * Pierre-Antoine Comby
10 |
11 | :Notes:
12 |
13 | Input classes must have the following properties:
14 |
15 | * **Gradient Operators**
16 |
17 | Must have the following methods:
18 |
19 | * ``get_grad()`` - calculate the gradient
20 |
21 | Must have the following variables:
22 |
23 | * ``grad`` - the gradient
24 |
25 | * **Linear Operators**
26 |
27 | Must have the following methods:
28 |
29 | * ``op()`` - operator
30 | * ``adj_op()`` - adjoint operator
31 |
32 | * **Proximity Operators**
33 |
34 | Must have the following methods:
35 |
36 | * ``op()`` - operator
37 |
38 | The following notation is used to implement the algorithms:
39 |
40 | * ``x_old`` is used in place of :math:`x_{n}`.
41 | * ``x_new`` is used in place of :math:`x_{n+1}`.
42 | * ``x_prox`` is used in place of :math:`\tilde{x}_{n+1}`.
43 | * ``x_temp`` is used for intermediate operations.
44 |
45 | """
46 |
47 | from .forward_backward import FISTA, ForwardBackward, GenForwardBackward, POGM
48 | from .primal_dual import Condat
49 | from .gradient_descent import (
50 | ADAMGradOpt,
51 | AdaGenericGradOpt,
52 | GenericGradOpt,
53 | MomentumGradOpt,
54 | RMSpropGradOpt,
55 | SAGAOptGradOpt,
56 | VanillaGenericGradOpt,
57 | )
58 | from .admm import ADMM, FastADMM
59 |
60 | __all__ = [
61 | "FISTA",
62 | "ForwardBackward",
63 | "GenForwardBackward",
64 | "POGM",
65 | "Condat",
66 | "ADAMGradOpt",
67 | "AdaGenericGradOpt",
68 | "GenericGradOpt",
69 | "MomentumGradOpt",
70 | "RMSpropGradOpt",
71 | "SAGAOptGradOpt",
72 | "VanillaGenericGradOpt",
73 | "ADMM",
74 | "FastADMM",
75 | ]
76 |
--------------------------------------------------------------------------------
/.github/workflows/cd-build.yml:
--------------------------------------------------------------------------------
1 | name: CD
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | - main
8 |
9 | jobs:
10 |
11 | coverage:
12 | name: Deploy Coverage Results
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - name: Checkout
17 | uses: actions/checkout@v4
18 |
19 | - uses: actions/setup-python@v4
20 | with:
21 | python-version: "3.10"
22 | cache: pip
23 |
24 | - name: Install dependencies
25 | shell: bash -l {0}
26 | run: |
27 | python -m pip install --upgrade pip
28 | python -m pip install twine
29 | python -m pip install .[doc,test]
30 |
31 | - name: Run Tests
32 | shell: bash -l {0}
33 | run: |
34 | pytest
35 |
36 | - name: Check distribution
37 | shell: bash -l {0}
38 | run: |
39 | twine check dist/*
40 |
41 | - name: Upload coverage to Codecov
42 | uses: codecov/codecov-action@v1
43 | with:
44 | token: ${{ secrets.CODECOV_TOKEN }}
45 | file: coverage.xml
46 | flags: unittests
47 |
48 | api:
49 | name: Deploy API Documentation
50 | needs: coverage
51 | runs-on: ubuntu-latest
52 | if: success()
53 |
54 | steps:
55 | - name: Checkout
56 | uses: actions/checkout@v4
57 |
58 |
59 | - name: Install dependencies
60 | shell: bash -l {0}
61 | run: |
62 | conda install -c conda-forge pandoc
63 | python -m pip install --upgrade pip
64 | python -m pip install .[doc]
65 |
66 | - name: Build API documentation
67 | shell: bash -l {0}
68 | run: |
69 | sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
70 | sphinx-build -E docs/source docs/_build
71 |
72 | - name: Deploy API documentation
73 | uses: peaceiris/actions-gh-pages@v3.5.9
74 | with:
75 | github_token: ${{ secrets.GITHUB_TOKEN }}
76 | publish_dir: docs/_build
77 |
--------------------------------------------------------------------------------
/src/modopt/interface/log.py:
--------------------------------------------------------------------------------
1 | """LOGGING ROUTINES.
2 |
3 | This module contains methods for handing logging.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import logging
10 |
11 |
12 | def set_up_log(filename, verbose=True):
13 | """Set up log.
14 |
15 | This method sets up a basic log.
16 |
17 | Parameters
18 | ----------
19 | filename : str
20 | Log file name
21 | verbose : bool
22 | Option for verbose output (default is ``True``)
23 |
24 | Returns
25 | -------
26 | logging.Logger
27 | Logging instance
28 |
29 | """
30 | # Add file extension.
31 | filename = f"{filename}.log"
32 |
33 | if verbose:
34 | print("Preparing log file:", filename)
35 |
36 | # Capture warnings.
37 | logging.captureWarnings(True)
38 |
39 | # Set output format.
40 | formatter = logging.Formatter(
41 | fmt="%(asctime)s %(message)s",
42 | datefmt="%d/%m/%Y %H:%M:%S",
43 | )
44 |
45 | # Create file handler.
46 | fh = logging.FileHandler(filename=filename, mode="w")
47 | fh.setLevel(logging.DEBUG)
48 | fh.setFormatter(formatter)
49 |
50 | # Create log.
51 | log = logging.getLogger(filename)
52 | log.setLevel(logging.DEBUG)
53 | log.addHandler(fh)
54 |
55 | # Send opening message.
56 | log.info("The log file has been set-up.")
57 |
58 | return log
59 |
60 |
61 | def close_log(log, verbose=True):
62 | """Close log.
63 |
64 | This method closes and active logging.Logger instance.
65 |
66 | Parameters
67 | ----------
68 | log : logging.Logger
69 | Logging instance
70 | verbose : bool
71 | Option for verbose output (default is ``True``)
72 |
73 | """
74 | if verbose:
75 | print("Closing log file:", log.name)
76 |
77 | # Send closing message.
78 | log.info("The log file has been closed.")
79 |
80 | # Remove all handlers from log.
81 | for log_handler in log.handlers:
82 | log.removeHandler(log_handler)
83 |
--------------------------------------------------------------------------------
/src/modopt/signal/validation.py:
--------------------------------------------------------------------------------
1 | """VALIDATION ROUTINES.
2 |
3 | This module contains methods for testing signal and operator properties.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 |
11 |
12 | def transpose_test(
13 | operator,
14 | operator_t,
15 | x_shape,
16 | x_args=None,
17 | y_shape=None,
18 | y_args=None,
19 | rng=None,
20 | ):
21 | """Transpose test.
22 |
23 | This method tests two operators to see if they are the transpose of each
24 | other.
25 |
26 | Parameters
27 | ----------
28 | operator : callable
29 | Operator function
30 | operator_t : callable
31 | Transpose operator function
32 | x_shape : tuple
33 | Shape of operator input data
34 | x_args : tuple
35 | Arguments to be passed to operator (default is ``None``)
36 | y_shape : tuple, optional
37 | Shape of transpose operator input data (default is ``None``)
38 | y_args : tuple, optional
39 | Arguments to be passed to transpose operator (default is ``None``)
40 | rng: numpy.random.Generator or int or None (default is ``None``)
41 | Initialized random number generator or seed.
42 |
43 | Raises
44 | ------
45 | TypeError
46 | If input operators not callable
47 |
48 | Examples
49 | --------
50 | >>> import numpy as np
51 | >>> from modopt.signal.validation import transpose_test
52 | >>> np.random.seed(1)
53 | >>> a = np.random.ranf((3, 3))
54 | >>> transpose_test(lambda x, y: x.dot(y), lambda x, y: x.dot(y.T),
55 | ... a.shape, x_args=a)
56 | - | - | = 0.0
57 |
58 | """
59 | if not callable(operator) or not callable(operator_t):
60 | raise TypeError("The input operators must be callable functions.")
61 |
62 | if isinstance(y_shape, type(None)):
63 | y_shape = x_shape
64 |
65 | if isinstance(y_args, type(None)):
66 | y_args = x_args
67 |
68 | if not isinstance(rng, np.random.Generator):
69 | rng = np.random.default_rng(rng)
70 | # Generate random arrays.
71 | x_val = rng.random(x_shape)
72 | y_val = rng.random(y_shape)
73 |
74 | # Calculate
75 | mx_y = np.sum(np.multiply(operator(x_val, x_args), y_val))
76 |
77 | # Calculate
78 | x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args)))
79 |
80 | # Test the difference between the two.
81 | print(" - | - | =", np.abs(mx_y - x_mty))
82 |
--------------------------------------------------------------------------------
/src/modopt/opt/reweight.py:
--------------------------------------------------------------------------------
1 | """REWEIGHTING CLASSES.
2 |
3 | This module contains classes for reweighting optimisation implementations.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 |
11 | from modopt.base.types import check_float
12 |
13 |
14 | class cwbReweight:
15 | """Candes, Wakin and Boyd reweighting class.
16 |
17 | This class implements the reweighting scheme described in
18 | :cite:`candes2007`.
19 |
20 | Parameters
21 | ----------
22 | weights : numpy.ndarray
23 | Array of weights
24 | thresh_factor : float
25 | Threshold factor (default is ``1.0``)
26 |
27 | Examples
28 | --------
29 | >>> import numpy as np
30 | >>> from modopt.opt.reweight import cwbReweight
31 | >>> a = np.arange(9).reshape(3, 3).astype(float) + 1
32 | >>> rw = cwbReweight(a)
33 | >>> rw.weights
34 | array([[1., 2., 3.],
35 | [4., 5., 6.],
36 | [7., 8., 9.]])
37 | >>> rw.reweight(a)
38 | >>> rw.weights
39 | array([[0.5, 1. , 1.5],
40 | [2. , 2.5, 3. ],
41 | [3.5, 4. , 4.5]])
42 |
43 | """
44 |
45 | def __init__(self, weights, thresh_factor=1.0, verbose=False):
46 | self.weights = check_float(weights)
47 | self.original_weights = np.copy(self.weights)
48 | self.thresh_factor = check_float(thresh_factor)
49 | self._rw_num = 1
50 | self.verbose = verbose
51 |
52 | def reweight(self, input_data):
53 | r"""Reweight.
54 |
55 | This method implements the reweighting from section 4 in
56 | :cite:`candes2007`.
57 |
58 | Parameters
59 | ----------
60 | input_data : numpy.ndarray
61 | Input data
62 |
63 | Raises
64 | ------
65 | ValueError
66 | For invalid input shape
67 |
68 | Notes
69 | -----
70 | Reweighting implemented as:
71 |
72 | .. math::
73 |
74 | w = w \left( \frac{1}{1 + \frac{|x^w|}{n \sigma}} \right)
75 |
76 | where :math:`w` are the weights, :math:`x` is the ``input_data`` and
77 | :math:`n` is the ``thresh_factor``.
78 |
79 | """
80 | if self.verbose:
81 | print(f" - Reweighting: {self._rw_num}")
82 |
83 | self._rw_num += 1
84 |
85 | input_data = check_float(input_data)
86 |
87 | if input_data.shape != self.weights.shape:
88 | raise ValueError(
89 | "Input data must have the same shape as the initial weights.",
90 | )
91 |
92 | thresh_weights = self.thresh_factor * self.original_weights
93 |
94 | self.weights *= np.array(
95 | 1.0 / (1.0 + np.abs(input_data) / (thresh_weights)),
96 | )
97 |
--------------------------------------------------------------------------------
/docs/source/dependencies.rst:
--------------------------------------------------------------------------------
1 | Dependencies
2 | ============
3 |
4 | .. note::
5 |
6 | All packages required by ModOpt should be installed automatically. Optional
7 | packages, however, will need to be installed manually.
8 |
9 | Required Packages
10 | -----------------
11 |
12 | In order to use ModOpt the following packages must be installed:
13 |
14 | * |link-to-python| ``[>= 3.6]``
15 | * |link-to-metadata| ``[>=3.7.0]``
16 | * |link-to-numpy| ``[>=1.19.5]``
17 | * |link-to-scipy| ``[>=1.5.4]``
18 | * |link-to-progressbar| ``[>=3.53.1]``
19 |
20 | .. |link-to-python| raw:: html
21 |
22 | Python
24 |
25 | .. |link-to-metadata| raw:: html
26 |
27 | importlib_metadata
29 |
30 | .. |link-to-numpy| raw:: html
31 |
32 | Numpy
34 |
35 | .. |link-to-scipy| raw:: html
36 |
37 | Scipy
39 |
40 | .. |link-to-progressbar| raw:: html
41 |
42 | Progressbar 2
44 |
45 | Optional Packages
46 | -----------------
47 |
48 | The following packages can optionally be installed to add extra functionality:
49 |
50 | * |link-to-astropy|
51 | * |link-to-matplotlib|
52 | * |link-to-skimage|
53 | * |link-to-sklearn|
54 | * |link-to-termcolor|
55 |
56 | .. |link-to-astropy| raw:: html
57 |
58 | Astropy
60 |
61 | .. |link-to-matplotlib| raw:: html
62 |
63 | Matplotlib
65 |
66 | .. |link-to-skimage| raw:: html
67 |
68 | Scikit-Image
70 |
71 | .. |link-to-sklearn| raw:: html
72 |
73 | Scikit-Learn
75 |
76 | .. |link-to-termcolor| raw:: html
77 |
78 | Termcolor
80 |
81 | For GPU compliance the following packages can also be installed:
82 |
83 | * |link-to-cupy|
84 | * |link-to-torch|
85 | * |link-to-tf|
86 |
87 | .. |link-to-cupy| raw:: html
88 |
89 | CuPy
91 |
92 | .. |link-to-torch| raw:: html
93 |
94 | Torch
96 |
97 | .. |link-to-tf| raw:: html
98 |
99 | TensorFlow
101 |
102 | .. note::
103 |
104 | Note that none of these are required for running on a CPU.
105 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 | pytest.xml
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 | docs/source/fortuna.*
75 | docs/source/scripts.*
76 | docs/source/auto_examples/
77 | docs/source/*.nblink
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103 | __pypackages__/
104 |
105 | # Celery stuff
106 | celerybeat-schedule
107 | celerybeat.pid
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 | # pytype static type analyzer
140 | .pytype/
141 |
142 | # Cython debug symbols
143 | cython_debug/
144 |
--------------------------------------------------------------------------------
/src/modopt/signal/positivity.py:
--------------------------------------------------------------------------------
1 | """POSITIVITY.
2 |
3 | This module contains a function that retains only positive coefficients in
4 | an array.
5 |
6 | :Author: Samuel Farrens
7 |
8 | """
9 |
10 | import numpy as np
11 |
12 |
13 | def pos_thresh(input_data):
14 | """Positive Threshold.
15 |
16 | Keep only positive coefficients from input data.
17 |
18 | Parameters
19 | ----------
20 | input_data : int, float, list, tuple or numpy.ndarray
21 | Input data
22 |
23 | Returns
24 | -------
25 | int, float, or numpy.ndarray
26 | Positive coefficients
27 |
28 | """
29 | return input_data * (input_data > 0)
30 |
31 |
32 | def pos_recursive(input_data):
33 | """Positive Recursive.
34 |
35 | Run pos_thresh on input array or recursively for ragged nested arrays.
36 |
37 | Parameters
38 | ----------
39 | input_data : numpy.ndarray
40 | Input data
41 |
42 | Returns
43 | -------
44 | numpy.ndarray
45 | Positive coefficients
46 |
47 | """
48 | if input_data.dtype == "O":
49 | res = np.array([pos_recursive(elem) for elem in input_data], dtype="object")
50 |
51 | else:
52 | res = pos_thresh(input_data)
53 |
54 | return res
55 |
56 |
57 | def positive(input_data, ragged=False):
58 | """Positivity operator.
59 |
60 | This method preserves only the positive coefficients of the input data, all
61 | negative coefficients are set to zero.
62 |
63 | Parameters
64 | ----------
65 | input_data : int, float, or numpy.ndarray
66 | Input data
67 | ragged : bool, optional
68 | Specify if the input_data is a ragged nested array
69 | (defaul is ``False``)
70 |
71 | Returns
72 | -------
73 | int or float, or numpy.ndarray
74 | Positive coefficients
75 |
76 | Raises
77 | ------
78 | TypeError
79 | For invalid input type.
80 |
81 | Examples
82 | --------
83 | >>> import numpy as np
84 | >>> from modopt.signal.positivity import positive
85 | >>> a = np.arange(9).reshape(3, 3) - 5
86 | >>> a
87 | array([[-5, -4, -3],
88 | [-2, -1, 0],
89 | [ 1, 2, 3]])
90 | >>> positive(a)
91 | array([[0, 0, 0],
92 | [0, 0, 0],
93 | [1, 2, 3]])
94 |
95 | """
96 | if not isinstance(input_data, (int, float, list, tuple, np.ndarray)):
97 | raise TypeError(
98 | "Invalid data type, input must be `int`, `float`, `list`, "
99 | + "`tuple` or `np.ndarray`.",
100 | )
101 |
102 | if isinstance(input_data, (int, float)):
103 | return pos_thresh(input_data)
104 |
105 | if ragged:
106 | input_data = np.array(input_data, dtype="object")
107 |
108 | else:
109 | input_data = np.array(input_data)
110 |
111 | return pos_recursive(input_data)
112 |
--------------------------------------------------------------------------------
/src/modopt/signal/filter.py:
--------------------------------------------------------------------------------
1 | """FILTER ROUTINES.
2 |
3 | This module contains methods for distance measurements in cosmology.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 |
11 | from modopt.base.types import check_float
12 |
13 |
14 | def gaussian_filter(data_point, sigma, norm=True):
15 | """Gaussian filter.
16 |
17 | This method implements a Gaussian filter.
18 |
19 | Parameters
20 | ----------
21 | data_point : float
22 | Input data point
23 | sigma : float
24 | Standard deviation (filter scale)
25 | norm : bool
26 | Option to return normalised data (default is ``True``)
27 |
28 | Returns
29 | -------
30 | float
31 | Gaussian filtered data point
32 |
33 | Examples
34 | --------
35 | >>> from modopt.signal.filter import gaussian_filter
36 | >>> gaussian_filter(1, 1)
37 | 0.24197072451914337
38 |
39 | >>> gaussian_filter(1, 1, False)
40 | 0.6065306597126334
41 |
42 | """
43 | data_point = check_float(data_point)
44 | sigma = check_float(sigma)
45 |
46 | numerator = np.exp(-0.5 * (data_point / sigma) ** 2)
47 |
48 | if norm:
49 | return numerator / (np.sqrt(2 * np.pi) * sigma)
50 |
51 | return numerator
52 |
53 |
54 | def mex_hat(data_point, sigma):
55 | """Mexican hat.
56 |
57 | This method implements a Mexican hat (or Ricker) wavelet.
58 |
59 | Parameters
60 | ----------
61 | data_point : float
62 | Input data point
63 | sigma : float
64 | Standard deviation (filter scale)
65 |
66 | Returns
67 | -------
68 | float
69 | Mexican hat filtered data point
70 |
71 | Examples
72 | --------
73 | >>> from modopt.signal.filter import mex_hat
74 | >>> round(mex_hat(2, 1), 15)
75 | -0.352139052257134
76 |
77 | """
78 | data_point = check_float(data_point)
79 | sigma = check_float(sigma)
80 |
81 | xs = (data_point / sigma) ** 2
82 | factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25
83 |
84 | return factor * (1 - xs) * np.exp(-0.5 * xs)
85 |
86 |
87 | def mex_hat_dir(data_gauss, data_mex, sigma):
88 | """Directional Mexican hat.
89 |
90 | This method implements a directional Mexican hat (or Ricker) wavelet.
91 |
92 | Parameters
93 | ----------
94 | data_gauss : float
95 | Input data point for Gaussian
96 | data_mex : float
97 | Input data point for Mexican hat
98 | sigma : float
99 | Standard deviation (filter scale)
100 |
101 | Returns
102 | -------
103 | float
104 | Directional Mexican hat filtered data point
105 |
106 | Examples
107 | --------
108 | >>> from modopt.signal.filter import mex_hat_dir
109 | >>> round(mex_hat_dir(1, 2, 1), 16)
110 | 0.1760695261285668
111 |
112 | """
113 | data_gauss = check_float(data_gauss)
114 | sigma = check_float(sigma)
115 |
116 | return -0.5 * (data_gauss / sigma) ** 2 * mex_hat(data_mex, sigma)
117 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, gender identity and expression, level of experience,
9 | education, socio-economic status, nationality, personal appearance, race,
10 | religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at {{ email }}. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
--------------------------------------------------------------------------------
/src/modopt/interface/errors.py:
--------------------------------------------------------------------------------
1 | """ERROR HANDLING ROUTINES.
2 |
3 | This module contains methods for handing warnings and errors.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import os
10 | import sys
11 | import warnings
12 |
13 | try:
14 | from termcolor import colored
15 | except ImportError:
16 | import_fail = True
17 | else:
18 | import_fail = False
19 |
20 |
21 | def warn(warn_string, log=None):
22 | """Warning.
23 |
24 | This method creates custom warning messages.
25 |
26 | Parameters
27 | ----------
28 | warn_string : str
29 | Warning message string
30 | log : logging.Logger, optional
31 | Logging structure instance (default is ``None``)
32 |
33 | """
34 | if import_fail:
35 | warn_txt = "WARNING"
36 | else:
37 | warn_txt = colored("WARNING", "yellow")
38 |
39 | # Print warning to stdout.
40 | sys.stderr.write(f"{warn_txt}: {warn_string}\n")
41 |
42 | # Check if a logging structure is provided.
43 | if not isinstance(log, type(None)):
44 | warnings.warn(warn_string, stacklevel=2)
45 |
46 |
47 | def catch_error(exception, log=None):
48 | """Catch error.
49 |
50 | This method catches errors and prints them to the terminal. It also saves
51 | the errors to a log if provided.
52 |
53 | Parameters
54 | ----------
55 | exception : str
56 | Exception message string
57 | log : logging.Logger, optional
58 | Logging structure instance (default is ``None``)
59 |
60 | """
61 | if import_fail:
62 | err_txt = "ERROR"
63 | else:
64 | err_txt = colored("ERROR", "red")
65 |
66 | # Print exception to stdout.
67 | stream_txt = f"{err_txt}: {exception}\n"
68 | sys.stderr.write(stream_txt)
69 |
70 | # Check if a logging structure is provided.
71 | if not isinstance(log, type(None)):
72 | log_txt = f"ERROR: {exception}\n"
73 | log.exception(log_txt)
74 |
75 |
76 | def file_name_error(file_name):
77 | """File name error.
78 |
79 | This method checks if the input file name is valid.
80 |
81 | Parameters
82 | ----------
83 | file_name : str
84 | File name string
85 |
86 | Raises
87 | ------
88 | IOError
89 | If file name not specified or file not found
90 |
91 | """
92 | if file_name == "" or file_name[0][0] == "-":
93 | raise OSError("Input file name not specified.")
94 |
95 | elif not os.path.isfile(file_name):
96 | raise OSError(f"Input file name {file_name} not found!")
97 |
98 |
99 | def is_exe(fpath):
100 | """Is Executable.
101 |
102 | Check if the input file path corresponds to an executable on the system.
103 |
104 | Parameters
105 | ----------
106 | fpath : str
107 | File path
108 |
109 | Returns
110 | -------
111 | bool
112 | True if file path exists
113 |
114 | """
115 | return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
116 |
117 |
118 | def is_executable(exe_name):
119 | """Check if Input is Executable.
120 |
121 | This method checks if the input executable exists.
122 |
123 | Parameters
124 | ----------
125 | exe_name : str
126 | Executable name
127 |
128 | Raises
129 | ------
130 | TypeError
131 | For invalid input type
132 | IOError
133 | For invalid system executable
134 |
135 | """
136 | if not isinstance(exe_name, str):
137 | raise TypeError("Executable name must be a string.")
138 |
139 | fpath, fname = os.path.split(exe_name)
140 |
141 | if fpath:
142 | res = is_exe(exe_name)
143 |
144 | else:
145 | res = any(
146 | is_exe(os.path.join(path, exe_name))
147 | for path in os.environ["PATH"].split(os.pathsep)
148 | )
149 |
150 | if not res:
151 | message = "{0} does not appear to be a valid executable on this system."
152 | raise OSError(message.format(exe_name))
153 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ModOpt
2 |
3 |
4 |
5 | | Usage | Development | Release |
6 | | ----- | ----------- | ------- |
7 | | [](https://cea-cosmic.github.io/ModOpt/) | [](https://github.com/CEA-COSMIC/modopt/actions?query=workflow%3ACI) | [](https://github.com/CEA-COSMIC/modopt/releases/latest) |
8 | | [](https://github.com/CEA-COSMIC/modopt/blob/master/LICENCE.txt) | [](https://github.com/CEA-COSMIC/modopt/actions?query=workflow%3ACD) | [](https://pypi.org/project/modopt/) |
9 | | [](https://github.com/wemake-services/wemake-python-styleguide) | [](https://codecov.io/gh/CEA-COSMIC/modopt) | [](https://www.python.org/downloads/source/) |
10 | | [](https://github.com/CEA-COSMIC/modopt/blob/master/CONTRIBUTING.md) | [](https://www.codefactor.io/repository/github/CEA-COSMIC/modopt) | |
11 | | [](https://github.com/CEA-COSMIC/modopt/blob/master/CODE_OF_CONDUCT.md) | [](https://pyup.io/repos/github/CEA-COSMIC/ModOpt/) | |
12 |
13 | ModOpt is a series of **Modular Optimisation** tools for solving inverse problems.
14 |
15 | See [documentation](https://CEA-COSMIC.github.io/ModOpt/) for more details.
16 |
17 | ## Installation
18 |
19 | To install using `pip` run the following command:
20 |
21 | ```bash
22 | $ pip install modopt
23 | ```
24 |
25 | To clone the ModOpt repository from GitHub run the following command:
26 |
27 | ```bash
28 | $ git clone https://github.com/CEA-COSMIC/ModOpt.git
29 | ```
30 |
31 | ## Dependencies
32 |
33 | All packages required by ModOpt should be installed automatically. Optional packages, however, will need to be installed manually.
34 |
35 | ### Required Packages
36 |
37 | In order to run the code in this repository the following packages must be
38 | installed:
39 |
40 | * [Python](https://www.python.org/) [> 3.7]
41 | * [importlib_metadata](https://importlib-metadata.readthedocs.io/en/latest/) [==3.7.0]
42 | * [Numpy](http://www.numpy.org/) [==1.19.5]
43 | * [Scipy](http://www.scipy.org/) [==1.5.4]
44 | * [tqdm](https://tqdm.github.io/) [>=4.64.0]
45 |
46 | ### Optional Packages
47 |
48 | The following packages can optionally be installed to add extra functionality:
49 |
50 | * [Astropy](http://www.astropy.org/)
51 | * [Matplotlib](http://matplotlib.org/)
52 | * [Scikit-Image](https://scikit-image.org/)
53 | * [Scikit-Learn](https://scikit-learn.org/)
54 | * [Termcolor](https://pypi.python.org/pypi/termcolor)
55 |
56 | For (partial) GPU compliance the following packages can also be installed.
57 | Note that none of these are required for running on a CPU.
58 |
59 | * [CuPy](https://cupy.dev/)
60 | * [Torch](https://pytorch.org/)
61 | * [TensorFlow](https://www.tensorflow.org/)
62 |
63 | ## Citation
64 |
65 | If you use ModOpt in a scientific publication, we would appreciate citations to the following paper:
66 |
67 | [PySAP: Python Sparse Data Analysis Package for multidisciplinary image processing](https://www.sciencedirect.com/science/article/pii/S2213133720300561), S. Farrens et al., Astronomy and Computing 32, 2020
68 |
69 | The BibTeX citation is the following:
70 | ```
71 | @Article{farrens2020pysap,
72 | title={{PySAP: Python Sparse Data Analysis Package for multidisciplinary image processing}},
73 | author={Farrens, S and Grigis, A and El Gueddari, L and Ramzi, Z and Chaithya, GR and Starck, S and Sarthou, B and Cherkaoui, H and Ciuciu, P and Starck, J-L},
74 | journal={Astronomy and Computing},
75 | volume={32},
76 | pages={100402},
77 | year={2020},
78 | publisher={Elsevier}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/src/modopt/base/types.py:
--------------------------------------------------------------------------------
1 | """TYPE HANDLING ROUTINES.
2 |
3 | This module contains methods for handing object types.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 | from modopt.interface.errors import warn
11 |
12 |
13 | def check_callable(input_obj):
14 | """Check input object is callable.
15 |
16 | This method checks if the input operator is a callable funciton and
17 | optionally adds support for arguments and keyword arguments if not already
18 | provided.
19 |
20 | Parameters
21 | ----------
22 | input_obj : callable
23 | Callable function
24 |
25 | Raises
26 | ------
27 | TypeError
28 | For invalid input type
29 | """
30 | if not callable(input_obj):
31 | raise TypeError("The input object must be a callable function.")
32 | return input_obj
33 |
34 |
35 | def check_float(input_obj):
36 | """Check Float.
37 |
38 | Check if input object is a float or a numpy.ndarray of floats, if not
39 | convert.
40 |
41 | Parameters
42 | ----------
43 | input_obj : any
44 | Input value
45 |
46 | Returns
47 | -------
48 | float or numpy.ndarray
49 | Input value as a float
50 |
51 | Raises
52 | ------
53 | TypeError
54 | For invalid input type
55 |
56 | Examples
57 | --------
58 | >>> import numpy as np
59 | >>> from modopt.base.types import check_float
60 | >>> a = np.arange(5)
61 | >>> a
62 | array([0, 1, 2, 3, 4])
63 | >>> check_float(a)
64 | array([0., 1., 2., 3., 4.])
65 |
66 | See Also
67 | --------
68 | check_int : related function
69 |
70 | """
71 | if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
72 | raise TypeError("Invalid input type.")
73 | if isinstance(input_obj, int):
74 | input_obj = float(input_obj)
75 | elif isinstance(input_obj, (list, tuple)):
76 | input_obj = np.array(input_obj, dtype=float)
77 | elif isinstance(input_obj, np.ndarray) and (
78 | not np.issubdtype(input_obj.dtype, np.floating)
79 | ):
80 | input_obj = input_obj.astype(float)
81 |
82 | return input_obj
83 |
84 |
85 | def check_int(input_obj):
86 | """Check Integer.
87 |
88 | Check if input value is an int or a np.ndarray of ints, if not convert.
89 |
90 | Parameters
91 | ----------
92 | input_obj : any
93 | Input value
94 |
95 | Returns
96 | -------
97 | int or numpy.ndarray
98 | Input value as an integer
99 |
100 | Raises
101 | ------
102 | TypeError
103 | For invalid input type
104 |
105 | Examples
106 | --------
107 | >>> import numpy as np
108 | >>> from modopt.base.types import check_int
109 | >>> a = np.arange(5).astype(float)
110 | >>> a
111 | array([0., 1., 2., 3., 4.])
112 | >>> check_int(a)
113 | array([0, 1, 2, 3, 4])
114 |
115 | See Also
116 | --------
117 | check_float : related function
118 |
119 | """
120 | if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
121 | raise TypeError("Invalid input type.")
122 | if isinstance(input_obj, float):
123 | input_obj = int(input_obj)
124 | elif isinstance(input_obj, (list, tuple)):
125 | input_obj = np.array(input_obj, dtype=int)
126 | elif isinstance(input_obj, np.ndarray) and (
127 | not np.issubdtype(input_obj.dtype, np.integer)
128 | ):
129 | input_obj = input_obj.astype(int)
130 |
131 | return input_obj
132 |
133 |
134 | def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
135 | """Check Numpy ND-Array.
136 |
137 | Check if input object is a numpy array.
138 |
139 | Parameters
140 | ----------
141 | input_obj : numpy.ndarray
142 | Input object
143 | dtype : type
144 | Numpy ndarray data type
145 | writeable : bool
146 | Option to make array immutable
147 | verbose : bool
148 | Verbosity option
149 |
150 | Raises
151 | ------
152 | TypeError
153 | For invalid input type
154 | TypeError
155 | For invalid numpy.ndarray dtype
156 |
157 | """
158 | if not isinstance(input_obj, np.ndarray):
159 | raise TypeError("Input is not a numpy array.")
160 |
161 | if (not isinstance(dtype, type(None))) and (
162 | not np.issubdtype(input_obj.dtype, dtype)
163 | ):
164 | raise (
165 | TypeError(
166 | f"The numpy array elements are not of type: {dtype}",
167 | ),
168 | )
169 |
170 | if not writeable and verbose and input_obj.flags.writeable:
171 | warn("Making input data immutable.")
172 |
173 | input_obj.flags.writeable = writeable
174 |
--------------------------------------------------------------------------------
/examples/example_lasso_forward_backward.py:
--------------------------------------------------------------------------------
1 | """
2 | Solving the LASSO Problem with the Forward Backward Algorithm.
3 | ==============================================================
4 |
5 | This an example to show how to solve an example LASSO Problem
6 | using the Forward-Backward Algorithm.
7 |
8 | In this example we are going to use:
9 | - Modopt Operators (Linear, Gradient, Proximal)
10 | - Modopt implementation of solvers
11 | - Modopt Metric API.
12 | TODO: add reference to LASSO paper.
13 | """
14 |
15 | import numpy as np
16 | import matplotlib.pyplot as plt
17 |
18 | from modopt.opt.algorithms import ForwardBackward, POGM
19 | from modopt.opt.cost import costObj
20 | from modopt.opt.linear import LinearParent, Identity
21 | from modopt.opt.gradient import GradBasic
22 | from modopt.opt.proximity import SparseThreshold
23 | from modopt.math.matrix import PowerMethod
24 | from modopt.math.stats import mse
25 |
26 | # %%
27 | # Here we create a instance of the LASSO Problem
28 |
29 | BETA_TRUE = np.array(
30 | [3.0, 1.5, 0, 0, 2, 0, 0, 0]
31 | ) # 8 original values from lLASSO Paper
32 | DIM = len(BETA_TRUE)
33 |
34 |
35 | rng = np.random.default_rng()
36 | sigma_noise = 1
37 | obs = 20
38 | # create a measurement matrix with decaying covariance matrix.
39 | cov = 0.4 ** abs((np.arange(DIM) * np.ones((DIM, DIM))).T - np.arange(DIM))
40 | x = rng.multivariate_normal(np.zeros(DIM), cov, obs)
41 |
42 | y = x @ BETA_TRUE
43 | y_noise = y + (sigma_noise * np.random.standard_normal(obs))
44 |
45 |
46 | # %%
47 | # Next we create Operators for solving the problem.
48 |
49 | # MatrixOperator could also work here.
50 | lin_op = LinearParent(lambda b: x @ b, lambda bb: x.T @ bb)
51 | grad_op = GradBasic(y_noise, op=lin_op.op, trans_op=lin_op.adj_op)
52 |
53 | prox_op = SparseThreshold(Identity(), 1, thresh_type="soft")
54 |
55 | # %%
56 | # In order to get the best convergence rate, we first determine the Lipschitz constant of the gradient Operator
57 | #
58 |
59 | calc_lips = PowerMethod(grad_op.trans_op_op, 8, data_type="float32", auto_run=True)
60 | lip = calc_lips.spec_rad
61 | print("lipschitz constant:", lip)
62 |
63 | # %%
64 | # Solving using FISTA algorithm
65 | # -----------------------------
66 | #
67 | # TODO: Add description/Reference of FISTA.
68 |
69 | cost_op_fista = costObj([grad_op, prox_op], verbose=False)
70 |
71 | fb_fista = ForwardBackward(
72 | np.zeros(8),
73 | beta_param=1 / lip,
74 | grad=grad_op,
75 | prox=prox_op,
76 | cost=cost_op_fista,
77 | metric_call_period=1,
78 | auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
79 | )
80 |
81 | fb_fista.iterate()
82 |
83 | # %%
84 | # After the run we can have a look at the results
85 |
86 | print(fb_fista.x_final)
87 | mse_fista = mse(fb_fista.x_final, BETA_TRUE)
88 | plt.stem(fb_fista.x_final, label="estimation", linefmt="C0-")
89 | plt.stem(BETA_TRUE, label="reference", linefmt="C1-")
90 | plt.legend()
91 | plt.title(f"FISTA Estimation MSE={mse_fista:.4f}")
92 |
93 | # sphinx_gallery_start_ignore
94 | assert mse(fb_fista.x_final, BETA_TRUE) < 1
95 | # sphinx_gallery_end_ignore
96 |
97 |
98 | # %%
99 | # Solving Using the POGM Algorithm
100 | # --------------------------------
101 | #
102 | # TODO: Add description/Reference to POGM.
103 |
104 |
105 | cost_op_pogm = costObj([grad_op, prox_op], verbose=False)
106 |
107 | fb_pogm = POGM(
108 | np.zeros(8),
109 | np.zeros(8),
110 | np.zeros(8),
111 | np.zeros(8),
112 | beta_param=1 / lip,
113 | grad=grad_op,
114 | prox=prox_op,
115 | cost=cost_op_pogm,
116 | metric_call_period=1,
117 | auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
118 | )
119 |
120 | fb_pogm.iterate()
121 |
122 | # %%
123 | # After the run we can have a look at the results
124 |
125 | print(fb_pogm.x_final)
126 | mse_pogm = mse(fb_pogm.x_final, BETA_TRUE)
127 |
128 | plt.stem(fb_pogm.x_final, label="estimation", linefmt="C0-")
129 | plt.stem(BETA_TRUE, label="reference", linefmt="C1-")
130 | plt.legend()
131 | plt.title(f"FISTA Estimation MSE={mse_pogm:.4f}")
132 | #
133 | # sphinx_gallery_start_ignore
134 | assert mse(fb_pogm.x_final, BETA_TRUE) < 1
135 | # sphinx_gallery_end_ignore
136 |
137 | # %%
138 | # Comparing the Two algorithms
139 | # ----------------------------
140 |
141 | plt.figure()
142 | plt.semilogy(cost_op_fista._cost_list, label="FISTA convergence")
143 | plt.semilogy(cost_op_pogm._cost_list, label="POGM convergence")
144 | plt.xlabel("iterations")
145 | plt.ylabel("Cost Function")
146 | plt.legend()
147 | plt.show()
148 |
149 |
150 | # %%
151 | # We can see that the two algorithm converges quickly, and POGM requires less iterations.
152 | # However the POGM iterations are more costly, so a proper benchmark with time measurement is needed.
153 | # Check the benchopt benchmark for more details.
154 |
--------------------------------------------------------------------------------
/src/modopt/math/convolve.py:
--------------------------------------------------------------------------------
1 | """CONVOLUTION ROUTINES.
2 |
3 | This module contains methods for convolution.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 | import scipy.signal
11 |
12 | from modopt.base.np_adjust import rotate_stack
13 | from modopt.interface.errors import warn
14 |
15 | try:
16 | from astropy.convolution import convolve_fft
17 | except ImportError: # pragma: no cover
18 | import_astropy = False
19 | warn("astropy not found, will default to scipy for convolution")
20 | else:
21 | import_astropy = True
22 | try:
23 | import pyfftw
24 | except ImportError: # pragma: no cover
25 | pass
26 | else: # pragma: no cover
27 | scipy.fftpack = pyfftw.interfaces.scipy_fftpack
28 | warn('Using pyFFTW "monkey patch" for scipy.fftpack')
29 |
30 |
31 | def convolve(input_data, kernel, method="scipy"):
32 | """Convolve data with kernel.
33 |
34 | This method convolves the input data with a given kernel using FFT and
35 | is the default convolution used for all routines.
36 |
37 | Parameters
38 | ----------
39 | input_data : numpy.ndarray
40 | Input data array, normally a 2D image
41 | kernel : numpy.ndarray
42 | Input kernel array, normally a 2D kernel
43 | method : {'scipy', 'astropy'}, optional
44 | Convolution method (default is ``'scipy'``)
45 |
46 | Returns
47 | -------
48 | numpy.ndarray
49 | Convolved data
50 |
51 | Raises
52 | ------
53 | ValueError
54 | If `data` and `kernel` do not have the same number of dimensions
55 | ValueError
56 | If `method` is not 'astropy' or 'scipy'
57 |
58 | Examples
59 | --------
60 | >>> import numpy as np
61 | >>> from modopt.math.convolve import convolve
62 | >>> a = np.arange(9).reshape(3, 3)
63 | >>> b = a + 10
64 | >>> convolve(a, b)
65 | array([[ 86., 170., 146.],
66 | [246., 444., 354.],
67 | [290., 494., 374.]])
68 |
69 | >>> convolve(a, b, method='astropy')
70 | array([[534., 525., 534.],
71 | [453., 444., 453.],
72 | [534., 525., 534.]])
73 |
74 | See Also
75 | --------
76 | scipy.signal.fftconvolve : scipy FFT convolution
77 | astropy.convolution.convolve_fft : astropy FFT convolution
78 |
79 | """
80 | if input_data.ndim != kernel.ndim:
81 | raise ValueError("Data and kernel must have the same dimensions.")
82 |
83 | if method not in {"astropy", "scipy"}:
84 | raise ValueError('Invalid method. Options are "astropy" or "scipy".')
85 |
86 | if not import_astropy: # pragma: no cover
87 | method = "scipy"
88 |
89 | if method == "astropy":
90 | return convolve_fft(
91 | input_data,
92 | kernel,
93 | boundary="wrap",
94 | crop=False,
95 | nan_treatment="fill",
96 | normalize_kernel=False,
97 | )
98 |
99 | elif method == "scipy":
100 | return scipy.signal.fftconvolve(input_data, kernel, mode="same")
101 |
102 |
103 | def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"):
104 | """Convolve stack of data with stack of kernels.
105 |
106 | This method convolves the input data with a given kernel using FFT and
107 | is the default convolution used for all routines.
108 |
109 | Parameters
110 | ----------
111 | input_data : numpy.ndarray
112 | Input data array, normally a 2D image
113 | kernel : numpy.ndarray
114 | Input kernel array, normally a 2D kernel
115 | rot_kernel : bool
116 | Option to rotate kernels by 180 degrees (default is ``False``)
117 | method : {'astropy', 'scipy'}, optional
118 | Convolution method (default is ``'scipy'``)
119 |
120 | Returns
121 | -------
122 | numpy.ndarray
123 | Convolved data
124 |
125 | Examples
126 | --------
127 | >>> import numpy as np
128 | >>> from modopt.math.convolve import convolve_stack
129 | >>> a = np.arange(18).reshape(2, 3, 3)
130 | >>> b = a + 10
131 | >>> convolve_stack(a, b, method='astropy')
132 | array([[[ 534., 525., 534.],
133 | [ 453., 444., 453.],
134 | [ 534., 525., 534.]],
135 |
136 | [[2721., 2712., 2721.],
137 | [2640., 2631., 2640.],
138 | [2721., 2712., 2721.]]])
139 |
140 | >>> convolve_stack(a, b, method='astropy', rot_kernel=True)
141 | array([[[ 474., 483., 474.],
142 | [ 555., 564., 555.],
143 | [ 474., 483., 474.]],
144 |
145 | [[2661., 2670., 2661.],
146 | [2742., 2751., 2742.],
147 | [2661., 2670., 2661.]]])
148 |
149 | See Also
150 | --------
151 | convolve : The convolution function called by convolve_stack
152 |
153 | """
154 | if rot_kernel:
155 | kernel = rotate_stack(kernel)
156 |
157 | return np.array(
158 | [
159 | convolve(data_i, kernel_i, method=method)
160 | for data_i, kernel_i in zip(input_data, kernel)
161 | ]
162 | )
163 |
--------------------------------------------------------------------------------
/src/modopt/signal/noise.py:
--------------------------------------------------------------------------------
1 | """NOISE ROUTINES.
2 |
3 | This module contains methods for adding and removing noise from data.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 |
11 | from modopt.base.backend import get_array_module
12 |
13 |
14 | def add_noise(input_data, sigma=1.0, noise_type="gauss", rng=None):
15 | """Add noise to data.
16 |
17 | This method adds Gaussian or Poisson noise to the input data.
18 |
19 | Parameters
20 | ----------
21 | input_data : numpy.ndarray, list or tuple
22 | Input data array
23 | sigma : float or list, optional
24 | Standard deviation of the noise to be added (``'gauss'`` only,
25 | default is ``1.0``)
26 | noise_type : {'gauss', 'poisson'}
27 | Type of noise to be added (default is ``'gauss'``)
28 | rng: np.random.Generator or int
29 | A Random number generator or a seed to initialize one.
30 |
31 |
32 | Returns
33 | -------
34 | numpy.ndarray
35 | Input data with added noise
36 |
37 | Raises
38 | ------
39 | ValueError
40 | If ``noise_type`` is not ``'gauss'`` or ``'poisson'``
41 | ValueError
42 | If number of ``sigma`` values does not match the first dimension of the
43 | input data
44 |
45 | Examples
46 | --------
47 | >>> import numpy as np
48 | >>> from modopt.signal.noise import add_noise
49 | >>> x = np.arange(9).reshape(3, 3).astype(float)
50 | >>> x
51 | array([[0., 1., 2.],
52 | [3., 4., 5.],
53 | [6., 7., 8.]])
54 | >>> np.random.seed(1)
55 | >>> add_noise(x, noise_type='poisson')
56 | array([[ 0., 2., 2.],
57 | [ 4., 5., 10.],
58 | [11., 15., 18.]])
59 |
60 | >>> import numpy as np
61 | >>> from modopt.signal.noise import add_noise
62 | >>> x = np.zeros(5)
63 | >>> x
64 | array([0., 0., 0., 0., 0.])
65 | >>> np.random.seed(1)
66 | >>> add_noise(x, sigma=2.0)
67 | array([ 3.24869073, -1.22351283, -1.0563435 , -2.14593724, 1.73081526])
68 |
69 | """
70 | if not isinstance(rng, np.random.Generator):
71 | rng = np.random.default_rng(rng)
72 |
73 | input_data = np.array(input_data)
74 |
75 | if noise_type not in {"gauss", "poisson"}:
76 | raise ValueError(
77 | 'Invalid noise type. Options are "gauss" or "poisson"',
78 | )
79 |
80 | if isinstance(sigma, (list, tuple, np.ndarray)):
81 | if len(sigma) != input_data.shape[0]:
82 | raise ValueError(
83 | "Number of sigma values must match first dimension of input " + "data",
84 | )
85 |
86 | if noise_type == "gauss":
87 | random = rng.standard_normal(input_data.shape)
88 |
89 | elif noise_type == "poisson":
90 | random = rng.poisson(np.abs(input_data))
91 |
92 | if isinstance(sigma, (int, float)):
93 | return input_data + sigma * random
94 |
95 | noise = np.array([sig * rand for sig, rand in zip(sigma, random)])
96 |
97 | return input_data + noise
98 |
99 |
100 | def thresh(input_data, threshold, threshold_type="hard"):
101 | r"""Threshold data.
102 |
103 | This method perfoms hard or soft thresholding on the input data.
104 |
105 | Parameters
106 | ----------
107 | input_data : numpy.ndarray, list or tuple
108 | Input data array
109 | threshold : float or numpy.ndarray
110 | Threshold level(s)
111 | threshold_type : {'hard', 'soft'}
112 | Type of noise to be added (default is ``'hard'``)
113 |
114 | Returns
115 | -------
116 | numpy.ndarray
117 | Thresholded data
118 |
119 | Raises
120 | ------
121 | ValueError
122 | If ``threshold_type`` is not ``'hard'`` or ``'soft'``
123 |
124 | Notes
125 | -----
126 | Implements one of the following two equations:
127 |
128 | * Hard Threshold
129 | .. math::
130 | \mathrm{HT}_\lambda(x) =
131 | \begin{cases}
132 | x & \text{if } |x|\geq\lambda \\
133 | 0 & \text{otherwise}
134 | \end{cases}
135 |
136 | * Soft Threshold
137 | .. math::
138 | \mathrm{ST}_\lambda(x) =
139 | \begin{cases}
140 | x-\lambda\text{sign}(x) & \text{if } |x|\geq\lambda \\
141 | 0 & \text{otherwise}
142 | \end{cases}
143 |
144 | Examples
145 | --------
146 | >>> import numpy as np
147 | >>> from modopt.signal.noise import thresh
148 | >>> np.random.seed(1)
149 | >>> x = np.random.randint(-9, 9, 10)
150 | >>> x
151 | array([-4, 2, 3, -1, 0, 2, -4, 6, -9, 7])
152 | >>> thresh(x, 4)
153 | array([-4, 0, 0, 0, 0, 0, -4, 6, -9, 7])
154 |
155 | >>> import numpy as np
156 | >>> from modopt.signal.noise import thresh
157 | >>> np.random.seed(1)
158 | >>> x = np.random.ranf((3, 3))
159 | >>> x
160 | array([[4.17022005e-01, 7.20324493e-01, 1.14374817e-04],
161 | [3.02332573e-01, 1.46755891e-01, 9.23385948e-02],
162 | [1.86260211e-01, 3.45560727e-01, 3.96767474e-01]])
163 | >>> thresh(x, 0.2, threshold_type='soft')
164 | array([[0.217022 , 0.52032449, 0. ],
165 | [0.10233257, 0. , 0. ],
166 | [0. , 0.14556073, 0.19676747]])
167 |
168 | """
169 | xp = get_array_module(input_data)
170 |
171 | input_data = xp.array(input_data)
172 |
173 | if threshold_type not in {"hard", "soft"}:
174 | raise ValueError(
175 | 'Invalid threshold type. Options are "hard" or "soft"',
176 | )
177 |
178 | if threshold_type == "soft":
179 | denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data))
180 | max_value = xp.maximum((1.0 - threshold / denominator), 0)
181 |
182 | return xp.around(max_value * input_data, decimals=15)
183 |
184 | return input_data * (xp.abs(input_data) >= threshold)
185 |
--------------------------------------------------------------------------------
/src/modopt/base/backend.py:
--------------------------------------------------------------------------------
1 | """BACKEND MODULE.
2 |
3 | This module contains methods for GPU Compatiblity.
4 |
5 | :Author: Chaithya G R
6 |
7 | """
8 |
9 | from importlib import util
10 |
11 | import numpy as np
12 |
13 | from modopt.interface.errors import warn
14 |
15 | try:
16 | import torch
17 | from torch.utils.dlpack import from_dlpack as torch_from_dlpack
18 | from torch.utils.dlpack import to_dlpack as torch_to_dlpack
19 |
20 | except ImportError: # pragma: no cover
21 | import_torch = False
22 | else:
23 | import_torch = True
24 |
25 | # Handle the compatibility with variable
26 | LIBRARIES = {
27 | "cupy": None,
28 | "tensorflow": None,
29 | "numpy": np,
30 | }
31 |
32 | if util.find_spec("cupy") is not None:
33 | try:
34 | import cupy as cp
35 |
36 | LIBRARIES["cupy"] = cp
37 | except ImportError:
38 | pass
39 |
40 | if util.find_spec("tensorflow") is not None:
41 | try:
42 | from tensorflow.experimental import numpy as tnp
43 |
44 | LIBRARIES["tensorflow"] = tnp
45 | except ImportError:
46 | pass
47 |
48 |
49 | def get_backend(backend):
50 | """Get backend.
51 |
52 | Returns the backend module for input specified by string.
53 |
54 | Parameters
55 | ----------
56 | backend: str
57 | String holding the backend name. One of ``'tensorflow'``,
58 | ``'numpy'`` or ``'cupy'``.
59 |
60 | Returns
61 | -------
62 | tuple
63 | Returns the module for carrying out calculations and the actual backend
64 | that was reverted towards. If the right libraries are not installed,
65 | the function warns and reverts to the ``'numpy'`` backend.
66 | """
67 | if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None:
68 | msg = (
69 | "{0} backend not possible, please ensure that "
70 | + "the optional libraries are installed.\n"
71 | + "Reverting to numpy."
72 | )
73 | warn(msg.format(backend))
74 | backend = "numpy"
75 | return LIBRARIES[backend], backend
76 |
77 |
78 | def get_array_module(input_data):
79 | """Get Array Module.
80 |
81 | This method returns the array module, which tells if the data is residing
82 | on GPU or CPU.
83 |
84 | Parameters
85 | ----------
86 | input_data : numpy.ndarray, cupy.ndarray or tf.experimental.numpy.ndarray
87 | Input data array
88 |
89 | Returns
90 | -------
91 | module
92 | The numpy or cupy module
93 |
94 | """
95 | if LIBRARIES["tensorflow"] is not None:
96 | if isinstance(input_data, LIBRARIES["tensorflow"].ndarray):
97 | return LIBRARIES["tensorflow"]
98 | if LIBRARIES["cupy"] is not None:
99 | if isinstance(input_data, LIBRARIES["cupy"].ndarray):
100 | return LIBRARIES["cupy"]
101 | return np
102 |
103 |
104 | def change_backend(input_data, backend="cupy"):
105 | """Move data to device.
106 |
107 | This method changes the backend of an array. This can be used to copy data
108 | to GPU or to CPU.
109 |
110 | Parameters
111 | ----------
112 | input_data : numpy.ndarray, cupy.ndarray or tf.experimental.numpy.ndarray
113 | Input data array to be moved
114 | backend: str, optional
115 | The backend to use, one among ``'tensorflow'``, ``'cupy'`` and
116 | ``'numpy'``. Default is ``'cupy'``.
117 |
118 | Returns
119 | -------
120 | backend.ndarray
121 | An ndarray of specified backend
122 |
123 | """
124 | xp = get_array_module(input_data)
125 | txp, target_backend = get_backend(backend)
126 | if xp == txp:
127 | return input_data
128 | return txp.array(input_data)
129 |
130 |
131 | def move_to_cpu(input_data):
132 | """Move data to CPU.
133 |
134 | This method moves data from GPU to CPU. It returns the same data if it is
135 | already on CPU.
136 |
137 | Parameters
138 | ----------
139 | input_data : cupy.ndarray or tf.experimental.numpy.ndarray
140 | Input data array to be moved
141 |
142 | Returns
143 | -------
144 | numpy.ndarray
145 | The NumPy array residing on CPU
146 |
147 | Raises
148 | ------
149 | ValueError
150 | if the input does not correspond to any array
151 | """
152 | xp = get_array_module(input_data)
153 |
154 | if xp == LIBRARIES["numpy"]:
155 | return input_data
156 | elif xp == LIBRARIES["cupy"]:
157 | return input_data.get()
158 | elif xp == LIBRARIES["tensorflow"]:
159 | return input_data.data.numpy()
160 | raise ValueError("Cannot identify the array type.")
161 |
162 |
163 | def convert_to_tensor(input_data):
164 | """Convert data to a tensor.
165 |
166 | This method converts input data to a torch tensor. Particularly, this
167 | method is helpful to convert CuPy array to Tensor.
168 |
169 | Parameters
170 | ----------
171 | input_data : cupy.ndarray
172 | Input data array to be converted
173 |
174 | Returns
175 | -------
176 | torch.Tensor
177 | The tensor data
178 |
179 | Raises
180 | ------
181 | ImportError
182 | If Torch package not found
183 |
184 | """
185 | if not import_torch:
186 | raise ImportError(
187 | "Required version of Torch package not found"
188 | + "see documentation for details: https://cea-cosmic."
189 | + "github.io/ModOpt/#optional-packages",
190 | )
191 |
192 | xp = get_array_module(input_data)
193 |
194 | if xp == np:
195 | return torch.Tensor(input_data)
196 |
197 | return torch_from_dlpack(input_data.toDlpack()).float()
198 |
199 |
200 | def convert_to_cupy_array(input_data):
201 | """Convert Tensor data to a CuPy Array.
202 |
203 | This method converts input tensor data to a cupy array.
204 |
205 | Parameters
206 | ----------
207 | input_data : torch.Tensor
208 | Input Tensor to be converted
209 |
210 | Returns
211 | -------
212 | cupy.ndarray
213 | The tensor data as a CuPy array
214 |
215 | Raises
216 | ------
217 | ImportError
218 | If Torch package not found
219 |
220 | """
221 | if not import_torch:
222 | raise ImportError(
223 | "Required version of Torch package not found"
224 | + "see documentation for details: https://cea-cosmic."
225 | + "github.io/ModOpt/#optional-packages",
226 | )
227 |
228 | if input_data.is_cuda:
229 | return cp.fromDlpack(torch_to_dlpack(input_data))
230 |
231 | return input_data.detach().numpy()
232 |
--------------------------------------------------------------------------------
/src/modopt/math/metrics.py:
--------------------------------------------------------------------------------
1 | """METRICS.
2 |
3 | This module contains classes of different metric functions for optimization.
4 |
5 | :Author: Benoir Sarthou
6 |
7 | """
8 |
9 | import numpy as np
10 |
11 | from modopt.base.backend import move_to_cpu
12 |
13 | try:
14 | from skimage.metrics import structural_similarity as compare_ssim
15 | except ImportError: # pragma: no cover
16 | import_skimage = False
17 | else:
18 | import_skimage = True
19 |
20 |
21 | def min_max_normalize(img):
22 | """Min-Max Normalize.
23 |
24 | Normalize a given array in the [0,1] range.
25 |
26 | Parameters
27 | ----------
28 | img : numpy.ndarray
29 | Input image
30 |
31 | Returns
32 | -------
33 | numpy.ndarray
34 | normalized array
35 |
36 | """
37 | min_img = img.min()
38 | max_img = img.max()
39 |
40 | return (img - min_img) / (max_img - min_img)
41 |
42 |
43 | def _preprocess_input(test, ref, mask=None):
44 | """Proprocess Input.
45 |
46 | Wrapper to the metric.
47 |
48 | Parameters
49 | ----------
50 | ref : numpy.ndarray
51 | The reference image
52 | test : numpy.ndarray
53 | The tested image
54 | mask : numpy.ndarray, optional
55 | The mask for the ROI (default is ``None``)
56 |
57 | Raises
58 | ------
59 | ValueError
60 | For invalid mask value
61 |
62 | Notes
63 | -----
64 | Compute the metric only on magnetude.
65 |
66 | Returns
67 | -------
68 | float
69 | The SNR
70 |
71 | """
72 | test = np.abs(np.copy(test)).astype("float64")
73 | ref = np.abs(np.copy(ref)).astype("float64")
74 | test = min_max_normalize(test)
75 | ref = min_max_normalize(ref)
76 |
77 | if (not isinstance(mask, np.ndarray)) and (mask is not None):
78 | message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
79 | raise ValueError(message.format(mask))
80 |
81 | if mask is None:
82 | return test, ref, None
83 |
84 | return test, ref, mask
85 |
86 |
87 | def ssim(test, ref, mask=None):
88 | """Structural Similarity (SSIM).
89 |
90 | Calculate the SSIM between a test image and a reference image.
91 |
92 | Parameters
93 | ----------
94 | ref : numpy.ndarray
95 | The reference image
96 | test : numpy.ndarray
97 | The tested image
98 | mask : numpy.ndarray, optional
99 | The mask for the ROI (default is ``None``)
100 |
101 | Raises
102 | ------
103 | ImportError
104 | If Scikit-Image package not found
105 |
106 | Notes
107 | -----
108 | Compute the metric only on magnetude.
109 |
110 | Returns
111 | -------
112 | float
113 | The SNR
114 |
115 | """
116 | if not import_skimage: # pragma: no cover
117 | raise ImportError(
118 | "Required version of Scikit-Image package not found"
119 | + "see documentation for details: https://cea-cosmic."
120 | + "github.io/ModOpt/#optional-packages",
121 | )
122 |
123 | test, ref, mask = _preprocess_input(test, ref, mask)
124 | test = move_to_cpu(test)
125 | assim, ssim_value = compare_ssim(test, ref, full=True, data_range=1.0)
126 |
127 | if mask is None:
128 | return assim
129 |
130 | return (mask * ssim_value).sum() / mask.sum()
131 |
132 |
133 | def snr(test, ref, mask=None):
134 | """Signal-to-Noise Ratio (SNR).
135 |
136 | Calculate the SNR between a test image and a reference image.
137 |
138 | Parameters
139 | ----------
140 | ref: numpy.ndarray
141 | The reference image
142 | test: numpy.ndarray
143 | The tested image
144 | mask: numpy.ndarray, optional
145 | The mask for the ROI (default is ``None``)
146 |
147 | Notes
148 | -----
149 | Compute the metric only on magnetude.
150 |
151 | Returns
152 | -------
153 | float
154 | The SNR
155 |
156 | """
157 | test, ref, mask = _preprocess_input(test, ref, mask)
158 |
159 | if mask is not None:
160 | test = mask * test
161 |
162 | num = np.mean(np.square(test))
163 | deno = mse(test, ref)
164 |
165 | return 10.0 * np.log10(num / deno)
166 |
167 |
168 | def psnr(test, ref, mask=None):
169 | """Peak Signal-to-Noise Ratio (PSNR).
170 |
171 | Calculate the PSNR between a test image and a reference image.
172 |
173 | Parameters
174 | ----------
175 | ref : numpy.ndarray
176 | The reference image
177 | test : numpy.ndarray
178 | The tested image
179 | mask : numpy.ndarray, optional
180 | The mask for the ROI (default is ``None``)
181 |
182 | Notes
183 | -----
184 | Compute the metric only on magnetude.
185 |
186 | Returns
187 | -------
188 | float
189 | The PSNR
190 |
191 | """
192 | test, ref, mask = _preprocess_input(test, ref, mask)
193 |
194 | if mask is not None:
195 | test = mask * test
196 | ref = mask * ref
197 |
198 | num = np.max(np.abs(test))
199 | deno = mse(test, ref)
200 |
201 | return 10.0 * np.log10(num / deno)
202 |
203 |
204 | def mse(test, ref, mask=None):
205 | r"""Mean Squared Error (MSE).
206 |
207 | Calculate the MSE between a test image and a reference image.
208 |
209 | Parameters
210 | ----------
211 | ref : numpy.ndarray
212 | The reference image
213 | test : numpy.ndarray
214 | The tested image
215 | mask : numpy.ndarray, optional
216 | The mask for the ROI (default is ``None``)
217 |
218 | Notes
219 | -----
220 | Compute the metric only on magnetude.
221 |
222 | .. math::
223 | 1/N * \|ref - test\|_2
224 |
225 | Returns
226 | -------
227 | float
228 | The MSE
229 |
230 | """
231 | test, ref, mask = _preprocess_input(test, ref, mask)
232 |
233 | if mask is not None:
234 | test = mask * test
235 | ref = mask * ref
236 |
237 | return np.mean(np.square(test - ref))
238 |
239 |
240 | def nrmse(test, ref, mask=None):
241 | """Return NRMSE.
242 |
243 | Parameters
244 | ----------
245 | ref : numpy.ndarray
246 | The reference image
247 | test : numpy.ndarray
248 | The tested image
249 | mask : numpy.ndarray, optional
250 | The mask for the ROI (default is ``None``)
251 |
252 | Notes
253 | -----
254 | Compute the metric only on magnitude.
255 |
256 | Returns
257 | -------
258 | float
259 | The NRMSE
260 |
261 | """
262 | test, ref, mask = _preprocess_input(test, ref, mask)
263 |
264 | if mask is not None:
265 | test = mask * test
266 | ref = mask * ref
267 |
268 | num = np.sqrt(mse(test, ref))
269 | deno = np.sqrt(np.mean(np.square(test)))
270 |
271 | return num / deno
272 |
--------------------------------------------------------------------------------
/src/modopt/base/np_adjust.py:
--------------------------------------------------------------------------------
1 | """NUMPY ADJUSTMENT ROUTINES.
2 |
3 | This module contains methods for adjusting the default output for certain
4 | Numpy functions.
5 |
6 | :Author: Samuel Farrens
7 |
8 | """
9 |
10 | import numpy as np
11 |
12 |
13 | def rotate(input_data):
14 | """Rotate.
15 |
16 | This method rotates an input numpy array by 180 degrees.
17 |
18 | Parameters
19 | ----------
20 | input_data : numpy.ndarray
21 | Input data array (at least 2D)
22 |
23 | Returns
24 | -------
25 | numpy.ndarray
26 | Rotated data
27 |
28 | Notes
29 | -----
30 | Adjustment to numpy.rot90
31 |
32 | Examples
33 | --------
34 | >>> import numpy as np
35 | >>> from modopt.base.np_adjust import rotate
36 | >>> x = np.arange(9).reshape((3, 3))
37 | >>> x
38 | array([[0, 1, 2],
39 | [3, 4, 5],
40 | [6, 7, 8]])
41 | >>> rotate(x)
42 | array([[8, 7, 6],
43 | [5, 4, 3],
44 | [2, 1, 0]])
45 |
46 |
47 | See Also
48 | --------
49 | numpy.rot90 : base function
50 |
51 | """
52 | return np.rot90(input_data, 2)
53 |
54 |
55 | def rotate_stack(input_data):
56 | """Rotate stack.
57 |
58 | This method rotates each array in a stack of arrays by 180 degrees.
59 |
60 | Parameters
61 | ----------
62 | input_data : numpy.ndarray
63 | Input data array (at least 3D)
64 |
65 | Returns
66 | -------
67 | numpy.ndarray
68 | Rotated data
69 |
70 | Examples
71 | --------
72 | >>> import numpy as np
73 | >>> from modopt.base.np_adjust import rotate_stack
74 | >>> x = np.arange(18).reshape((2, 3, 3))
75 | >>> x
76 | array([[[ 0, 1, 2],
77 | [ 3, 4, 5],
78 | [ 6, 7, 8]],
79 |
80 | [[ 9, 10, 11],
81 | [12, 13, 14],
82 | [15, 16, 17]]])
83 | >>> rotate_stack(x)
84 | array([[[ 8, 7, 6],
85 | [ 5, 4, 3],
86 | [ 2, 1, 0]],
87 |
88 | [[17, 16, 15],
89 | [14, 13, 12],
90 | [11, 10, 9]]])
91 |
92 | See Also
93 | --------
94 | rotate : looped function
95 |
96 | """
97 | return np.array([rotate(array) for array in input_data])
98 |
99 |
100 | def pad2d(input_data, padding):
101 | """Pad array.
102 |
103 | This method pads an input numpy array with zeros in all directions.
104 |
105 | Parameters
106 | ----------
107 | input_data : numpy.ndarray
108 | Input data array (at least 2D)
109 | padding : int or tuple
110 | Amount of padding in x and y directions, respectively
111 |
112 | Returns
113 | -------
114 | numpy.ndarray
115 | Padded data
116 |
117 | Raises
118 | ------
119 | ValueError
120 | For
121 |
122 | Notes
123 | -----
124 | Adjustment to numpy.pad()
125 |
126 | Examples
127 | --------
128 | >>> import numpy as np
129 | >>> from modopt.base.np_adjust import pad2d
130 | >>> x = np.arange(9).reshape((3, 3))
131 | >>> x
132 | array([[0, 1, 2],
133 | [3, 4, 5],
134 | [6, 7, 8]])
135 | >>> pad2d(x, (1, 1))
136 | array([[0, 0, 0, 0, 0],
137 | [0, 0, 1, 2, 0],
138 | [0, 3, 4, 5, 0],
139 | [0, 6, 7, 8, 0],
140 | [0, 0, 0, 0, 0]])
141 |
142 | See Also
143 | --------
144 | numpy.pad : base function
145 |
146 | """
147 | input_data = np.array(input_data)
148 |
149 | if isinstance(padding, int):
150 | padding = np.array([padding])
151 | elif isinstance(padding, (tuple, list)):
152 | padding = np.array(padding)
153 | elif not isinstance(padding, np.ndarray):
154 | raise ValueError(
155 | "Padding must be an integer or a tuple (or list, np.ndarray) of integers",
156 | )
157 |
158 | if padding.size == 1:
159 | padding = np.repeat(padding, 2)
160 |
161 | pad_x = (padding[0], padding[0])
162 | pad_y = (padding[1], padding[1])
163 |
164 | return np.pad(input_data, (pad_x, pad_y), "constant")
165 |
166 |
167 | def ftr(input_data):
168 | """Fancy transpose right.
169 |
170 | Apply ``fancy_transpose`` to data with ``roll=1``.
171 |
172 | Parameters
173 | ----------
174 | input_data : numpy.ndarray
175 | Input data array
176 |
177 | Returns
178 | -------
179 | numpy.ndarray
180 | Transposed data
181 |
182 | See Also
183 | --------
184 | fancy_transpose : base function
185 |
186 | """
187 | return fancy_transpose(input_data)
188 |
189 |
190 | def ftl(input_data):
191 | """Fancy transpose left.
192 |
193 | Apply ``fancy_transpose`` to data with ``roll=-1``.
194 |
195 | Parameters
196 | ----------
197 | input_data : numpy.ndarray
198 | Input data array
199 |
200 | Returns
201 | -------
202 | numpy.ndarray
203 | Transposed data
204 |
205 | See Also
206 | --------
207 | fancy_transpose : base function
208 |
209 | """
210 | return fancy_transpose(input_data, -1)
211 |
212 |
213 | def fancy_transpose(input_data, roll=1):
214 | """Fancy transpose.
215 |
216 | This method transposes a multidimensional matrix.
217 |
218 | Parameters
219 | ----------
220 | input_data : numpy.ndarray
221 | Input data array
222 | roll : int
223 | Roll direction and amount (default is ``1``)
224 |
225 | Returns
226 | -------
227 | numpy.ndarray
228 | Transposed data
229 |
230 | Notes
231 | -----
232 | Adjustment to numpy.transpose
233 |
234 | Examples
235 | --------
236 | >>> import numpy as np
237 | >>> from modopt.base.np_adjust import fancy_transpose
238 | >>> x = np.arange(27).reshape(3, 3, 3)
239 | >>> x
240 | array([[[ 0, 1, 2],
241 | [ 3, 4, 5],
242 | [ 6, 7, 8]],
243 |
244 | [[ 9, 10, 11],
245 | [12, 13, 14],
246 | [15, 16, 17]],
247 |
248 | [[18, 19, 20],
249 | [21, 22, 23],
250 | [24, 25, 26]]])
251 | >>> fancy_transpose(x)
252 | array([[[ 0, 3, 6],
253 | [ 9, 12, 15],
254 | [18, 21, 24]],
255 |
256 | [[ 1, 4, 7],
257 | [10, 13, 16],
258 | [19, 22, 25]],
259 |
260 | [[ 2, 5, 8],
261 | [11, 14, 17],
262 | [20, 23, 26]]])
263 | >>> fancy_transpose(x, roll=-1)
264 | array([[[ 0, 9, 18],
265 | [ 1, 10, 19],
266 | [ 2, 11, 20]],
267 |
268 | [[ 3, 12, 21],
269 | [ 4, 13, 22],
270 | [ 5, 14, 23]],
271 |
272 | [[ 6, 15, 24],
273 | [ 7, 16, 25],
274 | [ 8, 17, 26]]])
275 |
276 | See Also
277 | --------
278 | numpy.transpose : base function
279 |
280 | """
281 | axis_roll = np.roll(np.arange(input_data.ndim), roll)
282 |
283 | return np.transpose(input_data, axes=axis_roll)
284 |
--------------------------------------------------------------------------------
/tests/test_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Test for base module.
3 |
4 | :Authors:
5 | Samuel Farrens
6 | Pierre-Antoine Comby
7 | """
8 |
9 | import numpy as np
10 | import numpy.testing as npt
11 | import pytest
12 | from test_helpers import failparam, skipparam
13 |
14 | from modopt.base import backend, np_adjust, transform, types
15 | from modopt.base.backend import LIBRARIES
16 |
17 |
18 | class TestNpAdjust:
19 | """Test for npadjust."""
20 |
21 | array33 = np.arange(9).reshape((3, 3))
22 | array233 = np.arange(18).reshape((2, 3, 3))
23 | arraypad = np.array(
24 | [
25 | [0, 0, 0, 0, 0],
26 | [0, 0, 1, 2, 0],
27 | [0, 3, 4, 5, 0],
28 | [0, 6, 7, 8, 0],
29 | [0, 0, 0, 0, 0],
30 | ]
31 | )
32 |
33 | def test_rotate(self):
34 | """Test rotate."""
35 | npt.assert_array_equal(
36 | np_adjust.rotate(self.array33),
37 | np.rot90(np.rot90(self.array33)),
38 | err_msg="Incorrect rotation.",
39 | )
40 |
41 | def test_rotate_stack(self):
42 | """Test rotate_stack."""
43 | npt.assert_array_equal(
44 | np_adjust.rotate_stack(self.array233),
45 | np.rot90(self.array233, k=2, axes=(1, 2)),
46 | err_msg="Incorrect stack rotation.",
47 | )
48 |
49 | @pytest.mark.parametrize(
50 | "padding",
51 | [
52 | 1,
53 | [1, 1],
54 | np.array([1, 1]),
55 | failparam("1", raises=ValueError),
56 | ],
57 | )
58 | def test_pad2d(self, padding):
59 | """Test pad2d."""
60 | npt.assert_equal(np_adjust.pad2d(self.array33, padding), self.arraypad)
61 |
62 | def test_fancy_transpose(self):
63 | """Test fancy transpose."""
64 | npt.assert_array_equal(
65 | np_adjust.fancy_transpose(self.array233),
66 | np.array(
67 | [
68 | [[0, 3, 6], [9, 12, 15]],
69 | [[1, 4, 7], [10, 13, 16]],
70 | [[2, 5, 8], [11, 14, 17]],
71 | ]
72 | ),
73 | err_msg="Incorrect fancy transpose",
74 | )
75 |
76 | def test_ftr(self):
77 | """Test ftr."""
78 | npt.assert_array_equal(
79 | np_adjust.ftr(self.array233),
80 | np.array(
81 | [
82 | [[0, 3, 6], [9, 12, 15]],
83 | [[1, 4, 7], [10, 13, 16]],
84 | [[2, 5, 8], [11, 14, 17]],
85 | ]
86 | ),
87 | err_msg="Incorrect fancy transpose: ftr",
88 | )
89 |
90 | def test_ftl(self):
91 | """Test fancy transpose left."""
92 | npt.assert_array_equal(
93 | np_adjust.ftl(self.array233),
94 | np.array(
95 | [
96 | [[0, 9], [1, 10], [2, 11]],
97 | [[3, 12], [4, 13], [5, 14]],
98 | [[6, 15], [7, 16], [8, 17]],
99 | ]
100 | ),
101 | err_msg="Incorrect fancy transpose: ftl",
102 | )
103 |
104 |
105 | class TestTransforms:
106 | """Test for the transform module."""
107 |
108 | cube = np.arange(16).reshape((4, 2, 2))
109 | map = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]])
110 | matrix = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]])
111 | layout = (2, 2)
112 | fail_layout = (3, 3)
113 |
114 | @pytest.mark.parametrize(
115 | ("func", "indata", "layout", "outdata"),
116 | [
117 | (transform.cube2map, cube, layout, map),
118 | failparam(transform.cube2map, np.eye(2), layout, map, raises=ValueError),
119 | (transform.map2cube, map, layout, cube),
120 | (transform.map2matrix, map, layout, matrix),
121 | (transform.matrix2map, matrix, matrix.shape, map),
122 | ],
123 | )
124 | def test_map(self, func, indata, layout, outdata):
125 | """Test cube2map."""
126 | npt.assert_array_equal(
127 | func(indata, layout),
128 | outdata,
129 | )
130 | if func.__name__ != "map2matrix":
131 | npt.assert_raises(ValueError, func, indata, self.fail_layout)
132 |
133 | def test_cube2matrix(self):
134 | """Test cube2matrix."""
135 | npt.assert_array_equal(
136 | transform.cube2matrix(self.cube),
137 | self.matrix,
138 | )
139 |
140 | def test_matrix2cube(self):
141 | """Test matrix2cube."""
142 | npt.assert_array_equal(
143 | transform.matrix2cube(self.matrix, self.cube[0].shape),
144 | self.cube,
145 | err_msg="Incorrect transformation: matrix2cube",
146 | )
147 |
148 |
149 | class TestType:
150 | """Test for type module."""
151 |
152 | data_list = list(range(5)) # noqa: RUF012
153 | data_int = np.arange(5)
154 | data_flt = np.arange(5).astype(float)
155 |
156 | @pytest.mark.parametrize(
157 | ("data", "checked"),
158 | [
159 | (1.0, 1.0),
160 | (1, 1.0),
161 | (data_list, data_flt),
162 | (data_int, data_flt),
163 | failparam("1.0", 1.0, raises=TypeError),
164 | ],
165 | )
166 | def test_check_float(self, data, checked):
167 | """Test check float."""
168 | npt.assert_array_equal(types.check_float(data), checked)
169 |
170 | @pytest.mark.parametrize(
171 | ("data", "checked"),
172 | [
173 | (1.0, 1),
174 | (1, 1),
175 | (data_list, data_int),
176 | (data_flt, data_int),
177 | failparam("1", None, raises=TypeError),
178 | ],
179 | )
180 | def test_check_int(self, data, checked):
181 | """Test check int."""
182 | npt.assert_array_equal(types.check_int(data), checked)
183 |
184 | @pytest.mark.parametrize(
185 | ("data", "dtype"), [(data_flt, np.integer), (data_int, np.floating)]
186 | )
187 | def test_check_npndarray(self, data, dtype):
188 | """Test check_npndarray."""
189 | npt.assert_raises(
190 | TypeError,
191 | types.check_npndarray,
192 | data,
193 | dtype=dtype,
194 | )
195 |
196 | def test_check_callable(self):
197 | """Test callable."""
198 | npt.assert_raises(TypeError, types.check_callable, 1)
199 |
200 |
201 | @pytest.mark.parametrize(
202 | "backend_name",
203 | [
204 | skipparam(name, cond=LIBRARIES[name] is None, reason=f"{name} not installed")
205 | for name in LIBRARIES
206 | ],
207 | )
208 | def test_tf_backend(backend_name):
209 | """Test Modopt computational backends."""
210 | xp, checked_backend_name = backend.get_backend(backend_name)
211 | if checked_backend_name != backend_name or xp != LIBRARIES[backend_name]:
212 | raise AssertionError(f"{backend_name} get_backend fails!")
213 | xp_input = backend.change_backend(np.array([10, 10]), backend_name)
214 | if (
215 | backend.get_array_module(LIBRARIES[backend_name].ones(1))
216 | != backend.LIBRARIES[backend_name]
217 | or backend.get_array_module(xp_input) != LIBRARIES[backend_name]
218 | ):
219 | raise AssertionError(f"{backend_name} backend fails!")
220 |
--------------------------------------------------------------------------------
/docs/source/plugin_example.rst:
--------------------------------------------------------------------------------
1 | Plugin Example
2 | ==============
3 |
4 | This is a quick example demonstrating the use of ModOpt in the context
5 | of a PySAP plugin. In this example the Forward Backward algorithm from
6 | ModOpt is used to find the best fitting line to a set of data points.
7 |
8 | --------------
9 |
10 | First we import the required packages.
11 |
12 | .. code:: python
13 |
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 | from modopt import opt
17 |
18 | Then we define a couple of utility functions for the puroposes of this
19 | example.
20 |
21 | .. code:: python
22 |
23 | def regression_plot(x1, y1, x2=None, y2=None):
24 | """Plot data points and a proposed fit.
25 | """
26 |
27 | fig = plt.figure(figsize=(12, 8))
28 | plt.plot(x1, y1, 'o', color="#19C3F5", label='Data')
29 | if not all([isinstance(var, type(None)) for var in (x2, y2)]):
30 | plt.plot(x2, y2, '--', color='#FF4F5B', label='Model')
31 | plt.title('Best Fit Line', fontsize=20)
32 | plt.xlabel('x', fontsize=18)
33 | plt.ylabel('y', fontsize=18)
34 | plt.legend()
35 | plt.show()
36 |
37 | def y_func(x, a):
38 | """Equation of a polynomial line.
39 | """
40 |
41 | return sum([(a_i * x ** n) for a_i, n in zip(a, range(a.size))])
42 |
43 | --------------
44 |
45 | The Problem
46 | -----------
47 |
48 | Image we have the following set of data points and we want to find the
49 | best fitting 3\ :math:`^{\textrm{rd}}` degree polynomial line.
50 |
51 | .. code:: python
52 |
53 | # A range of positions on the x-axis
54 | x = np.linspace(0.0, 1.0, 21)
55 |
56 | # A set of positions on the y-axis
57 | y = np.array([0.486, 0.866, 0.944, 1.144, 1.103, 1.202, 1.166, 1.191, 1.124, 1.095, 1.122, 1.102, 1.099, 1.017,
58 | 1.111, 1.117, 1.152, 1.265, 1.380, 1.575, 1.857])
59 |
60 | # Plot the points
61 | regression_plot(x, y)
62 |
63 |
64 |
65 | .. image:: output_5_0.png
66 |
67 |
68 | This corresponds to solving the inverse problem
69 |
70 | .. math:: y = Ha
71 |
72 | where :math:`y` are the points on the y-axis, :math:`H` is a matirx
73 | coposed of the points in the x-axis
74 |
75 | .. code:: python
76 |
77 | H = np.array([np.ones(x.size), x, x ** 2, x ** 3]).T
78 |
79 | and :math:`a` is a model describing the best fit line that we aim to
80 | recover.
81 |
82 | Note: We could easily solve this problem analytically.
83 |
84 | --------------
85 |
86 | ModOpt solution
87 | ---------------
88 |
89 | Let’s attempt to solve this problem using the Forward Backward algorithm
90 | implemented in ModOpt in a few simple steps.
91 |
92 | Step 1: Set an initial guess for :math:`a`
93 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
94 |
95 | We will start asumming :math:`a` is a vector of zeros, which is
96 | obviously a very bad fit to the data.
97 |
98 | .. code:: python
99 |
100 | a_0 = np.zeros(4)
101 |
102 | regression_plot(x, y, x, y_func(x, a_0))
103 |
104 |
105 |
106 | .. image:: output_9_0.png
107 |
108 |
109 | Step 2: Define a gradient operator for the problem
110 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
111 |
112 | For this problem we can use the basic gradient provided in ModOpt
113 |
114 | .. math:: \nabla F(x) = H^T(Hx - y)
115 |
116 | we simply need to define the operations of :math:`H` and :math:`H^T`.
117 |
118 | .. code:: python
119 |
120 | grad_op = opt.gradient.GradBasic(y, lambda x: np.dot(H, x), lambda x: np.dot(H.T, x))
121 |
122 |
123 | .. parsed-literal::
124 |
125 | WARNING: Making input data immutable.
126 |
127 |
128 | Step 3: Define a proximity operator for the algorithm
129 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
130 |
131 | Since we don’t need to implement any regularisation we simply set the
132 | identity operator.
133 |
134 | .. code:: python
135 |
136 | prox_op = opt.proximity.IdentityProx()
137 |
138 | Step 4: Pass everything to the algorithm
139 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
140 |
141 | Here we pass the initial guess for :math:`a` along with the gradient and
142 | proximity operators we defined. We also specify a value for the
143 | :math:`\beta` parameter in the Forward Backward algorithm. Finally, we
144 | specify that we want a maximum of 500 iterations.
145 |
146 | .. code:: python
147 |
148 | alg = opt.algorithms.ForwardBackward(a_0, grad_op, prox_op, beta_param=0.01, auto_iterate=False)
149 | alg.iterate(max_iter=500)
150 |
151 |
152 | .. parsed-literal::
153 |
154 | 100% (500 of 500) \|######################\| Elapsed Time: 0:00:00 Time: 0:00:00
155 |
156 |
157 | .. parsed-literal::
158 |
159 | - ITERATION: 1
160 | - DATA FIDELITY (X): 6.931639464572723
161 | - COST: 6.931639464572723
162 |
163 | - ITERATION: 2
164 | - DATA FIDELITY (X): 3.4179113043303198
165 | - COST: 3.4179113043303198
166 |
167 | - ITERATION: 3
168 | - DATA FIDELITY (X): 1.7894732608136656
169 | - COST: 1.7894732608136656
170 |
171 | - ITERATION: 4
172 | - DATA FIDELITY (X): 0.8712577337041495
173 | - COST: 0.8712577337041495
174 |
175 | - CONVERGENCE TEST -
176 | - CHANGE IN COST: 2.8897396205130526
177 |
178 | ...
179 |
180 | - ITERATION: 499
181 | - DATA FIDELITY (X): 0.05534231539479718
182 | - COST: 0.05534231539479718
183 |
184 | - ITERATION: 500
185 | - DATA FIDELITY (X): 0.05498126500595005
186 | - COST: 0.05498126500595005
187 |
188 | - CONVERGENCE TEST -
189 | - CHANGE IN COST: 0.013163437152771431
190 |
191 |
192 |
193 | Step 5: Extract the final result
194 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
195 |
196 | Once the algorithm has finished running we take the final result. We can
197 | see that it’s a pretty good fit!
198 |
199 | .. code:: python
200 |
201 | a = alg.x_final
202 |
203 | regression_plot(x, y, x, y_func(x, a))
204 |
205 |
206 |
207 | .. image:: output_17_0.png
208 |
209 |
210 | --------------
211 |
212 | PySAP Plugin
213 | ------------
214 |
215 | Now imagine we want to implement the above solution as a PySAP plugin.
216 | To do so we would first need to is create a new repository using the
217 | PySAP plugin template (https://github.com/CEA-COSMIC/pysap-extplugin).
218 |
219 | Afterwards, we could package our solution as a more user friendly
220 | function.
221 |
222 | .. code:: python
223 |
224 | def poly_fit(x, y, deg=3):
225 |
226 | H = np.array([np.ones(x.size), x, x ** 2, x ** 3]).T
227 | a_0 = np.zeros(4)
228 | grad_op = opt.gradient.GradBasic(y, lambda x: np.dot(H, x), lambda x: np.dot(H.T, x))
229 | prox_op = opt.proximity.IdentityProx()
230 | cost_op = opt.cost.costObj(operators=[grad_op, prox_op], verbose=False)
231 | alg = opt.algorithms.ForwardBackward(a_0, grad_op, prox_op, cost_op, beta_param=0.01, auto_iterate=False,
232 | progress=False)
233 | alg.iterate(max_iter=500)
234 |
235 | return alg.x_final
236 |
237 | Once the plugin (let’s call it foo) was integrated it would be possible
238 | to call this function as follows:
239 |
240 | .. code:: python
241 |
242 | from pysap.plugins.foo import poly_fit
243 |
244 | Then we could use this function to fit our data directly.
245 |
246 | .. code:: python
247 |
248 | a_new = poly_fit(x, y)
249 |
250 | regression_plot(x, y, x, y_func(x, a_new))
251 |
252 |
253 |
254 | .. image:: output_21_0.png
255 |
--------------------------------------------------------------------------------
/src/modopt/math/stats.py:
--------------------------------------------------------------------------------
1 | """STATISTICS ROUTINES.
2 |
3 | This module contains methods for basic statistics.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 |
11 | try:
12 | from packaging import version
13 | from astropy import __version__ as astropy_version
14 | from astropy.convolution import Gaussian2DKernel
15 | except ImportError: # pragma: no cover
16 | import_astropy = False
17 | else:
18 | import_astropy = True
19 |
20 |
21 | def gaussian_kernel(data_shape, sigma, norm="max"):
22 | """Gaussian kernel.
23 |
24 | This method produces a Gaussian kerenal of a specified size and dispersion.
25 |
26 | Parameters
27 | ----------
28 | data_shape : tuple
29 | Desiered shape of the kernel
30 | sigma : float
31 | Standard deviation of the kernel
32 | norm : {'max', 'sum'}, optional, default='max'
33 | Normalisation of the kernel
34 |
35 | Returns
36 | -------
37 | numpy.ndarray
38 | Kernel
39 |
40 | Raises
41 | ------
42 | ImportError
43 | If Astropy package not found
44 | ValueError
45 | For invalid norm
46 |
47 | Examples
48 | --------
49 | >>> from modopt.math.stats import gaussian_kernel
50 | >>> gaussian_kernel((3, 3), 1)
51 | array([[0.36787944, 0.60653066, 0.36787944],
52 | [0.60653066, 1. , 0.60653066],
53 | [0.36787944, 0.60653066, 0.36787944]])
54 |
55 | >>> gaussian_kernel((3, 3), 1, norm='sum')
56 | array([[0.07511361, 0.1238414 , 0.07511361],
57 | [0.1238414 , 0.20417996, 0.1238414 ],
58 | [0.07511361, 0.1238414 , 0.07511361]])
59 |
60 | """
61 | if not import_astropy: # pragma: no cover
62 | raise ImportError("Astropy package not found.")
63 |
64 | if norm not in {"max", "sum"}:
65 | raise ValueError('Invalid norm, options are "max", "sum" or "none".')
66 |
67 | kernel = np.array(
68 | Gaussian2DKernel(sigma, x_size=data_shape[1], y_size=data_shape[0]),
69 | )
70 |
71 | if norm == "max":
72 | return kernel / np.max(kernel)
73 |
74 | elif version.parse(astropy_version) < version.parse("5.2"):
75 | return kernel / np.sum(kernel)
76 |
77 | else:
78 | return kernel
79 |
80 |
81 | def mad(input_data):
82 | r"""Median absolute deviation.
83 |
84 | This method calculates the median absolute deviation of the input data.
85 |
86 | Parameters
87 | ----------
88 | input_data : numpy.ndarray
89 | Input data array
90 |
91 | Returns
92 | -------
93 | float
94 | MAD value
95 |
96 | Examples
97 | --------
98 | >>> import numpy as np
99 | >>> from modopt.math.stats import mad
100 | >>> a = np.arange(9).reshape(3, 3)
101 | >>> mad(a)
102 | 2.0
103 |
104 | Notes
105 | -----
106 | The MAD is calculated as follows:
107 |
108 | .. math::
109 |
110 | \mathrm{MAD} = \mathrm{median}\left(|X_i - \mathrm{median}(X)|\right)
111 |
112 | See Also
113 | --------
114 | numpy.median : median function used
115 |
116 | """
117 | return np.median(np.abs(input_data - np.median(input_data)))
118 |
119 |
120 | def mse(data1, data2):
121 | """Mean Squared Error.
122 |
123 | This method returns the Mean Squared Error (MSE) between two data sets.
124 |
125 | Parameters
126 | ----------
127 | data1 : numpy.ndarray
128 | First data set
129 | data2 : numpy.ndarray
130 | Second data set
131 |
132 | Returns
133 | -------
134 | float
135 | Mean squared error
136 |
137 | Examples
138 | --------
139 | >>> import numpy as np
140 | >>> from modopt.math.stats import mse
141 | >>> a = np.arange(9).reshape(3, 3)
142 | >>> mse(a, a + 2)
143 | 4.0
144 |
145 | """
146 | return np.mean((data1 - data2) ** 2)
147 |
148 |
149 | def psnr(data1, data2, method="starck", max_pix=255):
150 | r"""Peak Signal-to-Noise Ratio.
151 |
152 | This method calculates the Peak Signal-to-Noise Ratio between two data
153 | sets.
154 |
155 | Parameters
156 | ----------
157 | data1 : numpy.ndarray
158 | First data set
159 | data2 : numpy.ndarray
160 | Second data set
161 | method : {'starck', 'wiki'}, optional
162 | PSNR implementation (default is ``'starck'``)
163 | max_pix : int, optional
164 | Maximum number of pixels (default is ``255``)
165 |
166 | Returns
167 | -------
168 | float
169 | PSNR value
170 |
171 | Raises
172 | ------
173 | ValueError
174 | For invalid PSNR method
175 |
176 | Examples
177 | --------
178 | >>> import numpy as np
179 | >>> from modopt.math.stats import psnr
180 | >>> a = np.arange(9).reshape(3, 3)
181 | >>> psnr(a, a + 2)
182 | 12.041199826559248
183 |
184 | >>> psnr(a, a + 2, method='wiki')
185 | 42.11020369539948
186 |
187 | Notes
188 | -----
189 | ``'starck'``:
190 |
191 | Implements eq.3.7 from :cite:`starck2010`
192 |
193 | ``'wiki'``:
194 |
195 | Implements PSNR equation on
196 | https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
197 |
198 | .. math::
199 |
200 | \mathrm{PSNR} = 20\log_{10}(\mathrm{MAX}_I -
201 | 10\log_{10}(\mathrm{MSE}))
202 |
203 | """
204 | if method == "starck":
205 | return 20 * np.log10(
206 | (data1.shape[0] * np.abs(np.max(data1) - np.min(data1)))
207 | / np.linalg.norm(data1 - data2),
208 | )
209 |
210 | elif method == "wiki":
211 | return 20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2))
212 |
213 | raise ValueError(
214 | 'Invalid PSNR method. Options are "starck" and "wiki"',
215 | )
216 |
217 |
218 | def psnr_stack(data1, data2, metric=np.mean, method="starck"):
219 | """Peak Signa-to-Noise for stack of images.
220 |
221 | This method calculates the PSNRs for two stacks of 2D arrays.
222 | By default the metod returns the mean value of the PSNRs, but any other
223 | metric can be used.
224 |
225 | Parameters
226 | ----------
227 | data1 : numpy.ndarray
228 | Stack of images, 3D array
229 | data2 : numpy.ndarray
230 | Stack of recovered images, 3D array
231 | metric : function
232 | The desired metric to be applied to the PSNR values (default is
233 | ``numpy.mean``)
234 | method : {'starck', 'wiki'}, optional
235 | PSNR implementation (default is ``'starck'``)
236 |
237 | Returns
238 | -------
239 | float
240 | Metric result of PSNR values
241 |
242 | Raises
243 | ------
244 | ValueError
245 | For invalid input data dimensions
246 |
247 | Examples
248 | --------
249 | >>> import numpy as np
250 | >>> from modopt.math.stats import psnr_stack
251 | >>> a = np.arange(18).reshape(2, 3, 3)
252 | >>> psnr_stack(a, a + 2)
253 | 12.041199826559248
254 |
255 | See Also
256 | --------
257 | numpy.mean : default metric
258 |
259 | """
260 | if data1.ndim != 3 or data2.ndim != 3:
261 | raise ValueError("Input data must be a 3D np.ndarray")
262 |
263 | return metric(
264 | [psnr(i_elem, j_elem, method=method) for i_elem, j_elem in zip(data1, data2)]
265 | )
266 |
267 |
268 | def sigma_mad(input_data):
269 | r"""MAD Standard Deviation.
270 |
271 | This method calculates the standard deviation of the input data from the
272 | MAD.
273 |
274 | Parameters
275 | ----------
276 | input_data : numpy.ndarray
277 | Input data array
278 |
279 | Returns
280 | -------
281 | float
282 | Sigma value
283 |
284 | Examples
285 | --------
286 | >>> import numpy as np
287 | >>> from modopt.math.stats import sigma_mad
288 | >>> a = np.arange(9).reshape(3, 3)
289 | >>> sigma_mad(a)
290 | 2.9652
291 |
292 | Notes
293 | -----
294 | This function can be used for estimating the standeviation of the noise in
295 | imgaes.
296 |
297 | Sigma is calculated as follows:
298 |
299 | .. math::
300 |
301 | \sigma = 1.4826 \mathrm{MAD}(X)
302 |
303 | """
304 | return 1.4826 * mad(input_data)
305 |
--------------------------------------------------------------------------------
/src/modopt/opt/gradient.py:
--------------------------------------------------------------------------------
1 | """GRADIENT CLASSES.
2 |
3 | This module contains classses for defining algorithm gradients.
4 | Based on work by Yinghao Ge and Fred Ngole.
5 |
6 | :Author: Samuel Farrens
7 |
8 | """
9 |
10 | import numpy as np
11 |
12 | from modopt.base.types import check_callable, check_float, check_npndarray
13 |
14 |
15 | class GradParent:
16 | """Gradient Parent Class.
17 |
18 | This class defines the basic methods that will be inherited by specific
19 | gradient classes.
20 |
21 | Parameters
22 | ----------
23 | input_data : numpy.ndarray
24 | The observed data
25 | op : callable
26 | The operator
27 | trans_op : callable
28 | The transpose operator
29 | get_grad : callable, optional
30 | Method for calculating the gradient (default is ``None``)
31 | cost: callable, optional
32 | Method for calculating the cost (default is ``None``)
33 | data_type : type, optional
34 | Expected data type of the input data (default is ``None``)
35 | input_data_writeable: bool, optional
36 | Option to make the observed data writeable (default is ``False``)
37 | verbose : bool, optional
38 | Option for verbose output (default is ``True``)
39 |
40 | Examples
41 | --------
42 | >>> import numpy as np
43 | >>> from modopt.opt.gradient import GradParent
44 | >>> y = np.arange(9).reshape(3, 3).astype(float)
45 | >>> g = GradParent(y, lambda x: x ** 2, lambda x: x ** 3)
46 | >>> g.op(y)
47 | array([[ 0., 1., 4.],
48 | [ 9., 16., 25.],
49 | [36., 49., 64.]])
50 | >>> g.trans_op(y)
51 | array([[ 0., 1., 8.],
52 | [ 27., 64., 125.],
53 | [216., 343., 512.]])
54 | >>> g.trans_op_op(y)
55 | array([[0.00000e+00, 1.00000e+00, 6.40000e+01],
56 | [7.29000e+02, 4.09600e+03, 1.56250e+04],
57 | [4.66560e+04, 1.17649e+05, 2.62144e+05]])
58 |
59 | """
60 |
61 | def __init__(
62 | self,
63 | input_data,
64 | op,
65 | trans_op,
66 | get_grad=None,
67 | cost=None,
68 | data_type=None,
69 | input_data_writeable=False,
70 | verbose=True,
71 | ):
72 | self.verbose = verbose
73 | self._input_data_writeable = input_data_writeable
74 | self._grad_data_type = data_type
75 | self.obs_data = input_data
76 | self.op = op
77 | self.trans_op = trans_op
78 |
79 | if not isinstance(get_grad, type(None)):
80 | self.get_grad = get_grad
81 | if not isinstance(cost, type(None)):
82 | self.cost = cost
83 |
84 | @property
85 | def obs_data(self):
86 | r"""Observed Data.
87 |
88 | The observed data :math:`\mathbf{y}`.
89 |
90 | Returns
91 | -------
92 | numpy.ndarray
93 | The observed data
94 |
95 | """
96 | return self._obs_data
97 |
98 | @obs_data.setter
99 | def obs_data(self, input_data):
100 | if self._grad_data_type in {float, np.floating}:
101 | input_data = check_float(input_data)
102 | check_npndarray(
103 | input_data,
104 | dtype=self._grad_data_type,
105 | writeable=self._input_data_writeable,
106 | verbose=self.verbose,
107 | )
108 |
109 | self._obs_data = input_data
110 |
111 | @property
112 | def op(self):
113 | r"""Operator.
114 |
115 | The operator :math:`\mathbf{H}`.
116 |
117 | Returns
118 | -------
119 | callable
120 | The operator function
121 |
122 | """
123 | return self._op
124 |
125 | @op.setter
126 | def op(self, operator):
127 | self._op = check_callable(operator)
128 |
129 | @property
130 | def trans_op(self):
131 | r"""Transpose operator.
132 |
133 | The transpose operator :math:`\mathbf{H}^T`.
134 |
135 | Returns
136 | -------
137 | callable
138 | The transpose operator function
139 |
140 | """
141 | return self._trans_op
142 |
143 | @trans_op.setter
144 | def trans_op(self, operator):
145 | self._trans_op = check_callable(operator)
146 |
147 | @property
148 | def get_grad(self):
149 | """Get gradient value."""
150 | return self._get_grad
151 |
152 | @get_grad.setter
153 | def get_grad(self, method):
154 | self._get_grad = check_callable(method)
155 |
156 | @property
157 | def grad(self):
158 | """Gradient value."""
159 | return self._grad
160 |
161 | @grad.setter
162 | def grad(self, input_value):
163 | if self._grad_data_type in {float, np.floating}:
164 | input_value = check_float(input_value)
165 | self._grad = input_value
166 |
167 | @property
168 | def cost(self):
169 | """Cost contribution."""
170 | return self._cost
171 |
172 | @cost.setter
173 | def cost(self, method):
174 | self._cost = check_callable(method)
175 |
176 | def trans_op_op(self, input_data):
177 | r"""Transpose Operation of the Operator.
178 |
179 | This method calculates the action of the transpose operator on
180 | the action of the operator on the data.
181 |
182 | Parameters
183 | ----------
184 | input_data : numpy.ndarray
185 | Input data array
186 |
187 | Returns
188 | -------
189 | numpy.ndarray
190 | Result
191 |
192 | Notes
193 | -----
194 | Implements the following equation:
195 |
196 | .. math::
197 | \mathbf{H}^T(\mathbf{H}\mathbf{x})
198 |
199 | where :math:`\mathbf{x}` is the ``input_data``.
200 |
201 | """
202 | return self.trans_op(self.op(input_data))
203 |
204 |
205 | class GradBasic(GradParent):
206 | """Basic Gradient Class.
207 |
208 | This class defines the gradient calculation and costs methods for
209 | common inverse problems.
210 |
211 | Parameters
212 | ----------
213 | *args : tuple
214 | Positional arguments
215 | **kwargs : dict
216 | Keyword arguments
217 |
218 | Examples
219 | --------
220 | >>> import numpy as np
221 | >>> from modopt.opt.gradient import GradBasic
222 | >>> y = np.arange(9).reshape(3, 3).astype(float)
223 | >>> g = GradBasic(y, lambda x: x ** 2, lambda x: x ** 3)
224 | >>> g.get_grad(y)
225 | >>> g.grad
226 | array([[0.00000e+00, 0.00000e+00, 8.00000e+00],
227 | [2.16000e+02, 1.72800e+03, 8.00000e+03],
228 | [2.70000e+04, 7.40880e+04, 1.75616e+05]])
229 |
230 | See Also
231 | --------
232 | GradParent : parent class
233 |
234 | """
235 |
236 | def __init__(self, *args, **kwargs):
237 | super().__init__(*args, **kwargs)
238 | self.get_grad = self._get_grad_method
239 | self.cost = self._cost_method
240 |
241 | def _get_grad_method(self, input_data):
242 | r"""Get the gradient.
243 |
244 | This method calculates the gradient step from the input data.
245 |
246 | Parameters
247 | ----------
248 | input_data : numpy.ndarray
249 | Input data array
250 |
251 | Notes
252 | -----
253 | Implements the following equation:
254 |
255 | .. math::
256 | \nabla F(x) = \mathbf{H}^T(\mathbf{H}\mathbf{x} - \mathbf{y})
257 |
258 | """
259 | self.grad = self.trans_op(self.op(input_data) - self.obs_data)
260 |
261 | def _cost_method(self, *args, **kwargs):
262 | """Calculate gradient component of the cost.
263 |
264 | This method returns the l2 norm error of the difference between the
265 | original data and the data obtained after optimisation.
266 |
267 | Parameters
268 | ----------
269 | *args : tuple
270 | Positional arguments
271 | **kwargs : dict
272 | Keyword arguments
273 |
274 | Returns
275 | -------
276 | float
277 | Gradient cost component
278 |
279 | """
280 | cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2
281 |
282 | if kwargs.get("verbose"):
283 | print(" - DATA FIDELITY (X):", cost_val)
284 |
285 | return cost_val
286 |
--------------------------------------------------------------------------------
/src/modopt/opt/linear/base.py:
--------------------------------------------------------------------------------
1 | """Base classes for linear operators."""
2 |
3 | import numpy as np
4 |
5 | from modopt.base.types import check_callable
6 | from modopt.base.backend import get_array_module
7 |
8 |
9 | class LinearParent:
10 | """Linear Operator Parent Class.
11 |
12 | This class sets the structure for defining linear operator instances.
13 |
14 | Parameters
15 | ----------
16 | op : callable
17 | Callable function that implements the linear operation
18 | adj_op : callable
19 | Callable function that implements the linear adjoint operation
20 |
21 | Examples
22 | --------
23 | >>> from modopt.opt.linear import LinearParent
24 | >>> a = LinearParent(lambda x: x * 2, lambda x: x ** 3)
25 | >>> a.op(2)
26 | 4
27 | >>> a.adj_op(2)
28 | 8
29 |
30 | """
31 |
32 | def __init__(self, op, adj_op):
33 | self.op = op
34 | self.adj_op = adj_op
35 |
36 | @property
37 | def op(self):
38 | """Linear Operator."""
39 | return self._op
40 |
41 | @op.setter
42 | def op(self, operator):
43 | self._op = check_callable(operator)
44 |
45 | @property
46 | def adj_op(self):
47 | """Linear Adjoint Operator."""
48 | return self._adj_op
49 |
50 | @adj_op.setter
51 | def adj_op(self, operator):
52 | self._adj_op = check_callable(operator)
53 |
54 |
55 | class Identity(LinearParent):
56 | """Identity Operator Class.
57 |
58 | This is a dummy class that can be used in the optimisation classes.
59 |
60 | See Also
61 | --------
62 | LinearParent : parent class
63 |
64 | """
65 |
66 | def __init__(self):
67 | self.op = lambda input_data: input_data
68 | self.adj_op = self.op
69 | self.cost = lambda *args, **kwargs: 0
70 |
71 |
72 | class MatrixOperator(LinearParent):
73 | """
74 | Matrix Operator class.
75 |
76 | This class transforms an array into a suitable linear operator.
77 | """
78 |
79 | def __init__(self, array):
80 | self.op = lambda x: array @ x
81 | xp = get_array_module(array)
82 |
83 | if xp.any(xp.iscomplex(array)):
84 | self.adj_op = lambda x: array.T.conjugate() @ x
85 | else:
86 | self.adj_op = lambda x: array.T @ x
87 |
88 |
89 | class LinearCombo(LinearParent):
90 | """Linear Combination Class.
91 |
92 | This class defines a combination of linear transform operators.
93 |
94 | Parameters
95 | ----------
96 | operators : list, tuple or numpy.ndarray
97 | List of linear operator class instances
98 | weights : list, tuple or numpy.ndarray, optional
99 | List of weights for combining the linear adjoint operator results
100 |
101 | Examples
102 | --------
103 | >>> from modopt.opt.linear import LinearCombo, LinearParent
104 | >>> a = LinearParent(lambda x: x * 2, lambda x: x ** 3)
105 | >>> b = LinearParent(lambda x: x * 4, lambda x: x ** 5)
106 | >>> c = LinearCombo([a, b])
107 | >>> a.op(2)
108 | 4
109 | >>> b.op(2)
110 | 8
111 | >>> c.op(2)
112 | array([4, 8], dtype=object)
113 | >>> a.adj_op(2)
114 | 8
115 | >>> b.adj_op(2)
116 | 32
117 | >>> c.adj_op([2, 2])
118 | 20.0
119 |
120 | See Also
121 | --------
122 | LinearParent : parent class
123 | """
124 |
125 | def __init__(self, operators, weights=None):
126 | operators, weights = self._check_inputs(operators, weights)
127 | self.operators = operators
128 | self.weights = weights
129 | self.op = self._op_method
130 | self.adj_op = self._adj_op_method
131 |
132 | def _check_type(self, input_val):
133 | """Check input type.
134 |
135 | This method checks if the input is a list, tuple or a numpy array and
136 | converts the input to a numpy array.
137 |
138 | Parameters
139 | ----------
140 | input_val : any
141 | Any input object
142 |
143 | Returns
144 | -------
145 | numpy.ndarray
146 | Numpy array of inputs
147 |
148 | Raises
149 | ------
150 | TypeError
151 | For invalid input type
152 | ValueError
153 | If input list is empty
154 |
155 | """
156 | if not isinstance(input_val, (list, tuple, np.ndarray)):
157 | raise TypeError(
158 | "Invalid input type, input must be a list, tuple or numpy " + "array.",
159 | )
160 |
161 | input_val = np.array(input_val)
162 |
163 | if not input_val.size:
164 | raise ValueError("Input list is empty.")
165 |
166 | return input_val
167 |
168 | def _check_inputs(self, operators, weights):
169 | """Check inputs.
170 |
171 | This method cheks that the input operators and weights are correctly
172 | formatted.
173 |
174 | Parameters
175 | ----------
176 | operators : list, tuple or numpy.ndarray
177 | List of linear operator class instances
178 | weights : list, tuple or numpy.ndarray
179 | List of weights for combining the linear adjoint operator results
180 |
181 | Returns
182 | -------
183 | tuple
184 | Operators and weights
185 |
186 | Raises
187 | ------
188 | ValueError
189 | If the number of weights does not match the number of operators
190 | TypeError
191 | If the individual weight values are not floats
192 |
193 | """
194 | operators = self._check_type(operators)
195 |
196 | for operator in operators:
197 | if not hasattr(operator, "op"):
198 | raise ValueError('Operators must contain "op" method.')
199 |
200 | if not hasattr(operator, "adj_op"):
201 | raise ValueError('Operators must contain "adj_op" method.')
202 |
203 | operator.op = check_callable(operator.op)
204 | operator.adj_op = check_callable(operator.adj_op)
205 |
206 | if not isinstance(weights, type(None)):
207 | weights = self._check_type(weights)
208 |
209 | if weights.size != operators.size:
210 | raise ValueError(
211 | "The number of weights must match the number of " + "operators.",
212 | )
213 |
214 | if not np.issubdtype(weights.dtype, np.floating):
215 | raise TypeError("The weights must be a list of float values.")
216 |
217 | return operators, weights
218 |
219 | def _op_method(self, input_data):
220 | """Operator.
221 |
222 | This method returns the input data operated on by all of the operators.
223 |
224 | Parameters
225 | ----------
226 | input_data : numpy.ndarray
227 | Input data array
228 |
229 | Returns
230 | -------
231 | numpy.ndarray
232 | Linear operation results
233 |
234 | """
235 | res = np.empty(len(self.operators), dtype=np.ndarray)
236 |
237 | for index, _ in enumerate(self.operators):
238 | res[index] = self.operators[index].op(input_data)
239 |
240 | return res
241 |
242 | def _adj_op_method(self, input_data):
243 | """Adjoint operator.
244 |
245 | This method returns the combination of the result of all of the
246 | adjoint operators. If weights are provided the comibination is the sum
247 | of the weighted results, otherwise the combination is the mean.
248 |
249 | Parameters
250 | ----------
251 | input_data : numpy.ndarray
252 | Input data array
253 |
254 | Returns
255 | -------
256 | numpy.ndarray
257 | Adjoint operation results
258 |
259 | """
260 | if isinstance(self.weights, type(None)):
261 | return np.mean(
262 | [
263 | operator.adj_op(elem)
264 | for elem, operator in zip(input_data, self.operators)
265 | ],
266 | axis=0,
267 | )
268 |
269 | return np.sum(
270 | [
271 | weight * operator.adj_op(elem)
272 | for elem, operator, weight in zip(
273 | input_data,
274 | self.operators,
275 | self.weights,
276 | )
277 | ],
278 | axis=0,
279 | )
280 |
--------------------------------------------------------------------------------
/docs/source/refs.bib:
--------------------------------------------------------------------------------
1 | @book{bauschke2009,
2 | author = {Bauschke, Heinz and Burachik, Regina and Combettes, Patrick and Luke, D.},
3 | year = {2009},
4 | month = {11},
5 | pages = {},
6 | title = {Fixed-Point Algorithms for Inverse Problems in Science and Engineering},
7 | publisher={Springer},
8 | url = {https://www.springer.com/gp/book/9781441995681}
9 | }
10 |
11 | @ARTICLE{candes2007,
12 | author = {Candes, Emmanuel J. and Wakin, Michael B. and Boyd, Stephen P.},
13 | title = {Enhancing Sparsity by Reweighted L1 Minimization},
14 | journal = {arXiv e-prints},
15 | keywords = {Statistics - Methodology, Mathematics - Statistics, 49N30, 49N45, 94A12},
16 | year = "2007",
17 | month = "Nov",
18 | eid = {arXiv:0711.1612},
19 | pages = {arXiv:0711.1612},
20 | archivePrefix = {arXiv},
21 | eprint = {0711.1612},
22 | primaryClass = {stat.ME},
23 | adsurl = {https://ui.adsabs.harvard.edu/abs/2007arXiv0711.1612C},
24 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
25 | }
26 |
27 | @article{chambolle2015,
28 | TITLE = {{On the convergence of the iterates of ``FISTA''}},
29 | AUTHOR = {Chambolle, Antonin and Dossal, Charles},
30 | URL = {https://hal.inria.fr/hal-01060130},
31 | JOURNAL = {{Journal of Optimization Theory and Applications}},
32 | PUBLISHER = {{Springer Verlag}},
33 | VOLUME = {Volume 166},
34 | NUMBER = { Issue 3},
35 | PAGES = {25},
36 | YEAR = {2015},
37 | MONTH = Aug,
38 | KEYWORDS = {proximal ; FISTA ; optimization ; convergence ; forward backward},
39 | PDF = {https://hal.inria.fr/hal-01060130/file/Fista10.pdf},
40 | HAL_ID = {hal-01060130},
41 | HAL_VERSION = {v3},
42 | }
43 |
44 | @article{combettes2005,
45 | author = {Combettes, Patrick L. and Wajs, Valérie R.},
46 | title = {Signal Recovery by Proximal Forward-Backward Splitting},
47 | journal = {Multiscale Modeling \& Simulation},
48 | volume = {4},
49 | number = {4},
50 | pages = {1168-1200},
51 | year = {2005},
52 | doi = {10.1137/050626090},
53 | URL = {https://doi.org/10.1137/050626090},
54 | eprint = {https://doi.org/10.1137/050626090}
55 | }
56 |
57 | @article{condat2013,
58 | author = {Condat, Laurent},
59 | year = {2013},
60 | month = {08},
61 | pages = {},
62 | title = {A Primal–Dual Splitting Method for Convex Optimization Involving Lipschitzian, Proximable and Linear Composite Terms},
63 | volume = {158},
64 | journal = {Journal of Optimization Theory and Applications},
65 | doi = {10.1007/s10957-012-0245-9}
66 | }
67 |
68 | @ARTICLE{defazio2014,
69 | author = {{Defazio}, Aaron and {Bach}, Francis and {Lacoste-Julien}, Simon},
70 | title = "{SAGA: A Fast Incremental Gradient Method With Support for Non-Strongly Convex Composite Objectives}",
71 | journal = {arXiv e-prints},
72 | keywords = {Computer Science - Machine Learning, Mathematics - Optimization and Control, Statistics - Machine Learning},
73 | year = 2014,
74 | month = jul,
75 | eid = {arXiv:1407.0202},
76 | pages = {arXiv:1407.0202},
77 | archivePrefix = {arXiv},
78 | eprint = {1407.0202},
79 | primaryClass = {cs.LG},
80 | adsurl = {https://ui.adsabs.harvard.edu/abs/2014arXiv1407.0202D},
81 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
82 | }
83 |
84 | @ARTICLE{figueiredo2014,
85 | author = {Figueiredo, Mario A.~T. and Nowak, Robert D.},
86 | title = {Sparse Estimation with Strongly Correlated Variables using Ordered Weighted L1 Regularization},
87 | journal = {arXiv e-prints},
88 | keywords = {Statistics - Machine Learning},
89 | year = "2014",
90 | month = "Sep",
91 | eid = {arXiv:1409.4005},
92 | pages = {arXiv:1409.4005},
93 | archivePrefix = {arXiv},
94 | eprint = {1409.4005},
95 | primaryClass = {stat.ML},
96 | adsurl = {https://ui.adsabs.harvard.edu/abs/2014arXiv1409.4005F},
97 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
98 | }
99 |
100 | @ARTICLE{kim2017,
101 | author = {Kim, Donghwan and Fessler, Jeffrey A.},
102 | title = {Adaptive Restart of the Optimized Gradient Method for Convex Optimization},
103 | journal = {arXiv e-prints},
104 | keywords = {Mathematics - Optimization and Control},
105 | year = "2017",
106 | month = "Mar",
107 | eid = {arXiv:1703.04641},
108 | pages = {arXiv:1703.04641},
109 | archivePrefix = {arXiv},
110 | eprint = {1703.04641},
111 | primaryClass = {math.OC},
112 | adsurl = {https://ui.adsabs.harvard.edu/abs/2017arXiv170304641K},
113 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
114 | }
115 |
116 | @ARTICLE{liang2018,
117 | author = {Liang, Jingwei and Luo, Tao and Schonlieb, Carola-Bibiane},
118 | title = {Improving Fast Iterative Shrinkage-Thresholding Algorithm: Faster, Smarter and Greedier},
119 | journal = {arXiv e-prints},
120 | keywords = {Mathematics - Optimization and Control},
121 | year = "2018",
122 | month = "Nov",
123 | eid = {arXiv:1811.01430},
124 | pages = {arXiv:1811.01430},
125 | archivePrefix = {arXiv},
126 | eprint = {1811.01430},
127 | primaryClass = {math.OC},
128 | adsurl = {https://ui.adsabs.harvard.edu/abs/2018arXiv181101430L},
129 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
130 | }
131 |
132 | @ARTICLE{mcdonald2014,
133 | author = {McDonald, Andrew M. and Pontil, Massimiliano and Stamos, Dimitris},
134 | title = {New Perspectives on k-Support and Cluster Norms},
135 | journal = {arXiv e-prints},
136 | keywords = {Statistics - Machine Learning},
137 | year = "2014",
138 | month = "Mar",
139 | eid = {arXiv:1403.1481},
140 | pages = {arXiv:1403.1481},
141 | archivePrefix = {arXiv},
142 | eprint = {1403.1481},
143 | primaryClass = {stat.ML},
144 | adsurl = {https://ui.adsabs.harvard.edu/abs/2014arXiv1403.1481M},
145 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
146 | }
147 |
148 | @ARTICLE{raguet2011,
149 | author = {Raguet, Hugo and Fadili, Jalal and Peyr{\'e}, Gabriel},
150 | title = {Generalized Forward-Backward Splitting},
151 | journal = {arXiv e-prints},
152 | keywords = {Mathematics - Optimization and Control, 65K05},
153 | year = "2011",
154 | month = "Aug",
155 | eid = {arXiv:1108.4404},
156 | pages = {arXiv:1108.4404},
157 | archivePrefix = {arXiv},
158 | eprint = {1108.4404},
159 | primaryClass = {math.OC},
160 | adsurl = {https://ui.adsabs.harvard.edu/abs/2011arXiv1108.4404R},
161 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
162 | }
163 |
164 | @ARTICLE{ruder2017,
165 | author = {{Ruder}, Sebastian},
166 | title = "{An overview of gradient descent optimization algorithms}",
167 | journal = {arXiv e-prints},
168 | keywords = {Computer Science - Machine Learning},
169 | year = 2016,
170 | month = sep,
171 | eid = {arXiv:1609.04747},
172 | pages = {arXiv:1609.04747},
173 | archivePrefix = {arXiv},
174 | eprint = {1609.04747},
175 | primaryClass = {cs.LG},
176 | adsurl = {https://ui.adsabs.harvard.edu/abs/2016arXiv160904747R},
177 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
178 | }
179 |
180 | @book{starck2010,
181 | place={Cambridge},
182 | title={Sparse Image and Signal Processing: Wavelets, Curvelets, Morphological Diversity},
183 | DOI={10.1017/CBO9780511730344},
184 | publisher={Cambridge University Press},
185 | author={Starck, Jean-Luc and Murtagh, Fionn and Fadili, Jalal M.},
186 | year={2010}
187 | }
188 |
189 | @article{yuan2006,
190 | author = {Yuan, Ming and Lin, Yi},
191 | year = {2006},
192 | month = {02},
193 | pages = {49-67},
194 | title = {Model Selection and Estimation in Regression With Grouped Variables},
195 | volume = {68},
196 | journal = {Journal of the Royal Statistical Society Series B},
197 | doi = {10.1111/j.1467-9868.2005.00532.x}
198 | }
199 |
200 | @article{zou2005,
201 | author = {Zou, Hui and Hastie, Trevor},
202 | year = {2005},
203 | month = {02},
204 | pages = {768-768},
205 | title = {Regularization and variable selection via the elastic net (vol B 67, pg 301, 2005)},
206 | volume = {67},
207 | journal = {Journal of the Royal Statistical Society Series B},
208 | doi = {10.1111/j.1467-9868.2005.00527.x}
209 | }
210 |
211 | @article{Goldstein2014,
212 | author={Goldstein, Tom and O’Donoghue, Brendan and Setzer, Simon and Baraniuk, Richard},
213 | year={2014},
214 | month={Jan},
215 | pages={1588–1623},
216 | title={Fast Alternating Direction Optimization Methods},
217 | journal={SIAM Journal on Imaging Sciences},
218 | volume={7},
219 | ISSN={1936-4954},
220 | doi={10/gdwr49},
221 | }
222 |
--------------------------------------------------------------------------------
/src/modopt/base/observable.py:
--------------------------------------------------------------------------------
1 | """Observable.
2 |
3 | This module contains observable classes
4 |
5 | :Author: Benoir Sarthou
6 |
7 | """
8 |
9 | import time
10 |
11 | import numpy as np
12 |
13 |
14 | class SignalObject:
15 | """Dummy class for signals."""
16 |
17 | pass
18 |
19 |
20 | class Observable:
21 | """Base class for observable classes.
22 |
23 | This class defines a simple interface to add or remove observers
24 | on an object.
25 |
26 | Parameters
27 | ----------
28 | signals : list
29 | The allowed signals
30 |
31 | """
32 |
33 | def __init__(self, signals):
34 | # Define class parameters
35 | self._allowed_signals = []
36 | self._observers = {}
37 |
38 | # Set allowed signals
39 | for signal in signals:
40 | self._allowed_signals.append(signal)
41 | self._observers[signal] = []
42 |
43 | # Set a lock option to avoid multiple observer notifications
44 | self._locked = False
45 |
46 | def add_observer(self, signal, observer):
47 | """Add an observer to the object.
48 |
49 | Raise an exception if the signal is not allowed.
50 |
51 | Parameters
52 | ----------
53 | signal : str
54 | A valid signal
55 | observer : callable
56 | A function that will be called when the signal is emitted
57 |
58 | """
59 | self._is_allowed_signal(signal)
60 | self._add_observer(signal, observer)
61 |
62 | def remove_observer(self, signal, observer):
63 | """Remove an observer from the object.
64 |
65 | Raise an eception if the signal is not allowed.
66 |
67 | Parameters
68 | ----------
69 | signal : str
70 | A valid signal
71 | observer : callable
72 | An obervation function to be removed
73 |
74 | """
75 | self._is_allowed_event(signal)
76 | self._remove_observer(signal, observer)
77 |
78 | def notify_observers(self, signal, **kwargs):
79 | """Notify observers of a given signal.
80 |
81 | Parameters
82 | ----------
83 | signal : str
84 | A valid signal
85 | **kwargs : dict
86 | The parameters that will be sent to the observers
87 |
88 | Returns
89 | -------
90 | bool
91 | ``False`` if a notification is in progress, otherwise ``True``
92 |
93 | """
94 | # Check if a notification if in progress
95 | if self._locked:
96 | return False
97 | # Set the lock
98 | self._locked = True
99 |
100 | # Create a signal object
101 | signal_to_be_notified = SignalObject()
102 | signal_to_be_notified.object = self
103 | signal_to_be_notified.signal = signal
104 |
105 | for name, key_value in kwargs.items():
106 | setattr(signal_to_be_notified, name, key_value)
107 | # Notify all the observers
108 | for observer in self._observers[signal]:
109 | observer(signal_to_be_notified)
110 | # Unlock the notification process
111 | self._locked = False
112 |
113 | def _get_allowed_signals(self):
114 | """Get allowed signals.
115 |
116 | Events allowed for the current object.
117 |
118 | Returns
119 | -------
120 | list
121 | List of allowed signals
122 |
123 | """
124 | return self._allowed_signals
125 |
126 | allowed_signals = property(_get_allowed_signals)
127 |
128 | def _is_allowed_signal(self, signal):
129 | """Check if a signal is valid.
130 |
131 | Raise an exception if the signal is not allowed.
132 |
133 | Parameters
134 | ----------
135 | signal: str
136 | A signal
137 |
138 | Raises
139 | ------
140 | ValueError
141 | For invalid signal
142 |
143 | """
144 | if signal not in self._allowed_signals:
145 | message = 'Signal "{0}" is not allowed for "{1}"'
146 | raise ValueError(message.format(signal, type(self)))
147 |
148 | def _add_observer(self, signal, observer):
149 | """Associate an observer to a valid signal.
150 |
151 | Parameters
152 | ----------
153 | signal : str
154 | A valid signal
155 | observer : callable
156 | An obervation function
157 |
158 | """
159 | if observer not in self._observers[signal]:
160 | self._observers[signal].append(observer)
161 |
162 | def _remove_observer(self, signal, observer):
163 | """Remove an observer to a valid signal.
164 |
165 | Parameters
166 | ----------
167 | signal : str
168 | A valid signal
169 | observer : callable
170 | An obervation function to be removed
171 |
172 | """
173 | if observer in self._observers[signal]:
174 | self._observers[signal].remove(observer)
175 |
176 |
177 | class MetricObserver:
178 | """Metric observer.
179 |
180 | Wrapper of the metric to the observer object notify by the Observable
181 | class.
182 |
183 | Parameters
184 | ----------
185 | name : str
186 | The name of the metric
187 | metric : callable
188 | Metric function with this precise signature func(test, ref)
189 | mapping : dict
190 | Define the mapping between the iterate variable and the metric
191 | keyword: ``{'x_new':'name_var_1', 'y_new':'name_var_2'}``. To cancel
192 | the need of a variable, the dict value should be None:
193 | ``'y_new': None``.
194 | cst_kwargs : dict
195 | Keywords arguments of constant argument for the metric computation
196 | early_stopping : bool
197 | If True it will compute the convergence flag (default is ``False``)
198 | wind : int
199 | Window on with the convergence criteria is compute (default is ``6``)
200 | eps : float
201 | The level of criteria of convergence (default is ``1.0e-3``)
202 |
203 | """
204 |
205 | def __init__(
206 | self,
207 | name,
208 | metric,
209 | mapping,
210 | cst_kwargs,
211 | early_stopping=False,
212 | wind=6,
213 | eps=1.0e-3,
214 | ):
215 | self.name = name
216 | self.metric = metric
217 | self.mapping = mapping
218 | self.cst_kwargs = cst_kwargs
219 | self.list_cv_values = []
220 | self.list_iters = []
221 | self.list_dates = []
222 | self.eps = eps
223 | self.wind = wind
224 | self.converge_flag = False
225 | self.early_stopping = early_stopping
226 |
227 | def __call__(self, signal):
228 | """Call Method.
229 |
230 | Wrapper the call from the observer signature to the metric
231 | signature.
232 |
233 | Parameters
234 | ----------
235 | signal : str
236 | A valid signal
237 |
238 | """
239 | kwargs = {}
240 | for key, key_value in self.mapping.items():
241 | if key_value is not None:
242 | kwargs[key_value] = getattr(signal, key)
243 | kwargs.update(self.cst_kwargs)
244 | self.list_iters.append(signal.idx)
245 | self.list_dates.append(time.time())
246 | self.list_cv_values.append(self.metric(**kwargs))
247 |
248 | if self.early_stopping:
249 | self.is_converge()
250 |
251 | def is_converge(self):
252 | """Check convergence.
253 |
254 | Return ``True`` if the convergence criteria is matched.
255 |
256 | """
257 | if len(self.list_cv_values) < self.wind:
258 | return
259 | start_idx = -self.wind
260 | mid_idx = -(self.wind // 2)
261 | old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean()
262 | current_mean = np.array(self.list_cv_values[mid_idx:]).mean()
263 | normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean)
264 | self.converge_flag = normalize_residual_metrics < self.eps
265 |
266 | def retrieve_metrics(self):
267 | """Retrieve metrics.
268 |
269 | Return the convergence metrics saved with the corresponding
270 | iterations.
271 |
272 | Returns
273 | -------
274 | dict
275 | Convergence metrics
276 |
277 | """
278 | time_val = np.array(self.list_dates)
279 |
280 | if time_val.size >= 1:
281 | time_val -= time_val[0]
282 |
283 | return {
284 | "time": time_val,
285 | "index": self.list_iters,
286 | "values": self.list_cv_values,
287 | }
288 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to ModOpt
2 |
3 | ModOpt is a series of Modular Optimisation tools for solving inverse problems.
4 | This package has been developed in collaboration between [CosmoStat](http://www.cosmostat.org/) and [NeuroSpin](http://joliot.cea.fr/drf/joliot/Pages/Entites_de_recherche/NeuroSpin.aspx) via the [COSMIC](http://cosmic.cosmostat.org/) project.
5 |
6 | ## Contents
7 |
8 | 1. [Introduction](#introduction)
9 | 2. [Issues](#issues)
10 | a. [Asking Questions](#asking-questions)
11 | b. [Installation Issues](#installation-issues)
12 | c. [Reporting Bugs](#reporting-bugs)
13 | d. [Requesting Features](#requesting-features)
14 | 3. [Pull Requests](#pull-requests)
15 | a. [Before Making a PR](#before-making-a-pr)
16 | b. [Making a PR](#making-a-pr)
17 | c. [After Making a PR](#after-making-a-pr)
18 | d. [Content](#content)
19 | e. [CI Tests](#ci-tests)
20 | f. [Coverage](#coverage)
21 | g. [Style Guide](#style-guide)
22 |
23 | ## Introduction
24 |
25 | ModOpt is fully open-source and as such users are welcome to fork, clone and/or reuse the software freely. Users wishing to contribute to the development of this package, however, are kindly requested to adhere to the following guidelines and the [code of conduct](./CODE_OF_CONDUCT.md).
26 |
27 | ## Issues
28 |
29 | The easiest way to contribute to ModOpt is by raising a "New issue". This will give you the opportunity to ask questions, report bugs or even request new features.
30 |
31 | Remember to use clear and descriptive titles for issues. This will help other users that encounter similar problems find quick solutions. We also ask that you read the available documentation and browse existing issues on similar topics before raising a new issue in order to avoid repetition.
32 |
33 | ### Asking Questions
34 |
35 | Users are of course welcome to ask any question relating to ModOpt and we will endeavour to reply as soon as possible.
36 |
37 | These issues should include the `help wanted` label.
38 |
39 | ### Installation Issues
40 |
41 | If you encounter difficulties installing ModOpt be sure to re-read the installation instructions provided. If you are still unable to install the package please remember to include the following details in the issue you raise:
42 |
43 | * your operating system and the corresponding version (*e.g.* macOS v10.14.1, Ubuntu v16.04.1, *etc.*),
44 | * the version of Python you are using (*e.g* v3.6.7, *etc.*),
45 | * the python environment you are using (if any) and the corresponding version (*e.g.* virtualenv v16.1.0, conda v4.5.11, *etc.*),
46 | * the exact steps followed while attempting to install ModOpt
47 | * and the error message printed or a screen capture of the terminal output.
48 |
49 | These issues should include the `installation` label.
50 |
51 | ### Reporting Bugs
52 |
53 | If you discover a bug while using ModOpt please provide the same information requested for installation issues. Be sure to list the exact steps you followed that lead to the bug you encountered so that we can attempt to recreate the conditions.
54 |
55 | If you are aware of the source of the bug we would very much appreciate if you could provide the module(s) and line number(s) affected. This will enable us to more rapidly fix the problem.
56 |
57 | These issues should include the `bug` label.
58 |
59 | ### Requesting Features
60 |
61 | If you believe ModOpt could be improved with the addition of extra functionality or features feel free to let us know. We cannot guarantee that we will include these features, but we will certainly take your suggestions into consideration.
62 |
63 | In order to increase your chances of having a feature included, be sure to be as clear and specific as possible as to the properties this feature should have.
64 |
65 | These issues should include the `enhancement` label.
66 |
67 | ## Pull Requests
68 |
69 | If you would like to take a more active roll in the development of ModOpt you can do so by submitting a "Pull request". A Pull Requests (PR) is a way by which a user can submit modifications or additions to the ModOpt package directly. PRs need to be reviewed by the package moderators and if accepted are merged into the master branch of the repository.
70 |
71 | Before making a PR, be sure to carefully read the following guidelines.
72 |
73 | ### Before Making a PR
74 |
75 | The following steps should be followed before making a pull request:
76 |
77 | 1. Log into your GitHub account or create an account if you do not already have one.
78 |
79 | 1. Go to the main ModOpt repository page: [https://github.com/CEA-COSMIC/ModOpt](https://github.com/CEA-COSMIC/ModOpt)
80 |
81 | 1. Fork the repository, *i.e.* press the button on the top right with this symbol
. This will create an independent copy of the repository on your account.
82 |
83 | 1. Clone your fork of ModOpt.
84 |
85 | ```bash
86 | git clone https://github.com/YOUR_USERNAME/ModOpt
87 | ```
88 |
89 | 5. Add the original repository (*upstream*) to remote.
90 |
91 | ```bash
92 | git remote add upstream https://github.com/CEA-COSMIC/ModOpt
93 | ```
94 |
95 | ### Making a PR
96 |
97 | The following steps should be followed to make a pull request:
98 |
99 | 1. Pull the latest updates to the original repository.
100 |
101 | ```bash
102 | git pull upstream master
103 | ```
104 |
105 | 2. Create a new branch for your modifications.
106 |
107 | ```bash
108 | git checkout -b BRANCH_NAME
109 | ```
110 |
111 | 3. Make the desired modifications to the relevant modules.
112 |
113 | 4. Add the modified files to the staging area.
114 |
115 | ```bash
116 | git add .
117 | ```
118 |
119 | 5. Make sure all of the appropriate files have been staged. Note that all files listed in green will be included in the following commit.
120 |
121 | ```bash
122 | git status
123 | ```
124 |
125 | 6. Commit the changes with an appropriate description.
126 |
127 | ```bash
128 | git commit -m "Description of commit"
129 | ```
130 |
131 | 7. Push the commits to a branch on your fork of ModOpt.
132 |
133 | ```bash
134 | git push origin BRANCH_NAME
135 | ```
136 |
137 | 8. Make a pull request for your branch with a clear description of what has been done, why and what issues this relates to.
138 |
139 | 9. Wait for feedback and repeat steps 3 through 7 if necessary.
140 |
141 | ### After Making a PR
142 |
143 | If your PR is accepted and merged it is recommended that the following steps be followed to keep your fork up to date.
144 |
145 | 1. Make sure you switch back to your local master branch.
146 |
147 | ```bash
148 | git checkout master
149 | ```
150 |
151 | 2. Delete the local branch you used for the PR.
152 |
153 | ```bash
154 | git branch -d BRANCH_NAME
155 | ```
156 |
157 | 3. Pull the latest updates to the original repository, which include your PR changes.
158 |
159 | ```bash
160 | git pull upstream master
161 | ```
162 |
163 | 4. Push the commits to your fork.
164 |
165 | ```bash
166 | git push origin master
167 | ```
168 |
169 | ### Content
170 |
171 | Every PR should correspond to a bug fix or new feature issue that has already be raised. When you make a PR be sure to tag the issue that it resolves (*e.g.* this PR relates to issue #1). This way the issue can be closed once the PR has been merged.
172 |
173 | The content of a given PR should be as concise as possible. To that end, aim to restrict modifications to those needed to resolve a single issue. Additional bug fixes or features should be made as separate PRs.
174 |
175 | ### CI Tests
176 |
177 | Continuous Integration (CI) tests are implemented via [Travis CI](https://travis-ci.org/). All PRs must pass the CI tests before being merged. Your PR may not be reviewed by a moderator until all CI test are passed. Therefore, try to resolve any issues in your PR that may cause the tests to fail.
178 |
179 | In some cases it may be necessary to modify the unit tests, but this should be clearly justified in the PR description.
180 |
181 | ### Coverage
182 |
183 | Coverage tests are implemented via [Coveralls](https://coveralls.io/). These tests will fail if the coverage, *i.e.* the number of lines of code covered by unit tests, decreases. When submitting new code in a PR, contributors should aim to write appropriate unit tests. If the coverage drops significantly moderators may request unit tests be added before the PR is merged.
184 |
185 | ### Style Guide
186 |
187 | All contributions should adhere to the following style guides currently implemented in ModOpt:
188 |
189 | 1. All code should be compatible with the Python versions listed in `README.rst`.
190 |
191 | 1. All code should adhere to [PEP8](https://www.python.org/dev/peps/pep-0008/) standards.
192 |
193 | 1. Docstrings need to be provided for all new modules, methods and classes. These should adhere to [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) standards.
194 |
195 | 1. When in doubt look at the existing code for inspiration.
196 |
--------------------------------------------------------------------------------
/src/modopt/base/transform.py:
--------------------------------------------------------------------------------
1 | """DATA TRANSFORM ROUTINES.
2 |
3 | This module contains methods for transforming data.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 |
11 |
12 | def cube2map(data_cube, layout):
13 | """Cube to Map.
14 |
15 | This method transforms the input data from a 3D cube to a 2D map with a
16 | specified layout.
17 |
18 | Parameters
19 | ----------
20 | data_cube : numpy.ndarray
21 | Input data cube, 3D array of 2D images
22 | layout : tuple
23 | 2D layout of 2D images
24 |
25 | Returns
26 | -------
27 | numpy.ndarray
28 | 2D map
29 |
30 | Raises
31 | ------
32 | ValueError
33 | For invalid data dimensions
34 | ValueError
35 | For invalid layout
36 |
37 | Examples
38 | --------
39 | >>> import numpy as np
40 | >>> from modopt.base.transform import cube2map
41 | >>> a = np.arange(16).reshape((4, 2, 2))
42 | >>> cube2map(a, (2, 2))
43 | array([[ 0, 1, 4, 5],
44 | [ 2, 3, 6, 7],
45 | [ 8, 9, 12, 13],
46 | [10, 11, 14, 15]])
47 |
48 | See Also
49 | --------
50 | map2cube : complimentary function
51 |
52 | """
53 | if data_cube.ndim != 3:
54 | raise ValueError("The input data must have 3 dimensions.")
55 |
56 | if data_cube.shape[0] != np.prod(layout):
57 | raise ValueError(
58 | "The desired layout must match the number of input " + "data layers.",
59 | )
60 |
61 | res = [
62 | np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))])
63 | for elem in range(layout[0])
64 | ]
65 |
66 | return np.vstack(res)
67 |
68 |
69 | def map2cube(data_map, layout):
70 | """Map to cube.
71 |
72 | This method transforms the input data from a 2D map with given layout to
73 | a 3D cube.
74 |
75 | Parameters
76 | ----------
77 | data_map : numpy.ndarray
78 | Input data map, 2D array
79 | layout : tuple
80 | 2D layout of 2D images
81 |
82 | Returns
83 | -------
84 | numpy.ndarray
85 | 3D cube
86 |
87 | Raises
88 | ------
89 | ValueError
90 | For invalid layout
91 |
92 | Examples
93 | --------
94 | >>> import numpy as np
95 | >>> from modopt.base.transform import map2cube
96 | >>> a = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13],
97 | ... [10, 11, 14, 15]])
98 | >>> map2cube(a, (2, 2))
99 | array([[[ 0, 1],
100 | [ 2, 3]],
101 |
102 | [[ 4, 5],
103 | [ 6, 7]],
104 |
105 | [[ 8, 9],
106 | [10, 11]],
107 |
108 | [[12, 13],
109 | [14, 15]]])
110 |
111 | See Also
112 | --------
113 | cube2map : complimentary function
114 |
115 | """
116 | if np.all(np.array(data_map.shape) % np.array(layout)):
117 | raise ValueError(
118 | "The desired layout must be a multiple of the number "
119 | + "pixels in the data map.",
120 | )
121 |
122 | d_shape = np.array(data_map.shape) // np.array(layout)
123 |
124 | return np.array(
125 | [
126 | data_map[
127 | (
128 | slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
129 | slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
130 | )
131 | ]
132 | for i_elem in range(layout[0])
133 | for j_elem in range(layout[1])
134 | ]
135 | )
136 |
137 |
138 | def map2matrix(data_map, layout):
139 | """Map to Matrix.
140 |
141 | This method transforms a 2D map to a 2D matrix.
142 |
143 | Parameters
144 | ----------
145 | data_map : numpy.ndarray
146 | Input data map, 2D array
147 | layout : tuple
148 | 2D layout of 2D images
149 |
150 | Returns
151 | -------
152 | numpy.ndarray
153 | 2D matrix
154 |
155 | Examples
156 | --------
157 | >>> import numpy as np
158 | >>> from modopt.base.transform import map2matrix
159 | >>> a = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13],
160 | ... [10, 11, 14, 15]])
161 | >>> map2matrix(a, (2, 2))
162 | array([[ 0, 4, 8, 12],
163 | [ 1, 5, 9, 13],
164 | [ 2, 6, 10, 14],
165 | [ 3, 7, 11, 15]])
166 |
167 | See Also
168 | --------
169 | matrix2map : complimentary function
170 |
171 | """
172 | layout = np.array(layout)
173 |
174 | # Get the shape of the images
175 | image_shape = (np.array(data_map.shape) // layout)[0]
176 |
177 | # Stack objects from map
178 | data_matrix = []
179 |
180 | for i_elem in range(np.prod(layout)):
181 | lower = (
182 | image_shape * (i_elem // layout[1]),
183 | image_shape * (i_elem % layout[1]),
184 | )
185 | upper = (
186 | image_shape * (i_elem // layout[1] + 1),
187 | image_shape * (i_elem % layout[1] + 1),
188 | )
189 | data_matrix.append(
190 | (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape(
191 | image_shape**2
192 | ),
193 | )
194 |
195 | return np.array(data_matrix).T
196 |
197 |
198 | def matrix2map(data_matrix, map_shape):
199 | """Matrix to Map.
200 |
201 | This method transforms a 2D matrix to a 2D map.
202 |
203 | Parameters
204 | ----------
205 | data_matrix : numpy.ndarray
206 | Input data matrix, 2D array
207 | map_shape : tuple
208 | 2D shape of the output map
209 |
210 | Returns
211 | -------
212 | numpy.ndarray
213 | 2D map
214 |
215 | Examples
216 | --------
217 | >>> import numpy as np
218 | >>> from modopt.base.transform import matrix2map
219 | >>> a = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14],
220 | ... [3, 7, 11, 15]])
221 | >>> matrix2map(a, (4, 4))
222 | array([[ 0, 1, 4, 5],
223 | [ 2, 3, 6, 7],
224 | [ 8, 9, 12, 13],
225 | [10, 11, 14, 15]])
226 |
227 | See Also
228 | --------
229 | map2matrix : complimentary function
230 |
231 | """
232 | map_shape = np.array(map_shape)
233 |
234 | # Get the shape and layout of the images
235 | image_shape = np.sqrt(data_matrix.shape[0]).astype(int)
236 | layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int")
237 |
238 | # Map objects from matrix
239 | data_map = np.zeros(map_shape)
240 |
241 | temp = data_matrix.reshape(image_shape, image_shape, data_matrix.shape[1])
242 |
243 | for i_elem in range(data_matrix.shape[1]):
244 | lower = (
245 | image_shape * (i_elem // layout[1]),
246 | image_shape * (i_elem % layout[1]),
247 | )
248 | upper = (
249 | image_shape * (i_elem // layout[1] + 1),
250 | image_shape * (i_elem % layout[1] + 1),
251 | )
252 | data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem]
253 |
254 | return data_map.astype(int)
255 |
256 |
257 | def cube2matrix(data_cube):
258 | """Cube to Matrix.
259 |
260 | This method transforms a 3D cube to a 2D matrix.
261 |
262 | Parameters
263 | ----------
264 | data_cube : numpy.ndarray
265 | Input data cube, 3D array
266 |
267 | Returns
268 | -------
269 | numpy.ndarray
270 | 2D matrix
271 |
272 | Examples
273 | --------
274 | >>> import numpy as np
275 | >>> from modopt.base.transform import cube2matrix
276 | >>> a = np.arange(16).reshape((4, 2, 2))
277 | >>> cube2matrix(a)
278 | array([[ 0, 4, 8, 12],
279 | [ 1, 5, 9, 13],
280 | [ 2, 6, 10, 14],
281 | [ 3, 7, 11, 15]])
282 |
283 | See Also
284 | --------
285 | matrix2cube : complimentary function
286 |
287 | """
288 | return data_cube.reshape(
289 | [data_cube.shape[0], np.prod(data_cube.shape[1:])],
290 | ).T
291 |
292 |
293 | def matrix2cube(data_matrix, im_shape):
294 | """Matrix to Cube.
295 |
296 | This method transforms a 2D matrix to a 3D cube.
297 |
298 | Parameters
299 | ----------
300 | data_matrix : numpy.ndarray
301 | Input data cube, 2D array
302 | im_shape : tuple
303 | 2D shape of the individual images
304 |
305 | Returns
306 | -------
307 | numpy.ndarray
308 | 3D cube
309 |
310 | Examples
311 | --------
312 | >>> import numpy as np
313 | >>> from modopt.base.transform import matrix2cube
314 | >>> a = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14],
315 | ... [3, 7, 11, 15]])
316 | >>> matrix2cube(a, (2, 2))
317 | array([[[ 0, 1],
318 | [ 2, 3]],
319 |
320 | [[ 4, 5],
321 | [ 6, 7]],
322 |
323 | [[ 8, 9],
324 | [10, 11]],
325 |
326 | [[12, 13],
327 | [14, 15]]])
328 |
329 | See Also
330 | --------
331 | cube2matrix : complimentary function
332 |
333 | """
334 | return data_matrix.T.reshape([data_matrix.shape[1], *list(im_shape)])
335 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Python Template sphinx config
2 |
3 | # Import relevant modules
4 | import sys
5 | import os
6 | from importlib_metadata import metadata
7 |
8 | # If extensions (or modules to document with autodoc) are in another directory,
9 | # add these directories to sys.path here. If the directory is relative to the
10 | # documentation root, use os.path.abspath to make it absolute, like shown here.
11 | sys.path.insert(0, os.path.abspath("../.."))
12 |
13 | # -- General configuration ------------------------------------------------
14 |
15 | # General information about the project.
16 | project = "modopt"
17 |
18 | mdata = metadata(project)
19 | author = "Samuel Farrens, Pierre-Antoine Comby, Chaithya GR, Philippe Ciuciu"
20 | version = mdata["Version"]
21 | copyright = f"2020, {author}"
22 | gh_user = "sfarrens"
23 |
24 | # Add any Sphinx extension module names here, as strings. They can be
25 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
26 | # ones.
27 | extensions = [
28 | "sphinx.ext.autodoc",
29 | "sphinx.ext.autosummary",
30 | "sphinx.ext.coverage",
31 | "sphinx.ext.doctest",
32 | "sphinx.ext.ifconfig",
33 | "sphinx.ext.intersphinx",
34 | "sphinx.ext.mathjax",
35 | "sphinx.ext.napoleon",
36 | "sphinx.ext.todo",
37 | "sphinx.ext.viewcode",
38 | "sphinxawesome_theme.highlighting",
39 | "sphinxcontrib.bibtex",
40 | "myst_parser",
41 | "nbsphinx",
42 | "nbsphinx_link",
43 | "numpydoc",
44 | "sphinx_gallery.gen_gallery",
45 | ]
46 |
47 | # Include module names for objects
48 | add_module_names = False
49 |
50 | # Set class documentation standard.
51 | autoclass_content = "class"
52 |
53 | # Audodoc options
54 | autodoc_default_options = {
55 | "member-order": "bysource",
56 | "private-members": True,
57 | "show-inheritance": True,
58 | }
59 |
60 | # Generate summaries
61 | autosummary_generate = True
62 |
63 | # Suppress class members in toctree.
64 | numpydoc_show_class_members = False
65 |
66 | # The suffix(es) of source filenames.
67 | # You can specify multiple suffix as a list of string:
68 | source_suffix = [".rst", ".md"]
69 |
70 | # The master toctree document.
71 | master_doc = "index"
72 |
73 | # If true, sectionauthor and moduleauthor directives will be shown in the
74 | # output. They are ignored by default.
75 | show_authors = True
76 |
77 | # The name of the Pygments (syntax highlighting) style to use.
78 | pygments_style = "default"
79 |
80 | # If true, `todo` and `todoList` produce output, else they produce nothing.
81 | todo_include_todos = True
82 |
83 | # -- Options for HTML output ----------------------------------------------
84 |
85 | # The theme to use for HTML and HTML Help pages. See the documentation for
86 | # a list of builtin themes.
87 | html_theme = "sphinxawesome_theme"
88 | # html_theme = 'sphinx_book_theme'
89 |
90 | # Theme options are theme-specific and customize the look and feel of a theme
91 | # further. For a list of options available for each theme, see the
92 | # documentation.
93 | html_theme_options = {
94 | "nav_include_hidden": True,
95 | "show_nav": True,
96 | "show_breadcrumbs": True,
97 | "breadcrumbs_separator": "/",
98 | "show_prev_next": True,
99 | "show_scrolltop": True,
100 | }
101 | html_collapsible_definitions = True
102 | html_awesome_headerlinks = True
103 | html_logo = "modopt_logo.png"
104 | html_permalinks_icon = (
105 | ''
113 | )
114 | # The name for this set of Sphinx documents. If None, it defaults to
115 | # " v documentation".
116 | html_title = f"{project} v{version}"
117 |
118 | # A shorter title for the navigation bar. Default is the same as html_title.
119 | # html_short_title = None
120 |
121 | # The name of an image file (relative to this directory) to place at the top
122 | # of the sidebar.
123 | # html_logo = None
124 |
125 | # The name of an image file (within the static path) to use as favicon of the
126 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
127 | # pixels large.
128 | # html_favicon = None
129 |
130 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
131 | # using the given strftime format.
132 | html_last_updated_fmt = "%d %b, %Y"
133 |
134 | # If true, SmartyPants will be used to convert quotes and dashes to
135 | # typographically correct entities.
136 | html_use_smartypants = True
137 |
138 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
139 | html_show_sphinx = True
140 |
141 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
142 | html_show_copyright = True
143 |
144 |
145 | # -- Options for Sphinx Gallery ----------------------------------------------
146 |
147 | sphinx_gallery_conf = {
148 | "examples_dirs": ["../../examples/"],
149 | "filename_pattern": "/example_",
150 | "ignore_pattern": r"/(__init__|conftest)\.py",
151 | }
152 |
153 |
154 | # -- Options for nbshpinx output ------------------------------------------
155 |
156 |
157 | # Custom fucntion to find notebooks, create .nblink files and update the
158 | # notebooks.rst file
159 | def add_notebooks(nb_path="../../notebooks"):
160 |
161 | print("Looking for notebooks")
162 | nb_ext = ".ipynb"
163 | nb_rst_file_name = "notebooks.rst"
164 | nb_link_format = '{{\n "path": "{0}/{1}"\n}}'
165 |
166 | nbs = sorted([nb for nb in os.listdir(nb_path) if nb.endswith(nb_ext)])
167 |
168 | for list_pos, nb in enumerate(nbs):
169 |
170 | nb_name = nb.rstrip(nb_ext)
171 |
172 | nb_link_file_name = nb_name + ".nblink"
173 | print(f"Writing {nb_link_file_name}")
174 | with open(nb_link_file_name, "w") as nb_link_file:
175 | nb_link_file.write(nb_link_format.format(nb_path, nb))
176 |
177 | print(f"Looking for {nb_name} in {nb_rst_file_name}")
178 | with open(nb_rst_file_name) as nb_rst_file:
179 | check_name = nb_name not in nb_rst_file.read()
180 |
181 | if check_name:
182 | print(f"Adding {nb_name} to {nb_rst_file_name}")
183 | with open(nb_rst_file_name, "a") as nb_rst_file:
184 | if list_pos == 0:
185 | nb_rst_file.write("\n")
186 | nb_rst_file.write(f" {nb_name}\n")
187 |
188 | return nbs
189 |
190 |
191 | # Add notebooks
192 | add_notebooks()
193 |
194 | binder = "https://mybinder.org/v2/gh"
195 | binder_badge = "https://mybinder.org/badge_logo.svg"
196 | github = "https://github.com/"
197 | github_badge = "https://badgen.net/badge/icon/github?icon=github&label"
198 |
199 | # Remove promts and add binder badge
200 | nb_header_pt1 = r"""
201 | {% if env.metadata[env.docname]['nbsphinx-link-target'] %}
202 | {% set docpath = env.metadata[env.docname]['nbsphinx-link-target'] %}
203 | {% else %}
204 | {% set docpath = env.doc2path(env.docname, base='docs/source/') %}
205 | {% endif %}
206 |
207 | .. raw:: html
208 |
209 |
215 |
216 | """
217 | nb_header_pt2 = (
218 | r""" """
223 | r"""
"""
228 | )
229 |
230 | nbsphinx_prolog = nb_header_pt1 + nb_header_pt2
231 |
232 | # -- Intersphinx Mapping ----------------------------------------------
233 |
234 | # Refer to the package libraries for type definitions
235 | intersphinx_mapping = {
236 | "python": ("http://docs.python.org/3", None),
237 | "numpy": ("https://numpy.org/doc/stable/", None),
238 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
239 | "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None),
240 | "matplotlib": ("https://matplotlib.org", None),
241 | "astropy": ("http://docs.astropy.org/en/latest/", None),
242 | "cupy": ("https://docs-cupy.chainer.org/en/stable/", None),
243 | "torch": ("https://pytorch.org/docs/stable/", None),
244 | "sklearn": (
245 | "http://scikit-learn.org/stable",
246 | (None, "./_intersphinx/sklearn-objects.inv"),
247 | ),
248 | "tensorflow": (
249 | "https://www.tensorflow.org/api_docs/python",
250 | (
251 | "https://github.com/GPflow/tensorflow-intersphinx/"
252 | + "raw/master/tf2_py_objects.inv"
253 | ),
254 | ),
255 | }
256 |
257 | # -- BibTeX Setting ----------------------------------------------
258 |
259 | bibtex_bibfiles = ["refs.bib", "my_ref.bib"]
260 | bibtex_default_style = "alpha"
261 |
--------------------------------------------------------------------------------
/src/modopt/signal/svd.py:
--------------------------------------------------------------------------------
1 | """SVD ROUTINES.
2 |
3 | This module contains methods for thresholding singular values.
4 |
5 | :Author: Samuel Farrens
6 |
7 | """
8 |
9 | import numpy as np
10 | from scipy.linalg import svd
11 | from scipy.sparse.linalg import svds
12 |
13 | from modopt.base.transform import matrix2cube
14 | from modopt.interface.errors import warn
15 | from modopt.math.convolve import convolve
16 | from modopt.signal.noise import thresh
17 |
18 |
19 | def find_n_pc(u_vec, factor=0.5):
20 | """Find number of principal components.
21 |
22 | This method finds the minimum number of principal components required.
23 |
24 | Parameters
25 | ----------
26 | u_vec : numpy.ndarray
27 | Left singular vector of the original data
28 | factor : float, optional
29 | Factor for testing the auto correlation (default is ``0.5``)
30 |
31 | Returns
32 | -------
33 | int
34 | Number of principal components
35 |
36 | Raises
37 | ------
38 | ValueError
39 | Invalid left singular vector
40 |
41 | Examples
42 | --------
43 | >>> import numpy as np
44 | >>> from scipy.linalg import svd
45 | >>> from modopt.signal.svd import find_n_pc
46 | >>> x = np.arange(18).reshape(9, 2).astype(float)
47 | >>> find_n_pc(svd(x)[0])
48 | 1
49 |
50 | """
51 | if np.sqrt(u_vec.shape[0]) % 1:
52 | raise ValueError(
53 | "Invalid left singular vector. The size of the first "
54 | + "dimenion of ``u_vec`` must be perfect square.",
55 | )
56 |
57 | # Get the shape of the array
58 | array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
59 |
60 | # Find the auto correlation of the left singular vector.
61 | u_auto = [
62 | convolve(
63 | elem.reshape(array_shape),
64 | np.rot90(elem.reshape(array_shape), 2),
65 | )
66 | for elem in u_vec.T
67 | ]
68 |
69 | # Return the required number of principal components.
70 | return np.sum(
71 | [
72 | (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),)
73 | for u_val in u_auto
74 | ]
75 | )
76 |
77 |
78 | def calculate_svd(input_data):
79 | """Calculate Singular Value Decomposition.
80 |
81 | This method calculates the Singular Value Decomposition (SVD) of the input
82 | data using SciPy.
83 |
84 | Parameters
85 | ----------
86 | input_data : numpy.ndarray
87 | Input data array, 2D matrix
88 |
89 | Returns
90 | -------
91 | tuple
92 | Left singular vector, singular values and right singular vector
93 |
94 | Raises
95 | ------
96 | TypeError
97 | For invalid data type
98 |
99 | """
100 | if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
101 | raise TypeError("Input data must be a 2D np.ndarray.")
102 |
103 | return svd(
104 | input_data,
105 | check_finite=False,
106 | lapack_driver="gesvd",
107 | full_matrices=False,
108 | )
109 |
110 |
111 | def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"):
112 | """Threshold the singular values.
113 |
114 | This method thresholds the input data using singular value decomposition.
115 |
116 | Parameters
117 | ----------
118 | input_data : numpy.ndarray
119 | Input data array, 2D matrix
120 | threshold : float or numpy.ndarray, optional
121 | Threshold value(s) (default is ``None``)
122 | n_pc : int or str, optional
123 | Number of principal components, specify an integer value or ``'all'``
124 | (default is ``None``)
125 | thresh_type : {'hard', 'soft'}, optional
126 | Type of thresholding (default is ``'hard'``)
127 |
128 | Returns
129 | -------
130 | numpy.ndarray
131 | Thresholded data
132 |
133 | Raises
134 | ------
135 | ValueError
136 | For invalid n_pc value
137 |
138 | Examples
139 | --------
140 | >>> import numpy as np
141 | >>> from modopt.signal.svd import svd_thresh
142 | >>> x = np.arange(18).reshape(9, 2).astype(float)
143 | >>> svd_thresh(x, n_pc=1)
144 | array([[ 0.49815487, 0.54291537],
145 | [ 2.40863386, 2.62505584],
146 | [ 4.31911286, 4.70719631],
147 | [ 6.22959185, 6.78933678],
148 | [ 8.14007085, 8.87147725],
149 | [10.05054985, 10.95361772],
150 | [11.96102884, 13.03575819],
151 | [13.87150784, 15.11789866],
152 | [15.78198684, 17.20003913]])
153 |
154 | """
155 | less_than_zero = isinstance(n_pc, int) and n_pc <= 0
156 | str_not_all = isinstance(n_pc, str) and n_pc != "all"
157 |
158 | if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all:
159 | raise ValueError(
160 | 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"',
161 | )
162 |
163 | # Get SVD of input data.
164 | u_vec, s_values, v_vec = calculate_svd(input_data)
165 |
166 | # Find the threshold if not provided.
167 | if isinstance(threshold, type(None)):
168 | # Find the required number of principal components if not specified.
169 | if isinstance(n_pc, type(None)):
170 | n_pc = find_n_pc(u_vec, factor=0.1)
171 | print("xxxx", n_pc, u_vec)
172 |
173 | # If the number of PCs is too large use all of the singular values.
174 | if (isinstance(n_pc, int) and n_pc >= s_values.size) or (
175 | isinstance(n_pc, str) and n_pc == "all"
176 | ):
177 | n_pc = s_values.size
178 | warn("Using all singular values.")
179 |
180 | threshold = s_values[n_pc - 1]
181 |
182 | # Threshold the singular values.
183 | s_new = thresh(s_values, threshold, thresh_type)
184 |
185 | if np.all(s_new == s_values):
186 | warn("No change to singular values.")
187 |
188 | # Diagonalize the svd
189 | s_new = np.diag(s_new)
190 |
191 | # Return the thresholded data.
192 | return np.dot(u_vec, np.dot(s_new, v_vec))
193 |
194 |
195 | def svd_thresh_coef_fast(
196 | input_data,
197 | threshold,
198 | n_vals=-1,
199 | extra_vals=5,
200 | thresh_type="hard",
201 | ):
202 | """Threshold the singular values coefficients.
203 |
204 | This method thresholds the input data by using singular value
205 | decomposition, but only computing the the greastest ``n_vals``
206 | values.
207 |
208 | Parameters
209 | ----------
210 | input_data : numpy.ndarray
211 | Input data array, 2D matrix
212 | Operator class instance
213 | threshold : float or numpy.ndarray
214 | Threshold value(s)
215 | n_vals: int, optional
216 | Number of singular values to compute.
217 | If None, compute all singular values.
218 | extra_vals: int, optional
219 | If the number of values computed is not enough to perform thresholding,
220 | recompute by using ``n_vals + extra_vals`` (default is ``5``)
221 | thresh_type : {'hard', 'soft'}
222 | Type of noise to be added (default is ``'hard'``)
223 |
224 | Returns
225 | -------
226 | tuple
227 | The thresholded data (numpy.ndarray) and the estimated rank after
228 | thresholding (int)
229 | """
230 | if n_vals == -1:
231 | n_vals = min(input_data.shape) - 1
232 | ok = False
233 | while not ok:
234 | (u_vec, s_values, v_vec) = svds(input_data, k=n_vals)
235 | ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1
236 | n_vals = min(n_vals + extra_vals, *input_data.shape)
237 |
238 | s_values = thresh(
239 | s_values,
240 | threshold,
241 | threshold_type=thresh_type,
242 | )
243 | rank = np.count_nonzero(s_values)
244 | return (
245 | np.dot(
246 | u_vec[:, -rank:] * s_values[-rank:],
247 | v_vec[-rank:, :],
248 | ),
249 | rank,
250 | )
251 |
252 |
253 | def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"):
254 | """Threshold the singular values coefficients.
255 |
256 | This method thresholds the input data using singular value decomposition.
257 |
258 | Parameters
259 | ----------
260 | input_data : numpy.ndarray
261 | Input data array, 2D matrix
262 | operator : class
263 | Operator class instance
264 | threshold : float or numpy.ndarray
265 | Threshold value(s)
266 | thresh_type : {'hard', 'soft'}
267 | Type of noise to be added (default is ``'hard'``)
268 |
269 | Returns
270 | -------
271 | numpy.ndarray
272 | Thresholded data
273 |
274 | Raises
275 | ------
276 | TypeError
277 | If operator not callable
278 |
279 | """
280 | if not callable(operator):
281 | raise TypeError("Operator must be a callable function.")
282 |
283 | # Get SVD of data matrix
284 | u_vec, s_values, v_vec = calculate_svd(input_data)
285 |
286 | # Diagnalise s
287 | s_values = np.diag(s_values)
288 |
289 | # Compute coefficients
290 | a_matrix = np.dot(s_values, v_vec)
291 |
292 | # Get the shape of the array
293 | array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
294 |
295 | # Compute threshold matrix.
296 | ti = np.array(
297 | [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))]
298 | )
299 | threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape)
300 |
301 | # Threshold coefficients.
302 | a_new = thresh(a_matrix, threshold, thresh_type)
303 |
304 | # Return the thresholded image.
305 | return np.dot(u_vec, a_new)
306 |
--------------------------------------------------------------------------------
/tests/test_algorithms.py:
--------------------------------------------------------------------------------
1 | """UNIT TESTS FOR Algorithms.
2 |
3 | This module contains unit tests for the modopt.opt module.
4 |
5 | :Authors:
6 | Samuel Farrens
7 | Pierre-Antoine Comby
8 | """
9 |
10 | import numpy as np
11 | import numpy.testing as npt
12 | from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
13 | from pytest_cases import (
14 | fixture,
15 | parametrize,
16 | parametrize_with_cases,
17 | )
18 |
19 |
20 | SKLEARN_AVAILABLE = True
21 | try:
22 | import sklearn
23 | except ImportError:
24 | SKLEARN_AVAILABLE = False
25 |
26 |
27 | rng = np.random.default_rng()
28 |
29 |
30 | @fixture
31 | def idty():
32 | """Identity function."""
33 | return lambda x: x
34 |
35 |
36 | @fixture
37 | def reweight_op():
38 | """Reweight operator."""
39 | data3 = np.arange(9).reshape(3, 3).astype(float) + 1
40 | return reweight.cwbReweight(data3)
41 |
42 |
43 | def build_kwargs(kwargs, use_metrics):
44 | """Build the kwargs for each algorithm, replacing placeholders by true values.
45 |
46 | This function has to be call for each test, as direct parameterization somehow
47 | is not working with pytest-xdist and pytest-cases.
48 | It also adds dummy metric measurement to validate the metric api.
49 | """
50 | update_value = {
51 | "idty": lambda x: x,
52 | "lin_idty": linear.Identity(),
53 | "reweight_op": reweight.cwbReweight(
54 | np.arange(9).reshape(3, 3).astype(float) + 1
55 | ),
56 | }
57 | new_kwargs = dict()
58 | print(kwargs)
59 | # update the value of the dict is possible.
60 | for key in kwargs:
61 | new_kwargs[key] = update_value.get(kwargs[key], kwargs[key])
62 |
63 | if use_metrics:
64 | new_kwargs["linear"] = linear.Identity()
65 | new_kwargs["metrics"] = {
66 | "diff": {
67 | "metric": lambda test, ref: np.sum(test - ref),
68 | "mapping": {"x_new": "test"},
69 | "cst_kwargs": {"ref": np.arange(9).reshape((3, 3))},
70 | "early_stopping": False,
71 | }
72 | }
73 |
74 | return new_kwargs
75 |
76 |
77 | @parametrize(use_metrics=[True, False])
78 | class AlgoCases:
79 | r"""Cases for algorithms.
80 |
81 | Most of the test solves the trivial problem
82 |
83 | .. math::
84 | \\min_x \\frac{1}{2} \\| y - x \\|_2^2 \\quad\\text{s.t.} x \\geq 0
85 |
86 | More complex and concrete usecases are shown in examples.
87 | """
88 |
89 | data1 = np.arange(9).reshape(3, 3).astype(float)
90 | data2 = data1 + rng.standard_normal(data1.shape) * 1e-6
91 | max_iter = 20
92 |
93 | @parametrize(
94 | kwargs=[
95 | {"beta_update": "idty", "auto_iterate": False, "cost": None},
96 | {"beta_update": "idty"},
97 | {"cost": None, "lambda_update": None},
98 | {"beta_update": "idty", "a_cd": 3},
99 | {"beta_update": "idty", "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7},
100 | {"restart_strategy": "adaptive", "xi_restart": 0.9},
101 | {
102 | "restart_strategy": "greedy",
103 | "xi_restart": 0.9,
104 | "min_beta": 1.0,
105 | "s_greedy": 1.1,
106 | },
107 | ]
108 | )
109 | def case_forward_backward(self, kwargs, idty, use_metrics):
110 | """Forward Backward case."""
111 | update_kwargs = build_kwargs(kwargs, use_metrics)
112 | algo = algorithms.ForwardBackward(
113 | self.data1,
114 | grad=gradient.GradBasic(self.data1, idty, idty),
115 | prox=proximity.Positivity(),
116 | **update_kwargs,
117 | )
118 | if update_kwargs.get("auto_iterate", None) is False:
119 | algo.iterate(self.max_iter)
120 | return algo, update_kwargs
121 |
122 | @parametrize(
123 | kwargs=[
124 | {
125 | "cost": None,
126 | "auto_iterate": False,
127 | "gamma_update": "idty",
128 | "beta_update": "idty",
129 | },
130 | {"gamma_update": "idty", "lambda_update": "idty"},
131 | {"cost": True},
132 | {"cost": True, "step_size": 2},
133 | ]
134 | )
135 | def case_gen_forward_backward(self, kwargs, use_metrics, idty):
136 | """General FB setup."""
137 | update_kwargs = build_kwargs(kwargs, use_metrics)
138 | grad_inst = gradient.GradBasic(self.data1, idty, idty)
139 | prox_inst = proximity.Positivity()
140 | prox_dual_inst = proximity.IdentityProx()
141 | if update_kwargs.get("cost", None) is True:
142 | update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
143 | algo = algorithms.GenForwardBackward(
144 | self.data1,
145 | grad=grad_inst,
146 | prox_list=[prox_inst, prox_dual_inst],
147 | **update_kwargs,
148 | )
149 | if update_kwargs.get("auto_iterate", None) is False:
150 | algo.iterate(self.max_iter)
151 | return algo, update_kwargs
152 |
153 | @parametrize(
154 | kwargs=[
155 | {
156 | "sigma_dual": "idty",
157 | "tau_update": "idty",
158 | "rho_update": "idty",
159 | "auto_iterate": False,
160 | },
161 | {
162 | "sigma_dual": "idty",
163 | "tau_update": "idty",
164 | "rho_update": "idty",
165 | },
166 | {
167 | "linear": "lin_idty",
168 | "cost": True,
169 | "reweight": "reweight_op",
170 | },
171 | ]
172 | )
173 | def case_condat(self, kwargs, use_metrics, idty):
174 | """Condat Vu Algorithm setup."""
175 | update_kwargs = build_kwargs(kwargs, use_metrics)
176 | grad_inst = gradient.GradBasic(self.data1, idty, idty)
177 | prox_inst = proximity.Positivity()
178 | prox_dual_inst = proximity.IdentityProx()
179 | if update_kwargs.get("cost", None) is True:
180 | update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
181 |
182 | algo = algorithms.Condat(
183 | self.data1,
184 | self.data2,
185 | grad=grad_inst,
186 | prox=prox_inst,
187 | prox_dual=prox_dual_inst,
188 | **update_kwargs,
189 | )
190 | if update_kwargs.get("auto_iterate", None) is False:
191 | algo.iterate(self.max_iter)
192 | return algo, update_kwargs
193 |
194 | @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}])
195 | def case_pogm(self, kwargs, use_metrics, idty):
196 | """POGM setup."""
197 | update_kwargs = build_kwargs(kwargs, use_metrics)
198 | grad_inst = gradient.GradBasic(self.data1, idty, idty)
199 | prox_inst = proximity.Positivity()
200 | algo = algorithms.POGM(
201 | u=self.data1,
202 | x=self.data1,
203 | y=self.data1,
204 | z=self.data1,
205 | grad=grad_inst,
206 | prox=prox_inst,
207 | **update_kwargs,
208 | )
209 |
210 | if update_kwargs.get("auto_iterate", None) is False:
211 | algo.iterate(self.max_iter)
212 | return algo, update_kwargs
213 |
214 | @parametrize(
215 | GradDescent=[
216 | algorithms.VanillaGenericGradOpt,
217 | algorithms.AdaGenericGradOpt,
218 | algorithms.ADAMGradOpt,
219 | algorithms.MomentumGradOpt,
220 | algorithms.RMSpropGradOpt,
221 | algorithms.SAGAOptGradOpt,
222 | ]
223 | )
224 | def case_grad(self, GradDescent, use_metrics, idty):
225 | """Gradient Descent algorithm test."""
226 | update_kwargs = build_kwargs({}, use_metrics)
227 | grad_inst = gradient.GradBasic(self.data1, idty, idty)
228 | prox_inst = proximity.Positivity()
229 | cost_inst = cost.costObj([grad_inst, prox_inst])
230 |
231 | algo = GradDescent(
232 | self.data1,
233 | grad=grad_inst,
234 | prox=prox_inst,
235 | cost=cost_inst,
236 | **update_kwargs,
237 | )
238 | algo.iterate()
239 | return algo, update_kwargs
240 |
241 | @parametrize(admm=[algorithms.ADMM, algorithms.FastADMM])
242 | def case_admm(self, admm, use_metrics, idty):
243 | """ADMM setup."""
244 |
245 | def optim1(init, obs):
246 | return obs
247 |
248 | def optim2(init, obs):
249 | return obs
250 |
251 | update_kwargs = build_kwargs({}, use_metrics)
252 | algo = admm(
253 | u=self.data1,
254 | v=self.data1,
255 | mu=np.zeros_like(self.data1),
256 | A=linear.Identity(),
257 | B=linear.Identity(),
258 | b=self.data1,
259 | optimizers=(optim1, optim2),
260 | **update_kwargs,
261 | )
262 | algo.iterate()
263 | return algo, update_kwargs
264 |
265 |
266 | @parametrize_with_cases("algo, kwargs", cases=AlgoCases)
267 | def test_algo(algo, kwargs):
268 | """Test algorithms."""
269 | if kwargs.get("auto_iterate") is False:
270 | # algo already run
271 | npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1)
272 | else:
273 | npt.assert_almost_equal(algo.x_final, AlgoCases.data1)
274 |
275 | if kwargs.get("metrics"):
276 | print(algo.metrics)
277 | npt.assert_almost_equal(algo.metrics["diff"]["values"][-1], 0, 3)
278 |
--------------------------------------------------------------------------------
/src/modopt/opt/algorithms/primal_dual.py:
--------------------------------------------------------------------------------
1 | """Primal-Dual Algorithms."""
2 |
3 | from modopt.opt.algorithms.base import SetUp
4 | from modopt.opt.cost import costObj
5 | from modopt.opt.linear import Identity
6 |
7 |
8 | class Condat(SetUp):
9 | r"""Condat optimisation.
10 |
11 | This class implements algorithm 3.1 from :cite:`condat2013`.
12 |
13 | Parameters
14 | ----------
15 | x : numpy.ndarray
16 | Initial guess for the primal variable
17 | y : numpy.ndarray
18 | Initial guess for the dual variable
19 | grad
20 | Gradient operator class instance
21 | prox
22 | Proximity primal operator class instance
23 | prox_dual
24 | Proximity dual operator class instance
25 | linear : class instance, optional
26 | Linear operator class instance (default is ``None``)
27 | cost : class instance or str, optional
28 | Cost function class instance (default is ``'auto'``); Use ``'auto'`` to
29 | automatically generate a ``costObj`` instance
30 | reweight : class instance, optional
31 | Reweighting class instance
32 | rho : float, optional
33 | Relaxation parameter, :math:`\rho` (default is ``0.5``)
34 | sigma : float, optional
35 | Proximal dual parameter, :math:`\sigma` (default is ``1.0``)
36 | tau : float, optional
37 | Proximal primal paramater, :math:`\tau` (default is ``1.0``)
38 | rho_update : callable, optional
39 | Relaxation parameter update method (default is ``None``)
40 | sigma_update : callable, optional
41 | Proximal dual parameter update method (default is ``None``)
42 | tau_update : callable, optional
43 | Proximal primal parameter update method (default is ``None``)
44 | auto_iterate : bool, optional
45 | Option to automatically begin iterations upon initialisation (default
46 | is ``True``)
47 | max_iter : int, optional
48 | Maximum number of iterations (default is ``150``)
49 | n_rewightings : int, optional
50 | Number of reweightings to perform (default is ``1``)
51 |
52 | Notes
53 | -----
54 | The ``tau_param`` can also be set using the keyword `step_size`, which will
55 | override the value of ``tau_param``.
56 |
57 | The following state variable are available for metrics measurememts at
58 | each iteration :
59 |
60 | * ``'x_new'`` : new estimate of :math:`x` (primal variable)
61 | * ``'y_new'`` : new estimate of :math:`y` (dual variable)
62 | * ``'idx'`` : index of the iteration.
63 |
64 | See Also
65 | --------
66 | modopt.opt.algorithms.base.SetUp : parent class
67 | modopt.opt.cost.costObj : cost object class
68 | modopt.opt.gradient : gradient operator classes
69 | modopt.opt.proximity : proximity operator classes
70 | modopt.opt.linear : linear operator classes
71 | modopt.opt.reweight : reweighting classes
72 |
73 | """
74 |
75 | def __init__(
76 | self,
77 | x,
78 | y,
79 | grad,
80 | prox,
81 | prox_dual,
82 | linear=None,
83 | cost="auto",
84 | reweight=None,
85 | rho=0.5,
86 | sigma=1.0,
87 | tau=1.0,
88 | rho_update=None,
89 | sigma_update=None,
90 | tau_update=None,
91 | auto_iterate=True,
92 | max_iter=150,
93 | n_rewightings=1,
94 | metric_call_period=5,
95 | metrics=None,
96 | **kwargs,
97 | ):
98 | # Set default algorithm properties
99 | super().__init__(
100 | metric_call_period=metric_call_period,
101 | metrics=metrics,
102 | **kwargs,
103 | )
104 |
105 | # Set the initial variable values
106 | for input_data in (x, y):
107 | self._check_input_data(input_data)
108 |
109 | self._x_old = self.xp.copy(x)
110 | self._y_old = self.xp.copy(y)
111 |
112 | # Set the algorithm operators
113 | for operator in (grad, prox, prox_dual, linear, cost):
114 | self._check_operator(operator)
115 |
116 | self._grad = grad
117 | self._prox = prox
118 | self._prox_dual = prox_dual
119 | self._reweight = reweight
120 | if isinstance(linear, type(None)):
121 | self._linear = Identity()
122 | else:
123 | self._linear = linear
124 | if cost == "auto":
125 | self._cost_func = costObj(
126 | [
127 | self._grad,
128 | self._prox,
129 | self._prox_dual,
130 | ]
131 | )
132 | else:
133 | self._cost_func = cost
134 |
135 | # Set the algorithm parameters
136 | for param_val in (rho, sigma, tau):
137 | self._check_param(param_val)
138 |
139 | self._rho = rho
140 | self._sigma = sigma
141 | self._tau = self.step_size or tau
142 |
143 | # Set the algorithm parameter update methods
144 | for param_update in (rho_update, sigma_update, tau_update):
145 | self._check_param_update(param_update)
146 |
147 | self._rho_update = rho_update
148 | self._sigma_update = sigma_update
149 | self._tau_update = tau_update
150 |
151 | # Automatically run the algorithm
152 | if auto_iterate:
153 | self.iterate(max_iter=max_iter, n_rewightings=n_rewightings)
154 |
155 | def _update_param(self):
156 | """Update parameters.
157 |
158 | This method updates the values of the algorthm parameters with the
159 | methods provided.
160 |
161 | """
162 | # Update relaxation parameter.
163 | if not isinstance(self._rho_update, type(None)):
164 | self._rho = self._rho_update(self._rho)
165 |
166 | # Update proximal dual parameter.
167 | if not isinstance(self._sigma_update, type(None)):
168 | self._sigma = self._sigma_update(self._sigma)
169 |
170 | # Update proximal primal parameter.
171 | if not isinstance(self._tau_update, type(None)):
172 | self._tau = self._tau_update(self._tau)
173 |
174 | def _update(self):
175 | """Update.
176 |
177 | This method updates the current reconstruction.
178 |
179 | Notes
180 | -----
181 | Implements equation 9 (algorithm 3.1) from :cite:`condat2013`.
182 |
183 | - Primal proximity operator set up for positivity constraint.
184 |
185 | """
186 | # Step 1 from eq.9.
187 | self._grad.get_grad(self._x_old)
188 |
189 | x_prox = self._prox.op(
190 | self._x_old
191 | - self._tau * self._grad.grad
192 | - self._tau * self._linear.adj_op(self._y_old),
193 | )
194 |
195 | # Step 2 from eq.9.
196 | y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old)
197 |
198 | y_prox = y_temp - self._sigma * self._prox_dual.op(
199 | y_temp / self._sigma,
200 | extra_factor=(1.0 / self._sigma),
201 | )
202 |
203 | # Step 3 from eq.9.
204 | self._x_new = self._rho * x_prox + (1 - self._rho) * self._x_old
205 | self._y_new = self._rho * y_prox + (1 - self._rho) * self._y_old
206 |
207 | del x_prox, y_prox, y_temp
208 |
209 | # Update old values for next iteration.
210 | self.xp.copyto(self._x_old, self._x_new)
211 | self.xp.copyto(self._y_old, self._y_new)
212 |
213 | # Update parameter values for next iteration.
214 | self._update_param()
215 |
216 | # Test cost function for convergence.
217 | if self._cost_func:
218 | self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
219 | self._x_new, self._y_new
220 | )
221 |
222 | def iterate(self, max_iter=150, n_rewightings=1, progbar=None):
223 | """Iterate.
224 |
225 | This method calls update until either convergence criteria is met or
226 | the maximum number of iterations is reached.
227 |
228 | Parameters
229 | ----------
230 | max_iter : int, optional
231 | Maximum number of iterations (default is ``150``)
232 | n_rewightings : int, optional
233 | Number of reweightings to perform (default is ``1``)
234 | progbar: tqdm.tqdm
235 | Progress bar handle (default is ``None``)
236 | """
237 | self._run_alg(max_iter, progbar)
238 |
239 | if not isinstance(self._reweight, type(None)):
240 | for _ in range(n_rewightings):
241 | self._reweight.reweight(self._linear.op(self._x_new))
242 | if progbar:
243 | progbar.reset(total=max_iter)
244 | self._run_alg(max_iter, progbar)
245 |
246 | # retrieve metrics results
247 | self.retrieve_outputs()
248 | # rename outputs as attributes
249 | self.x_final = self._x_new
250 | self.y_final = self._y_new
251 |
252 | def get_notify_observers_kwargs(self):
253 | """Notify observers.
254 |
255 | Return the mapping between the metrics call and the iterated
256 | variables.
257 |
258 | Returns
259 | -------
260 | notify_observers_kwargs : dict,
261 | The mapping between the iterated variables
262 |
263 | """
264 | return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx}
265 |
266 | def retrieve_outputs(self):
267 | """Retrieve outputs.
268 |
269 | Declare the outputs of the algorithms as attributes: ``x_final``,
270 | ``y_final``, ``metrics``.
271 |
272 | """
273 | metrics = {}
274 | for obs in self._observers["cv_metrics"]:
275 | metrics[obs.name] = obs.retrieve_metrics()
276 | self.metrics = metrics
277 |
--------------------------------------------------------------------------------