├── spyrit ├── hadamard_matrix │ ├── __init__.py │ ├── create_hadamard_matrix_with_sage.py │ └── download_hadamard_matrix.py ├── core │ └── __init__.py ├── external │ └── __init__.py ├── __init__.py ├── misc │ ├── __init__.py │ ├── examples.py │ ├── data_visualisation.py │ ├── matrix_tools.py │ ├── load_data.py │ ├── color.py │ └── metrics.py └── dev │ ├── prep.py │ └── meas.py ├── docs ├── source │ ├── fig │ │ ├── lpgd.png │ │ ├── tuto1.png │ │ ├── tuto2.png │ │ ├── tuto6.png │ │ ├── tuto9.png │ │ ├── drunet.png │ │ ├── pinvnet.png │ │ ├── direct_net.png │ │ ├── tuto3_pinv.png │ │ ├── pinvnet_cnn.png │ │ ├── tuto5_dcnet.png │ │ ├── spi_principle.png │ │ └── tuto4_pinvnet.png │ ├── _templates │ │ ├── spyrit-method-template.rst │ │ ├── spyrit-class-template.rst │ │ └── spyrit-module-template.rst │ ├── gallery │ │ ├── images │ │ │ ├── sphx_glr_tuto_a00_connect_deepinv_001.png │ │ │ ├── sphx_glr_tuto_a00_connect_deepinv_002.png │ │ │ ├── sphx_glr_tuto_a00_connect_deepinv_003.png │ │ │ ├── sphx_glr_tuto_a00_connect_deepinv_004.png │ │ │ ├── sphx_glr_tuto_a00_connect_deepinv_005.png │ │ │ ├── sphx_glr_tuto_a00_connect_deepinv_006.png │ │ │ ├── sphx_glr_tuto_a00_connect_deepinv_007.png │ │ │ └── thumb │ │ │ │ └── sphx_glr_tuto_a00_connect_deepinv_thumb.png │ │ └── tuto_a00_connect_deepinv.rst │ ├── _static │ │ └── css │ │ │ └── sg_README.css │ ├── external_libraries.rst │ ├── build-tuto-local.md │ ├── organisation.rst │ ├── sg_execution_times.rst │ ├── index.rst │ ├── single_pixel.rst │ └── conf.py ├── Makefile └── make.bat ├── AUTHORS ├── tutorial ├── images │ └── test │ │ ├── ILSVRC2012_test_00000001.jpeg │ │ ├── ILSVRC2012_test_00000002.jpeg │ │ ├── ILSVRC2012_test_00000003.jpeg │ │ ├── ILSVRC2012_test_00000004.jpeg │ │ ├── ILSVRC2012_test_00000005.jpeg │ │ ├── ILSVRC2012_test_00000006.jpeg │ │ └── ILSVRC2012_test_00000007.jpeg ├── wip │ ├── _tuto_bonus_advanced_methods_colab.py │ ├── tuto_a00_connect_deepinv.py │ ├── _tuto_08_lpgd_split_measurements.py │ ├── _tuto_05_recon_hadamSplit.py │ └── _tuto_06_dcnet_split_measurements.py ├── README.txt ├── tuto_02_noise.py ├── tuto_01_a_acquisition_operators.py ├── tuto_01_b_splitting.py ├── tuto_03_pseudoinverse_linear.py ├── tuto_04_pseudoinverse_cnn_linear.py └── tuto_04_b_train_pseudoinverse_cnn_linear.py ├── requirements.txt ├── .pre-commit-config.yaml ├── .gitignore ├── .readthedocs.yml ├── Codemeta.json ├── pyproject.toml ├── README.md ├── .github └── workflows │ └── main.yml ├── LICENSE.md └── CHANGELOG.md /spyrit/hadamard_matrix/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/fig/lpgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/lpgd.png -------------------------------------------------------------------------------- /docs/source/fig/tuto1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/tuto1.png -------------------------------------------------------------------------------- /docs/source/fig/tuto2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/tuto2.png -------------------------------------------------------------------------------- /docs/source/fig/tuto6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/tuto6.png -------------------------------------------------------------------------------- /docs/source/fig/tuto9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/tuto9.png -------------------------------------------------------------------------------- /docs/source/fig/drunet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/drunet.png -------------------------------------------------------------------------------- /docs/source/fig/pinvnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/pinvnet.png -------------------------------------------------------------------------------- /spyrit/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core module for Spyrit package, containing the main classes and functions.""" 2 | -------------------------------------------------------------------------------- /docs/source/fig/direct_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/direct_net.png -------------------------------------------------------------------------------- /docs/source/fig/tuto3_pinv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/tuto3_pinv.png -------------------------------------------------------------------------------- /docs/source/fig/pinvnet_cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/pinvnet_cnn.png -------------------------------------------------------------------------------- /docs/source/fig/tuto5_dcnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/tuto5_dcnet.png -------------------------------------------------------------------------------- /docs/source/fig/spi_principle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/spi_principle.png -------------------------------------------------------------------------------- /docs/source/fig/tuto4_pinvnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/fig/tuto4_pinvnet.png -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Nicolas Ducros 2 | Romain Phan 3 | Thomas Baudier 4 | Juan Abascal 5 | Fadoua Taia-Alaoui 6 | Claire Mouton 7 | Guilherme Beneti 8 | -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000001.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/tutorial/images/test/ILSVRC2012_test_00000001.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000002.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/tutorial/images/test/ILSVRC2012_test_00000002.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000003.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/tutorial/images/test/ILSVRC2012_test_00000003.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000004.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/tutorial/images/test/ILSVRC2012_test_00000004.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000005.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/tutorial/images/test/ILSVRC2012_test_00000005.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000006.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/tutorial/images/test/ILSVRC2012_test_00000006.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000007.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/tutorial/images/test/ILSVRC2012_test_00000007.jpeg -------------------------------------------------------------------------------- /docs/source/_templates/spyrit-method-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. automethod:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_001.png -------------------------------------------------------------------------------- /docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_002.png -------------------------------------------------------------------------------- /docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_003.png -------------------------------------------------------------------------------- /docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_004.png -------------------------------------------------------------------------------- /docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_005.png -------------------------------------------------------------------------------- /docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_006.png -------------------------------------------------------------------------------- /docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/gallery/images/sphx_glr_tuto_a00_connect_deepinv_007.png -------------------------------------------------------------------------------- /spyrit/external/__init__.py: -------------------------------------------------------------------------------- 1 | """This module uses a modified version of the Unet presented in https://github.com/cszn/DPIR/blob/master/models/network_unet.py""" 2 | 3 | # from . import drunet 4 | -------------------------------------------------------------------------------- /docs/source/gallery/images/thumb/sphx_glr_tuto_a00_connect_deepinv_thumb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/HEAD/docs/source/gallery/images/thumb/sphx_glr_tuto_a00_connect_deepinv_thumb.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | scipy 4 | torch 5 | torchvision 6 | Pillow 7 | opencv-python 8 | imutils 9 | PyWavelets 10 | wget 11 | sympy 12 | imageio 13 | astropy 14 | tensorboard 15 | sphinx_gallery 16 | sphinx_rtd_theme 17 | girder-client 18 | gdown==v4.6.3 19 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - repo: https://github.com/psf/black-pre-commit-mirror 8 | rev: 25.9.0 9 | hooks: 10 | - id: black 11 | ci: 12 | autofix_commit_msg: | 13 | [pre-commit.ci] Automatic python formatting 14 | autofix_prs: true 15 | submodules: false 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.py[co] 3 | *.swp 4 | *.pdf 5 | *.png 6 | *.sh 7 | *.ipynb 8 | *.npy 9 | *.npz 10 | 11 | #folders 12 | data 13 | img 14 | models 15 | dist 16 | build 17 | spyrit.egg-info 18 | stats_walsh 19 | **/.ipynb_checkpoints/* 20 | docs/source/_build* 21 | model/ 22 | runs/ 23 | spyrit/drunet/ 24 | !spyrit/images/tuto/*.png 25 | docs/source/html 26 | docs/source/_autosummary 27 | docs/source/_templates 28 | docs/source/api 29 | docs/source/gallery 30 | -------------------------------------------------------------------------------- /docs/source/_templates/spyrit-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | 8 | {% block methods %} 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | :toctree: 14 | :template: spyrit-method-template.rst 15 | {% for item in methods %} 16 | {%- if item is in members %} 17 | ~{{ name }}.{{ item }} 18 | {%- endif %} 19 | {%- endfor %} 20 | {% endif %} 21 | {% endblock %} 22 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Optionally build your docs in additional formats such as PDF 9 | formats: all # pdf, epub and htlmzip 10 | 11 | # Optionally set the version of Python and requirements required to build your docs 12 | build: 13 | os: ubuntu-22.04 14 | tools: 15 | python: "3.11" 16 | python: 17 | install: 18 | - requirements: requirements.txt 19 | 20 | # Build documentation in the docs/ directory with Sphinx 21 | sphinx: 22 | configuration: docs/source/conf.py 23 | -------------------------------------------------------------------------------- /spyrit/__init__.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | # from __future__ import division, print_function, absolute_import 8 | # from distutils.version import LooseVersion 9 | # 10 | # 11 | # from . import spyritest 12 | # 13 | # 14 | # __all__ = [s for s in dir() if not s.startswith('_')] 15 | 16 | # from . import core 17 | # from . import misc 18 | # from . import external 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /spyrit/misc/__init__.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | """Contains miscellaneous Numpy / Pytorch functions useful for spyrit.core.""" 8 | 9 | # from . import color 10 | # from . import data_visualisation 11 | # from . import disp 12 | # from . import examples 13 | # from . import matrix_tools 14 | # from . import metrics 15 | # from . import pattern_choice 16 | # from . import sampling 17 | # from . import statistics 18 | # from . import walsh_hadamard 19 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_bonus_advanced_methods_colab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | r""" 3 | Bonus. Advanced methods - Colab 4 | =============================== 5 | .. _tuto_advanced_methods_colab: 6 | 7 | We refer to `spyrit-examples/tutorial `_ 8 | for a list of tutorials that can be run directly in colab and present more advanced cases than the main spyrit tutorials, 9 | such as comparison of methods for split measurements, or comparison of different denoising networks. 10 | 11 | """ 12 | 13 | ############################################################################### 14 | # The spyrit-examples repository also includes research contributions based on the SPYRIT toolbox. 15 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/css/sg_README.css: -------------------------------------------------------------------------------- 1 | .sphx-glr-thumbnails { 2 | width: 100%; 3 | margin: 0px 0px 0px 0px; 4 | 5 | /* align thumbnails on a grid */ 6 | justify-content: space-between; 7 | display: grid; 8 | /* each grid column should be at least 160px (this will determine 9 | the actual number of columns) and then take as much of the 10 | remaining width as possible */ 11 | grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)) !important; 12 | gap: 20px; 13 | } 14 | .sphx-glr-thumbcontainer { 15 | width: 100% !important; 16 | min-height: 210px !important; 17 | margin: 0px !important; 18 | } 19 | .sphx-glr-thumbcontainer .figure { 20 | min-width: 100px !important; 21 | height: 100px !important; 22 | } 23 | .sphx-glr-thumbcontainer img { 24 | display: inline !important; 25 | object-fit: cover !important; 26 | max-height: 150px !important; 27 | min-width: 300px !important; 28 | } 29 | -------------------------------------------------------------------------------- /Codemeta.json: -------------------------------------------------------------------------------- 1 | { 2 | "@context": "https://doi.org/10.5063/schema/codemeta-2.0", 3 | "type": "SoftwareSourceCode", 4 | "applicationCategory": "Single-pixel imaging", 5 | "codeRepository": "https://github.com/openspyrit/spyrit", 6 | "dateCreated": "2020-12-10", 7 | "datePublished": "2021-03-11", 8 | "description": "SPyRiT is a PyTorch-based deep image reconstruction package primarily designed for single-pixel imaging.", 9 | "keywords": [ 10 | "Single-pixel imaging", 11 | "pytorch" 12 | ], 13 | "license": "https://spdx.org/licenses/LGPL-3.0", 14 | "name": "SPyRiT", 15 | "operatingSystem": [ 16 | "Linux", 17 | "Windows", 18 | "MacOS" 19 | ], 20 | "programmingLanguage": "Python 3", 21 | "contIntegration": "https://github.com/openspyrit/spyrit/actions", 22 | "codemeta:continuousIntegration": { 23 | "id": "https://github.com/openspyrit/spyrit/actions" 24 | }, 25 | "issueTracker": "https://github.com/openspyrit/spyrit/issues" 26 | } 27 | -------------------------------------------------------------------------------- /docs/source/external_libraries.rst: -------------------------------------------------------------------------------- 1 | External libraries 2 | ================================== 3 | 4 | Overview of the connection with external librairies. 5 | 6 | Deepinverse 7 | ----------------------------------- 8 | 9 | * `Tutorial appendix 1 `_ introduces the connection from spyrit to `deepinverse `_. 10 | 11 | * Other examples are available in the `spyrit examples repository `_. For example, you can see the use of Linear physics. 12 | 13 | .. raw:: html 14 | 15 |
16 | 17 | .. thumbnail-parent-div-open 18 | 19 | .. raw:: html 20 | 21 |
22 | 23 | .. only:: html 24 | 25 | .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_logolarge.png 26 | :alt: 27 | :target: gallery/tuto_a00_connect_deepinv.html 28 | 29 | .. raw:: html 30 | 31 |
01.a. Deepinv
32 |
33 |
34 | .. only:: html 35 | -------------------------------------------------------------------------------- /spyrit/misc/examples.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | import numpy as np 8 | 9 | 10 | def translation_matrix(img_size, nb_pixels): 11 | init_ind = np.reshape(np.arange(img_size**2), (img_size, img_size)) 12 | final_ind = np.zeros((img_size, img_size)) 13 | final_ind[:, : (img_size - nb_pixels)] = init_ind[:, nb_pixels:] 14 | final_ind[:, (img_size - nb_pixels) :] = init_ind[:, :nb_pixels] 15 | 16 | final_ind = np.reshape(final_ind, (img_size**2, 1)) 17 | init_ind = np.reshape(init_ind, (img_size**2, 1)) 18 | F = permutation_matrix(final_ind, init_ind) 19 | return F 20 | 21 | 22 | def permutation_matrix(A, B): 23 | N = A.shape[0] 24 | I = np.eye(N) 25 | P = np.zeros((N, N)) 26 | 27 | for i in range(N): 28 | pat = np.matlib.repmat(A[i, :], N, 1) 29 | ind = np.where(np.sum((pat == B), axis=1)) 30 | P[ind, :] = I[i, :] 31 | 32 | return P 33 | 34 | 35 | def circle(img_size, R, x_max): 36 | x = np.linspace(-x_max, x_max, img_size) 37 | X, Y = np.meshgrid(x, x) 38 | return 1.0 * (X**2 + Y**2 < R) 39 | -------------------------------------------------------------------------------- /spyrit/hadamard_matrix/create_hadamard_matrix_with_sage.py: -------------------------------------------------------------------------------- 1 | from sage.all import * 2 | from sage.combinat.matrices.hadamard_matrix import ( 3 | hadamard_matrix, 4 | skew_hadamard_matrix, 5 | is_hadamard_matrix, 6 | is_skew_hadamard_matrix, 7 | ) 8 | import numpy as np 9 | import glob 10 | 11 | # Get all Hadamard matrices of order 4*n for Sage 12 | # https://github.com/sagemath/sage/ 13 | # run in conda env with: 14 | # sage create_hadamard_matrix_with_sage.py 15 | 16 | k = Integer(2000) 17 | for n in range(Integer(1), k + Integer(1)): 18 | try: 19 | H = hadamard_matrix(Integer(4) * n, check=False) 20 | 21 | if is_hadamard_matrix(H): 22 | print(n * 4) 23 | a = np.array(H) 24 | a[a == -1] = 0 25 | a = a.astype(bool) 26 | 27 | # find the files with that order 28 | files = glob.glob("had." + str(n * 4) + "*.npz") 29 | already_saved = False 30 | for file in files: 31 | b = np.load(file) 32 | if a == b: 33 | already_saved = True 34 | if already_saved: 35 | break 36 | 37 | if not already_saved: 38 | name = "had." + str(n * 4) + ".sage.npz" 39 | np.savez_compressed(name, a) 40 | except ValueError as e: 41 | pass 42 | -------------------------------------------------------------------------------- /docs/source/build-tuto-local.md: -------------------------------------------------------------------------------- 1 | 2 | ``` shell 3 | git clone --no-single-branch --depth 50 https://github.com/openspyrit/spyrit . 4 | git checkout --force origin/gallery 5 | git clean -d -f -f 6 | cat .readthedocs.yml 7 | ``` 8 | 9 | # Linux 10 | ``` shell 11 | python3.7 -mvirtualenv $READTHEDOCS_VIRTUALENV_PATH 12 | python -m pip install --upgrade --no-cache-dir pip setuptools 13 | python -m pip install --upgrade --no-cache-dir pillow==5.4.1 mock==1.0.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.9.1 recommonmark==0.5.0 sphinx sphinx-rtd-theme readthedocs-sphinx-ext<2.3 14 | python -m pip install --exists-action=w --no-cache-dir -r requirements.txt 15 | cat docs/source/conf.py 16 | python -m sphinx -T -E -b html -d _build/doctrees -D language=en . $READTHEDOCS_OUTPUT/html 17 | ``` 18 | 19 | # Windows using conda 20 | ``` powershell 21 | conda create --name readthedoc 22 | conda activate readthedoc 23 | conda install pip 24 | python.exe -m pip install --upgrade --no-cache-dir pip setuptools 25 | pip install --upgrade --no-cache-dir pillow==10.0.0 mock==1.0.1 alabaster==0.7.13 commonmark==0.9.1 recommonmark==0.5.0 sphinx sphinx-rtd-theme readthedocs-sphinx-ext==2.2.2 26 | cd .\myenv\spyrit\ # replace myenv by the environment in which spyrit is installed 27 | pip install --exists-action=w --no-cache-dir -r requirements.txt 28 | cd .\docs\source\ 29 | python -m sphinx -T -E -b html -d _build/doctrees -D language=en . html 30 | ``` 31 | -------------------------------------------------------------------------------- /docs/source/_templates/spyrit-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | :show-inheritance: 5 | 6 | {% block attributes %} 7 | {% if attributes %} 8 | .. rubric:: {{ _('Module Attributes') }} 9 | 10 | .. autosummary:: 11 | :toctree: 12 | {% for item in attributes %} 13 | {{ item }} 14 | {%- endfor %} 15 | {% endif %} 16 | {% endblock %} 17 | 18 | {% block functions %} 19 | {% if functions %} 20 | .. rubric:: {{ _('Functions') }} 21 | 22 | .. autosummary:: 23 | :toctree: 24 | {% for item in functions %} 25 | {{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block classes %} 31 | {% if classes %} 32 | .. rubric:: {{ _('Classes') }} 33 | 34 | .. autosummary:: 35 | :toctree: 36 | :template: spyrit-class-template.rst 37 | {% for item in classes %} 38 | {{ item }} 39 | {%- endfor %} 40 | {% endif %} 41 | {% endblock %} 42 | 43 | {% block exceptions %} 44 | {% if exceptions %} 45 | .. rubric:: {{ _('Exceptions') }} 46 | 47 | .. autosummary:: 48 | :toctree: 49 | {% for item in exceptions %} 50 | {{ item }} 51 | {%- endfor %} 52 | {% endif %} 53 | {% endblock %} 54 | 55 | {% block modules %} 56 | {% if modules %} 57 | .. rubric:: Modules 58 | 59 | .. autosummary:: 60 | :toctree: 61 | :template: spyrit-module-template.rst 62 | :recursive: 63 | {% for item in modules %} 64 | {{ item }} 65 | {%- endfor %} 66 | {% endif %} 67 | {% endblock %} 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=67", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.setuptools] 9 | include-package-data = true 10 | zip-safe = false 11 | script-files = [ 12 | "tutorial/tuto_01_a_acquisition_operators.py", 13 | "tutorial/tuto_01_b_splitting.py", 14 | "tutorial/tuto_01_c_HadamSplit2d.py", 15 | "tutorial/tuto_02_noise.py", 16 | "tutorial/tuto_03_pseudoinverse_linear.py" 17 | ] 18 | 19 | [tool.setuptools.dynamic] 20 | readme = { file = "README.md", content-type = "text/markdown"} 21 | 22 | [tool.setuptools.packages] 23 | find = {} # Scanning implicit namespaces is active by default 24 | 25 | [project] 26 | name = "spyrit" 27 | version = "3.0.2" 28 | dynamic = ["readme"] 29 | authors = [{name = "Nicolas Ducros", email = "Nicolas.Ducros@insa-lyon.fr"}] 30 | description = "Toolbox for deep image reconstruction" 31 | license = {file = "LICENSE.md"} 32 | classifiers = [ 33 | "Programming Language :: Python", 34 | "Programming Language :: Python :: 3.9", 35 | "Programming Language :: Python :: 3.10", 36 | "Programming Language :: Python :: 3.11", 37 | "Programming Language :: Python :: Implementation :: PyPy", 38 | "Operating System :: OS Independent", 39 | ] 40 | dependencies = [ 41 | "numpy", 42 | "matplotlib", 43 | "scipy", 44 | "torch", 45 | "torchvision", 46 | "Pillow", 47 | "PyWavelets", 48 | "wget", 49 | "sympy", 50 | "imageio", 51 | "astropy", 52 | "requests", 53 | "tqdm", 54 | "girder-client", 55 | ] 56 | requires-python = ">=3.9" 57 | -------------------------------------------------------------------------------- /docs/source/organisation.rst: -------------------------------------------------------------------------------- 1 | Organisation of the package 2 | ================================== 3 | 4 | .. figure:: fig/direct_net.png 5 | :width: 600 6 | :align: center 7 | 8 | 9 | SPyRiT's typical pipeline. 10 | 11 | SPyRiT allows to simulate measurements and perform image reconstruction using 12 | a full network. A full network includes a measurement operator 13 | :math:`A`, a noise operator :math:`\mathcal{N}`, a preprocessing 14 | operator :math:`B`, a reconstruction operator :math:`\mathcal{R}`, 15 | and a learnable neural network :math:`\mathcal{G}_{\theta}`. All operators 16 | inherit from :class:`torch.nn.Module`. 17 | 18 | 19 | Submodules 20 | ----------------------------------- 21 | 22 | SPyRiT has a modular structure with the core functionality organised in the 8 submodules of 23 | :mod:`spyrit.core`. 24 | 25 | 1. :mod:`spyrit.core.meas` provides measurement operators that compute linear measurements corresponding to :math:`A` in Eq. :eq:`eq_acquisition`. It also provides the adjoint and the pseudoinverse of :math:`A`, which are the basis of any reconstruction algorithm. 26 | 27 | 2. :mod:`spyrit.core.noise` provides noise operators corresponding to :math:`\mathcal{N}` in Eq. :eq:`eq_acquisition`. 28 | 29 | 3. :mod:`spyrit.core.prep` provides preprocessing operators for the operator :math:`B` introduced in Eq. :eq:`eq_prep`. 30 | 31 | 4. :mod:`spyrit.core.nnet` provides known neural networks corresponding to :math:`\mathcal{G}` in Eq. :eq:`eq_recon_direct` or Eq. :eq:`eq_pgd_no_Gamma`. 32 | 33 | 5. :mod:`spyrit.core.recon` returns the reconstruction operator corresponding to :math:`\mathcal{R}`. 34 | 35 | 6. :mod:`spyrit.core.train` provides the functionality to solve the minimisation problem of Eq. :eq:`eq_train`. 36 | 37 | 7. :mod:`spyrit.core.warp` contains the operators used for dynamic acquisitions. 38 | 39 | 8. :mod:`spyrit.core.torch` contains utility functions. 40 | 41 | In addition, :mod:`spyrit.misc` contains various utility functions for Numpy / PyTorch that can be used independently of the core functions. 42 | 43 | Finally, :mod:`spyrit.external` provides access to `DR-UNet `_. 44 | -------------------------------------------------------------------------------- /spyrit/misc/data_visualisation.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | #!/usr/bin/env python3 8 | # -*- coding: utf-8 -*- 9 | """ 10 | Created on Thu Jan 16 08:56:13 2020 11 | 12 | @author: crombez 13 | """ 14 | 15 | from astropy.io import fits 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | # Show basic information of a fits image acquired 20 | # with the andor zyla and plot the image 21 | def show_image_and_infos(path, file): 22 | hdul = fits.open(path + file) 23 | show_images_infos(path, file) 24 | plt.figure() 25 | plt.imshow(hdul[0].data[0]) 26 | plt.show() 27 | 28 | 29 | def show_images_infos(path, file): # Show basic information of a fits image acquired 30 | hdul = fits.open(path + file) 31 | print("***** Name file : " + file + " *****") 32 | print("Type de données : " + hdul[0].header["DATATYPE"]) 33 | print("Mode d'acquisition : " + hdul[0].header["ACQMODE"]) 34 | print("Temps d'exposition : " + str(hdul[0].header["EXPOSURE"])) 35 | print("Temps de lecture : " + str(hdul[0].header["READTIME"])) 36 | print("Longeur d'onde de Rayleigh : " + str(hdul[0].header["RAYWAVE"])) 37 | print("Longeur d'onde détectée : " + str(hdul[0].header["DTNWLGTH"])) 38 | print("***********************************" + "\n") 39 | 40 | 41 | # Plot the resulting fuction of to set of 1D data with the same dimension 42 | def simple_plot_2D( 43 | Lx, Ly, fig=None, title=None, xlabel=None, ylabel=None, style_color="b" 44 | ): 45 | plt.figure(fig) 46 | plt.clf() 47 | plt.title(title) 48 | plt.xlabel(xlabel) 49 | plt.ylabel(ylabel) 50 | plt.plot(Lx, Ly, style_color) 51 | plt.show() 52 | 53 | 54 | # Plot a 2D matrix 55 | def plot_im2D(Im, fig=None, title=None, xlabel=None, ylabel=None, cmap="viridis"): 56 | plt.figure(fig) 57 | plt.clf() 58 | plt.title(title) 59 | plt.xlabel(xlabel) 60 | plt.ylabel(ylabel) 61 | plt.imshow(Im, cmap=cmap) 62 | plt.colorbar() 63 | plt.show() 64 | -------------------------------------------------------------------------------- /docs/source/sg_execution_times.rst: -------------------------------------------------------------------------------- 1 | 2 | :orphan: 3 | 4 | .. _sphx_glr_sg_execution_times: 5 | 6 | 7 | Computation times 8 | ================= 9 | **00:00.000** total execution time for 7 files **from all galleries**: 10 | 11 | .. container:: 12 | 13 | .. raw:: html 14 | 15 | 19 | 20 | 21 | 22 | 27 | 28 | .. list-table:: 29 | :header-rows: 1 30 | :class: table table-striped sg-datatable 31 | 32 | * - Example 33 | - Time 34 | - Mem (MB) 35 | * - :ref:`sphx_glr_gallery_tuto_01_acquisition_operators.py` (``..\..\tutorial\tuto_01_acquisition_operators.py``) 36 | - 00:00.000 37 | - 0.0 38 | * - :ref:`sphx_glr_gallery_tuto_02_pseudoinverse_linear.py` (``..\..\tutorial\tuto_02_pseudoinverse_linear.py``) 39 | - 00:00.000 40 | - 0.0 41 | * - :ref:`sphx_glr_gallery_tuto_03_pseudoinverse_cnn_linear.py` (``..\..\tutorial\tuto_03_pseudoinverse_cnn_linear.py``) 42 | - 00:00.000 43 | - 0.0 44 | * - :ref:`sphx_glr_gallery_tuto_04_train_pseudoinverse_cnn_linear.py` (``..\..\tutorial\tuto_04_train_pseudoinverse_cnn_linear.py``) 45 | - 00:00.000 46 | - 0.0 47 | * - :ref:`sphx_glr_gallery_tuto_05_acquisition_split_measurements.py` (``..\..\tutorial\tuto_05_acquisition_split_measurements.py``) 48 | - 00:00.000 49 | - 0.0 50 | * - :ref:`sphx_glr_gallery_tuto_06_dcnet_split_measurements.py` (``..\..\tutorial\tuto_06_dcnet_split_measurements.py``) 51 | - 00:00.000 52 | - 0.0 53 | * - :ref:`sphx_glr_gallery_tuto_bonus_advanced_methods_colab.py` (``..\..\tutorial\tuto_bonus_advanced_methods_colab.py``) 54 | - 00:00.000 55 | - 0.0 56 | -------------------------------------------------------------------------------- /tutorial/README.txt: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | This series of tutorials should guide you through the use of the SPyRiT pipeline. 5 | 6 | .. figure:: ../fig/direct_net.png 7 | :width: 600 8 | :align: center 9 | :alt: SPyRiT pipeline 10 | 11 | | 12 | 13 | Each tutorial focuses on a specific submodule of the full pipeline. 14 | 15 | * :ref:`Tutorial 1 `.a introduces the basics of measurement operators. 16 | 17 | * :ref:`Tutorial 1 `.b introduces the splitting of measurement operators. 18 | 19 | * :ref:`Tutorial 1 `.c introduces the 2d Hadamard transform with subsampling. 20 | 21 | * :ref:`Tutorial 2 ` introduces the noise operators. 22 | 23 | * :ref:`Tutorial 3 ` demonstrates pseudo-inverse reconstructions from Hadamard measurements. 24 | 25 | * :ref:`Tutorial 4 `.a introduces data-driven post-processing reconstruction. 26 | 27 | * :ref:`Tutorial 4 `.b trains the post-processing CNN used in :ref:`Tutorial 4 `.a. 28 | 29 | * :ref:`Tutorial 5 ` introduces the denoised completion network for the reconstruction of Poisson-corrupted subsampled measurements. 30 | 31 | .. note:: 32 | 33 | The Python script (*.py*) or Jupyter notebook (*.ipynb*) corresponding to each tutorial can be downloaded at the bottom of the page. The images used in these files can be found on `GitHub`_. 34 | 35 | The tutorials below will gradually be updated to be compatible with SPyRiT 3 (work in progress, in the meantime see SPyRiT `2.4.0`_). 36 | 37 | 38 | * :ref:`Tutorial 6 ` uses a Denoised Completion Network with a trainable image denoiser to improve the results obtained in Tutorial 5 39 | 40 | * :ref:`Tutorial 7 ` shows how to perform image reconstruction using a pretrained plug-and-play denoising network. 41 | 42 | * :ref:`Tutorial 8 ` shows how to perform image reconstruction using a learnt proximal gradient descent. 43 | 44 | * :ref:`Tutorial 9 ` explains motion simulation from an image, dynamic measurements and reconstruction. 45 | 46 | .. _GitHub: https://github.com/openspyrit/spyrit/tree/3895b5e61fb6d522cff5e8b32a36da89b807b081/tutorial/images/test 47 | 48 | .. _2.4.0: https://spyrit.readthedocs.io/en/2.4.0/gallery/index.html 49 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. spyrit documentation master file, created by 2 | sphinx-quickstart on Fri Mar 12 11:04:59 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | SPyRiT 7 | ##################################################################### 8 | 9 | SPyRiT is a `PyTorch `_-based image reconstruction 10 | package designed for `single-pixel imaging `_. SPyRiT has a `modular organisation `_ and may be useful for other inverse problems. 11 | 12 | Github repository: `openspyrit/spyrit `_ 13 | 14 | 15 | Installation 16 | ================================== 17 | 18 | SPyRiT is available for Linux, MacOs and Windows:: 19 | 20 | pip install spyrit 21 | 22 | See `here `_ for advanced installation guidelines. 23 | 24 | 25 | Getting started 26 | ================================== 27 | 28 | Please check our `tutorials `_ as well as the `examples <.. _examples: https://github.com/openspyrit/spyrit-examples/tree/master/2025_spyrit_v3>`_ on GitHub. 29 | 30 | 31 | External librairies 32 | ================================== 33 | 34 | You can connect spyrit to different packages like deepinv. Check `tutorials `_. 35 | 36 | 37 | Cite us 38 | ================================== 39 | 40 | When using SPyRiT in scientific publications, please cite [v3]_ for SPyRiT v3, [v2]_ for SPyRiT v2, and [v1]_ for DC-Net. 41 | 42 | .. [v3] JFJP Abascal, T Baudier, R Phan, A Repetti, N Ducros, "SPyRiT 3.0: an open source package for single-pixel imaging based on deep learning," Vol. 33, Issue 13, pp. 27988-28005 (2025). `DOI `_. 43 | .. [v2] G Beneti-Martin, L Mahieu-Williame, T Baudier, N Ducros, "OpenSpyrit: an Ecosystem for Reproducible Single-Pixel Hyperspectral Imaging," *Optics Express*, Vol. 31, Issue 10, (2023). `DOI `_. 44 | .. [v1] A Lorente Mur, P Leclerc, F Peyrin, and N Ducros, "Single-pixel image reconstruction from experimental data using neural networks," *Opt. Express*, Vol. 29, Issue 11, 17097-17110 (2021). `DOI `_. 45 | 46 | 47 | Join the project 48 | ================================== 49 | 50 | The list of contributors can be found `here `_. Feel free to contact us by `e-mail `_ for any question. Direct contributions via pull requests (PRs) are welcome. 51 | 52 | .. toctree:: 53 | :maxdepth: 2 54 | :hidden: 55 | 56 | single_pixel 57 | organisation 58 | 59 | 60 | .. toctree:: 61 | :maxdepth: 2 62 | :caption: Tutorials 63 | :hidden: 64 | 65 | gallery/index 66 | 67 | .. toctree:: 68 | :maxdepth: 2 69 | :caption: External libraries 70 | :hidden: 71 | 72 | external_libraries 73 | 74 | Contents 75 | ======== 76 | 77 | .. autosummary:: 78 | :toctree: _autosummary 79 | :template: spyrit-module-template.rst 80 | :recursive: 81 | :caption: Contents 82 | 83 | spyrit.core 84 | spyrit.misc 85 | spyrit.external 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/openspyrit/spyrit?logo=github) 2 | [![GitHub](https://img.shields.io/github/license/openspyrit/spyrit?style=plastic)](https://github.com/openspyrit/spyrit/blob/master/LICENSE.md) 3 | [![PyPI pyversions](https://img.shields.io/pypi/pyversions/spyrit.svg)](https://pypi.python.org/pypi/spyrit/) 4 | [![Docs](https://readthedocs.org/projects/spyrit/badge/?version=master&style=flat)](https://spyrit.readthedocs.io/en/master/) 5 | 6 | # SPyRiT 7 | SPyRiT is a [PyTorch]()-based deep image reconstruction package primarily designed for single-pixel imaging. 8 | 9 | # Installation 10 | The spyrit package is available for Linux, MacOs and Windows. We recommend to use a virtual environment. 11 | ## Linux and MacOs 12 | (user mode) 13 | ``` 14 | pip install spyrit 15 | ``` 16 | (developper mode) 17 | ``` 18 | git clone https://github.com/openspyrit/spyrit.git 19 | cd spyrit 20 | pip install -e . 21 | ``` 22 | 23 | ## Windows 24 | On Windows you may need to install PyTorch first. It may also be necessary to run the following commands using administrator rights (e.g., starting your Python environment with administrator rights). 25 | 26 | Adapt the two examples below to your configuration (see [here](https://pytorch.org/get-started/locally/) for the latest instructions) 27 | 28 | (CPU version using `pip`) 29 | 30 | ``` 31 | pip3 install torch torchvision torchaudio 32 | ``` 33 | 34 | (GPU version using `conda`) 35 | 36 | ``` shell 37 | conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia 38 | ``` 39 | 40 | Then, install SPyRiT using `pip`: 41 | 42 | (user mode) 43 | ``` 44 | pip install spyrit 45 | ``` 46 | (developper mode) 47 | ``` 48 | git clone https://github.com/openspyrit/spyrit.git 49 | cd spyrit 50 | pip install -e . 51 | ``` 52 | 53 | 54 | ## Test 55 | To check the installation, run in your python terminal: 56 | ``` 57 | import spyrit 58 | ``` 59 | 60 | ## Get started - Examples 61 | To start, check the [documentation tutorials](https://spyrit.readthedocs.io/en/master/gallery/index.html). These tutorials must be runned from `tutorial` folder (they load image samples from `spyrit/images/`): 62 | ``` 63 | cd spyrit/tutorial/ 64 | ``` 65 | 66 | More advanced reconstruction examples can be found in [spyrit-examples/tutorial](https://github.com/openspyrit/spyrit-examples/tree/master/tutorial). Run advanced tutorial in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit-examples/blob/master/tutorial/tuto_core_2d_drunet.ipynb) 67 | 68 | 69 | # API Documentation 70 | https://spyrit.readthedocs.io/ 71 | 72 | # Contributors (alphabetical order) 73 | * Juan Abascal - [Website](https://juanabascal78.wixsite.com/juan-abascal-webpage) 74 | * Thomas Baudier 75 | * Sebastien Crombez 76 | * Nicolas Ducros - [Website](https://www.creatis.insa-lyon.fr/~ducros/WebPage/index.html) 77 | * Antonio Tomas Lorente Mur - [Website]( https://sites.google.com/view/antonio-lorente-mur/) 78 | * Romain Phan 79 | * Fadoua Taia-Alaoui 80 | 81 | # How to cite? 82 | When using SPyRiT in scientific publications, please cite the following paper: 83 | 84 | * G. Beneti-Martin, L Mahieu-Williame, T Baudier, N Ducros, "OpenSpyrit: an Ecosystem for Reproducible Single-Pixel Hyperspectral Imaging," Optics Express, Vol. 31, No. 10, (2023). https://doi.org/10.1364/OE.483937. 85 | 86 | When using SPyRiT specifically for the denoised completion network, please cite the following paper: 87 | 88 | * A Lorente Mur, P Leclerc, F Peyrin, and N Ducros, "Single-pixel image reconstruction from experimental data using neural networks," Opt. Express 29, 17097-17110 (2021). https://doi.org/10.1364/OE.424228. 89 | 90 | # License 91 | This project is licensed under the LGPL-3.0 license - see the [LICENSE.md](LICENSE.md) file for details 92 | 93 | # Acknowledgments 94 | * [Jin LI](https://github.com/happyjin/ConvGRU-pytorch) for his implementation of Convolutional Gated Recurrent Units for PyTorch 95 | * [Erik Lindernoren](https://github.com/eriklindernoren/Action-Recognition) for his processing of the UCF-101 Dataset. 96 | -------------------------------------------------------------------------------- /docs/source/single_pixel.rst: -------------------------------------------------------------------------------- 1 | Single-pixel imaging 2 | ================================== 3 | .. _principle: 4 | .. figure:: fig/spi_principle.png 5 | :width: 800 6 | :align: center 7 | 8 | Overview of the principle of single-pixel imaging. 9 | 10 | 11 | Simulation of the measurements 12 | ----------------------------------- 13 | Single-pixel imaging aims to recover an unknown image :math:`x\in\mathbb{R}^N` from a few noisy observations 14 | 15 | .. math:: 16 | m \approx Hx, 17 | 18 | where :math:`H\colon \mathbb{R}^{M\times N}` is a linear measurement operator, :math:`M` is the number of measurements and :math:`N` is the number of pixels in the image. 19 | 20 | In practice, measurements are obtained by uploading a set of light patterns onto a spatial light modulator (e.g., a digital micromirror device (DMD), see :ref:`principle`). Therefore, only positive patterns can be implemented. We model the actual acquisition process as 21 | 22 | 23 | .. math:: 24 | :label: eq_acquisition 25 | 26 | y = \mathcal{N}(Ax) 27 | 28 | where :math:`\mathcal{N} \colon \mathbb{R}^J \to \mathbb{R}^J` represents a noise operator (e.g., Poisson or Poisson-Gaussian), :math:`A \in \mathbb{R}_+^{J\times N}` is the actual acquisition operator that models the (positive) DMD patterns, and :math:`J` is the number of DMD patterns. 29 | 30 | Handling non negativity with pre-processing 31 | ---------------------------------------------------------------------- 32 | We may preprocess the measurements before reconstruction to transform the actual measurements into the target measurements 33 | 34 | .. math:: 35 | :label: eq_prep 36 | 37 | m = By \approx Hx 38 | 39 | 40 | where :math:`B\colon\mathbb{R}^{J}\to \mathbb{R}^{M}` is the preprocessing operator chosen such that :math:`BA=H`. Note that the noise of the preprocessed measurements :math:`m=By` is not the same as that of the actual measurements :math:`y`. 41 | 42 | Data-driven image reconstruction 43 | ----------------------------------- 44 | Data-driven methods based on deep learning aim to find an estimate :math:`x^*\in \mathbb{R}^N` of the unknown image :math:`x` from the preprocessed measurements :math:`By`, using a reconstruction operator :math:`\mathcal{R}_{\theta^*} \colon \mathbb{R}^M \to \mathbb{R}^N` 45 | 46 | .. math:: 47 | \mathcal{R}_{\theta^*}(m) = x^* \approx x, 48 | 49 | where :math:`\theta^*` represents the parameters learned during a training procedure. 50 | 51 | Learning phase 52 | ----------------------------------- 53 | In the case of supervised learning, it is assumed that a training dataset :math:`\{x^{(i)},y^{(i)}\}_{1 \le i \le I}` of :math:`I` pairs of ground truth images in :math:`\mathbb{R}^N` and measurements in :math:`\mathbb{R}^M` is available}. :math:`\theta^*` is then obtained by solving 54 | 55 | .. math:: 56 | :label: eq_train 57 | 58 | \min_{\theta}\,{\sum_{i =1}^I \mathcal{L}\left(x^{(i)},\mathcal{R}_\theta(By^{(i)})\right)}, 59 | 60 | 61 | where :math:`\mathcal{L}` is the training loss (e.g., squared error). In the case where only ground truth images :math:`\{x^{(i)}\}_{1 \le i \le I}` are available, the associated measurements are simulated as :math:`y^{(i)} = \mathcal{N}(Ax^{(i)})`, :math:`1 \le i \le I`. 62 | 63 | 64 | Reconstruction operator 65 | ----------------------------------- 66 | A simple yet efficient method consists in correcting a traditional (e.g. linear) reconstruction by a data-driven nonlinear step 67 | 68 | .. math:: 69 | :label: eq_recon_direct 70 | 71 | \mathcal{R}_\theta = \mathcal{G}_\theta \circ \mathcal{R}, 72 | 73 | where :math:`\mathcal{R}\colon\mathbb{R}^{M}\to\mathbb{R}^N` is a traditional hand-crafted (e.g., regularized) reconstruction operator and :math:`\mathcal{G}_\theta\colon\mathbb{R}^{N}\to\mathbb{R}^N` is a nonlinear neural network that acts in the image domain. 74 | 75 | Algorithm unfolding consists in defining :math:`\mathcal{R}_\theta` from an iterative scheme 76 | 77 | .. math:: 78 | :label: eq_pgd_no_Gamma 79 | 80 | \mathcal{R}_\theta = \mathcal{R}_{\theta_K} \circ ... \circ \mathcal{R}_{\theta_1}, 81 | 82 | where :math:`\mathcal{R}_{\theta_k}` can be interpreted as the computation of the :math:`k`-th iteration of the iterative scheme and :math:`\theta = \bigcup_{k} \theta_k`. 83 | -------------------------------------------------------------------------------- /spyrit/misc/matrix_tools.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | #!/usr/bin/env python3 8 | # -*- coding: utf-8 -*- 9 | """ 10 | Created on Wed Jan 15 16:37:27 2020 11 | 12 | @author: crombez 13 | """ 14 | import warnings 15 | 16 | warnings.simplefilter("always", DeprecationWarning) 17 | 18 | import numpy as np 19 | 20 | import spyrit.misc.sampling as samp 21 | 22 | 23 | def Permutation_Matrix(mat): 24 | r""" 25 | Returns permutation matrix from sampling matrix 26 | 27 | Args: 28 | Mat (np.ndarray): 29 | N-by-N sampling matrix, where high values indicate high significance. 30 | 31 | Returns: 32 | P (np.ndarray): N^2-by-N^2 permutation matrix (boolean) 33 | 34 | .. warning:: 35 | This function is a duplicate of 36 | :func:`spyrit.misc.sampling.Permutation_Matrix` and will be removed 37 | in a future release. 38 | 39 | .. note:: 40 | Consider using :func:`sort_by_significance` for increased 41 | computational performance if using :func:`Permutation_Matrix` to 42 | reorder a matrix as follows: 43 | ``y = Permutation_Matrix(Ord) @ Mat`` 44 | """ 45 | warnings.warn( 46 | "\nspyrit.misc.matrix_tools.Permutation_Matrix is deprecated and will" 47 | + " be removed in a future release. Use\n" 48 | + "spyrit.misc.sampling.Permutation_Matrix instead.", 49 | DeprecationWarning, 50 | ) 51 | return samp.Permutation_Matrix(mat) 52 | 53 | 54 | def expend_vect(Vect, N1, N2): # Expened a vectors of siez N1 to N2 55 | V_out = np.zeros(N2) 56 | S = int(N2 / N1) 57 | j = 0 58 | ad = 0 59 | for i in range(N1): 60 | for j in range(0, S): 61 | V_out[i + j + ad] = Vect[i] 62 | ad += S - 1 63 | return V_out 64 | 65 | 66 | def data_conv_hadamard(H, Data, N): 67 | for i in range(N): 68 | H[:, :, i] = H[:, :, i] * Data 69 | return H 70 | 71 | 72 | def Sum_coll(Mat, N_lin, N_coll): # Return the sum of all the raw of the N1xN2 matrix 73 | Mturn = np.zeros(N_lin) 74 | 75 | for i in range(N_coll): 76 | Mturn += Mat[:, i] 77 | 78 | return Mturn 79 | 80 | 81 | def compression_1D( 82 | H, Nl, Nc, Nh 83 | ): # Compress a Matrix of N1xN2xN3 into a matrix of N1xN3 by summing the raw 84 | H_1D = np.zeros((Nl, Nh)) 85 | for i in range(Nh): 86 | H_1D[:, i] = Sum_coll(H[:, :, i], Nl, Nc) 87 | 88 | return H_1D 89 | 90 | 91 | def normalize_mat_2D(Mat): # Normalise a N1xN2 matrix by is maximum value 92 | Max = np.amax(Mat) 93 | return Mat * (1 / Max) 94 | 95 | 96 | def normalize_by_median_mat_2D(Mat): # Normalise a N1xN2 matrix by is median value 97 | Median = np.median(Mat) 98 | return Mat * (1 / Median) 99 | 100 | 101 | def remove_offset_mat_2D(Mat): # Substract the mean value of the matrix 102 | Mean = np.mean(Mat) 103 | return Mat - Mean 104 | 105 | 106 | def resize(Mat, Nl, Nc, Nh): # Re-size a matrix of N1xN2 into N1xN3 107 | Mres = np.zeros((Nl, Nc)) 108 | for i in range(Nl): 109 | Mres[i, :] = expend_vect(Mat[i, :], Nh, Nc) 110 | return Mres 111 | 112 | 113 | def stack_depth_matrice( 114 | Mat, Nl, Nc, Nd 115 | ): # Stack a 3 by 3 matrix along its third dimensions 116 | M_out = np.zeros((Nl, Nc)) 117 | for i in range(Nd): 118 | M_out += Mat[:, :, i] 119 | return M_out 120 | 121 | 122 | # fuction that need to be better difended 123 | 124 | 125 | def smooth(y, box_pts): # Smooth a vectors 126 | box = np.ones(box_pts) / box_pts 127 | y_smooth = np.convolve(y, box, mode="same") 128 | return y_smooth 129 | 130 | 131 | def reject_outliers(data, m=2): # Remove 132 | return np.where(abs(data - np.mean(data)) < m * np.std(data), data, 0) 133 | 134 | 135 | def clean_out(Data, Nl, Nc, Nh, m=2): 136 | Mout = np.zeros((Nl, Nc, Nh)) 137 | for i in range(Nh): 138 | Mout[:, :, i] = reject_outliers(Data[:, :, i], m) 139 | return Data 140 | -------------------------------------------------------------------------------- /spyrit/hadamard_matrix/download_hadamard_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import requests 3 | import os 4 | import glob 5 | import importlib.util 6 | import tqdm 7 | 8 | 9 | def download_from_girder(): 10 | """ 11 | Download Hadamard matrices from the Girder repository into hadamard_matrix folder. 12 | """ 13 | 14 | hadamard_matrix_path = os.path.dirname(__file__) 15 | if os.path.isfile( 16 | os.path.join(hadamard_matrix_path, "had.236.sage.cooper-wallis.npz") 17 | ): 18 | return 19 | print("Downloading Hadamard matrices (>2300) from Girder repository...") 20 | print( 21 | "The matrices were downloaded from http://neilsloane.com/hadamard/ Sloane et al." 22 | ) 23 | import girder_client 24 | 25 | gc = girder_client.GirderClient( 26 | apiUrl="https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 27 | ) 28 | 29 | collection_id = "66796d3cbaa5a90007058946" 30 | folder_id = "6800c6891240141f6aa53845" 31 | limit = 50 # Number of items to retrieve per request 32 | offset = 0 # Starting point 33 | pbar = tqdm.tqdm(total=0) 34 | 35 | while True: 36 | items = gc.get( 37 | "item", 38 | parameters={ 39 | "parentType": "collection", 40 | "parentId": collection_id, 41 | "folderId": folder_id, 42 | "limit": limit, 43 | "offset": offset, 44 | }, 45 | ) 46 | if not items: 47 | break 48 | pbar.total += len(items) 49 | pbar.refresh() 50 | for item in items: 51 | files = gc.get(f'item/{item["_id"]}/files') 52 | for file in files: 53 | pbar.update(1) 54 | gc.downloadFile( 55 | file["_id"], os.path.join(hadamard_matrix_path, file["name"]) 56 | ) 57 | offset += limit 58 | pbar.close() 59 | 60 | 61 | def read_text_file_from_url(url): 62 | response = requests.get(url) 63 | content = response.text 64 | return content 65 | 66 | 67 | def download_from_sloane(): 68 | from selenium import webdriver 69 | from selenium.webdriver.common.by import By 70 | from selenium.webdriver.chrome.service import Service 71 | from webdriver_manager.chrome import ChromeDriverManager 72 | 73 | # Set up the WebDriver 74 | driver = webdriver.Chrome(service=Service(ChromeDriverManager().install())) 75 | 76 | # Open the website 77 | driver.get("http://neilsloane.com/hadamard/") 78 | 79 | # Find all links to Hadamard matrices 80 | links = driver.find_elements(By.XPATH, "//a[contains(@href, 'had.')]") 81 | 82 | # Extract the URLs 83 | hadamard_urls = set([link.get_attribute("href") for link in links]) 84 | 85 | # Print the URLs 86 | for url in hadamard_urls: 87 | print(url) 88 | # Read the text file from the URL 89 | file_content = read_text_file_from_url(url) 90 | # Split the content into lines 91 | lines = file_content.splitlines() 92 | 93 | # Print the content of the file 94 | if "+" in file_content or "0" in file_content or "-1" in file_content: 95 | if len(lines) > 1: 96 | size = len(lines[1]) 97 | else: 98 | size = len(lines[0]) 99 | array = [] 100 | for line in lines: 101 | if len(line) == size: 102 | line = line.replace("-1", "0") 103 | tmp = [] 104 | for e in line: 105 | if e == "+" or e == "1": 106 | tmp += [1] 107 | elif e == "-" or e == "0": 108 | tmp += [0] 109 | elif e == " ": 110 | pass 111 | else: 112 | print("Error during reading of " + url) 113 | array += [tmp] 114 | np_array = np.array(array, dtype=bool) 115 | 116 | name = url.split("/")[-1][:-4] 117 | order = int(name.split(".")[1]) 118 | 119 | # Check if the file already exists 120 | files = glob.glob("had." + str(order) + "*.npz") 121 | already_saved = False 122 | for file in files: 123 | b = np.load(file) 124 | if np.all(np_array == b): 125 | already_saved = True 126 | if already_saved: 127 | break 128 | 129 | if not already_saved: 130 | np.savez_compressed(name + ".npz", np_array) 131 | else: 132 | print("no ok for " + url) 133 | # print(file_content) 134 | 135 | # Close the WebDriver 136 | driver.quit() 137 | 138 | 139 | if __name__ == "__main__": 140 | download_from_sloane() 141 | -------------------------------------------------------------------------------- /tutorial/tuto_02_noise.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 02. Noise operators 3 | =================================================== 4 | .. _tuto_noise: 5 | 6 | This tutorial shows how to use noise operators using the :mod:`spyrit.core.noise` submodule. 7 | 8 | .. image:: ../fig/tuto2.png 9 | :width: 600 10 | :align: center 11 | :alt: Reconstruction architecture sketch 12 | 13 | | 14 | """ 15 | 16 | # %% 17 | # Load a batch of images 18 | # ----------------------------------------------------------------------------- 19 | 20 | ############################################################################### 21 | # We load a batch of images from the `/images/` folder. Using the 22 | # :func:`transform_gray_norm` function with the :attr:`normalize=False` 23 | # argument returns images with values in (0,1). 24 | import os 25 | 26 | import torch 27 | import torchvision 28 | import matplotlib.pyplot as plt 29 | 30 | from spyrit.misc.disp import imagesc 31 | from spyrit.misc.statistics import transform_gray_norm 32 | 33 | spyritPath = os.getcwd() 34 | imgs_path = os.path.join(spyritPath, "images/") 35 | 36 | # Grayscale images of size 64 x 64, no normalization to keep values in (0,1) 37 | transform = transform_gray_norm(img_size=64, normalize=False) 38 | 39 | # Create dataset and loader (expects class folder 'images/test/') 40 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 41 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 42 | 43 | x, _ = next(iter(dataloader)) 44 | print(f"Shape of input images: {x.shape}") 45 | 46 | 47 | ############################################################################### 48 | # We select the first image in the batch and plot it. 49 | 50 | i_plot = 1 51 | imagesc(x[i_plot, 0, :, :], r"$x$ in (0, 1)") 52 | 53 | 54 | # %% 55 | # Gaussian noise 56 | # ----------------------------------------------------------------------------- 57 | 58 | ############################################################################### 59 | # We consider additive Gaussiane noise, 60 | # 61 | # .. math:: 62 | # y \sim z + \mathcal{N}(0,\sigma^2), 63 | # 64 | # where :math:`\mathcal{N}(\mu, \sigma^2)` is a Gaussian distribution with mean :math:`\mu` and variance :math:`\sigma^2`, and :math:`z` is the noiseless image. The larger :math:`\sigma`, the lower the signal-to-noise ratio. 65 | 66 | ############################################################################### 67 | # To add 10% Gaussian noise, we instantiate a :class:`spyrit.core.noise` 68 | # operator with :attr:`sigma=0.1`. 69 | 70 | from spyrit.core.noise import Gaussian 71 | 72 | noise_op = Gaussian(sigma=0.1) 73 | x_noisy = noise_op(x) 74 | 75 | imagesc(x_noisy[1, 0, :, :], r"10% Gaussian noise") 76 | # sphinx_gallery_thumbnail_number = 2 77 | 78 | ############################################################################### 79 | # To add 2% Gaussian noise, we update the class attribute :attr:`sigma`. 80 | 81 | noise_op.sigma = 0.02 82 | x_noisy = noise_op(x) 83 | 84 | imagesc(x_noisy[1, 0, :, :], r"2% Gaussian noise") 85 | 86 | # %% 87 | # Poisson noise 88 | # ----------------------------------------------------------------------------- 89 | 90 | ############################################################################### 91 | # We now consider Poisson noise, 92 | # 93 | # .. math:: 94 | # y \sim \mathcal{P}(\alpha z), \quad z \ge 0, 95 | # 96 | # where :math:`\alpha \ge 0` is a scalar value that represents the maximum 97 | # image intensity (in photons). The larger :math:`\alpha`, the higher the signal-to-noise ratio. 98 | 99 | ############################################################################### 100 | # We consider the :class:`spyrit.core.noise.Poisson` class and set :math:`\alpha` 101 | # to 100 photons (which corresponds to the Poisson parameter). 102 | 103 | from spyrit.core.noise import Poisson 104 | from spyrit.misc.disp import add_colorbar, noaxis 105 | 106 | alpha = 100 # number of photons 107 | noise_op = Poisson(alpha) 108 | 109 | ############################################################################### 110 | # We simulate two noisy versions of the same images 111 | 112 | y1 = noise_op(x) # first sample 113 | y2 = noise_op(x) # another sample 114 | 115 | ############################################################################### 116 | # We now consider the case :math:`\alpha = 1000` photons. 117 | 118 | noise_op.alpha = 1000 119 | y3 = noise_op(x) # noisy measurement vector 120 | 121 | ############################################################################### 122 | # We finally plot the noisy images 123 | 124 | # plot 125 | f, axs = plt.subplots(1, 3, figsize=(10, 5)) 126 | axs[0].set_title("100 photons") 127 | im = axs[0].imshow(y1[1, 0].reshape(64, 64), cmap="gray") 128 | add_colorbar(im, "bottom") 129 | 130 | axs[1].set_title("100 photons") 131 | im = axs[1].imshow(y2[1, 0].reshape(64, 64), cmap="gray") 132 | add_colorbar(im, "bottom") 133 | 134 | axs[2].set_title("1000 photons") 135 | im = axs[2].imshow( 136 | y3[ 137 | 1, 138 | 0, 139 | ].reshape(64, 64), 140 | cmap="gray", 141 | ) 142 | add_colorbar(im, "bottom") 143 | 144 | noaxis(axs) 145 | 146 | ############################################################################### 147 | # As expected the signal-to-noise ratio of the measurement vector is higher for 148 | # 1,000 photons than for 100 photons 149 | # 150 | # .. note:: 151 | # Not only the signal-to-noise, but also the scale of the measurements 152 | # depends on :math:`\alpha`, which motivates the introduction of the 153 | # preprocessing operator. 154 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | 2 | name: CI 3 | 4 | on: 5 | push: 6 | branches: [ master ] 7 | tags: 8 | - '*' 9 | pull_request: 10 | branches: 11 | - '*' 12 | schedule: 13 | - cron: '0 0 * * 0' 14 | workflow_dispatch: 15 | 16 | jobs: 17 | build_wheel: 18 | runs-on: ${{ matrix.os }} 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | os: [ubuntu-latest] 23 | python-version: [3.9] 24 | 25 | steps: 26 | - name: Checkout github repo 27 | uses: actions/checkout@v4 28 | - name: Checkout submodules 29 | run: git submodule update --init --recursive 30 | - name: Set up Python ${{ matrix.python-version }} 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | architecture: 'x64' 35 | - name: Create Wheel 36 | run: | 37 | pip install build 38 | python -m build 39 | mkdir wheelhouse 40 | cp dist/spyrit-* wheelhouse/ 41 | ls wheelhouse 42 | rm -r dist 43 | mv wheelhouse dist 44 | - name: Upload wheels 45 | uses: actions/upload-artifact@v4 46 | with: 47 | name: dist 48 | path: dist/ 49 | 50 | test_install: 51 | runs-on: ${{ matrix.os }} 52 | strategy: 53 | fail-fast: false 54 | matrix: 55 | os: [ubuntu-latest, windows-latest, macos-13, macos-14] 56 | python-version: [3.9, "3.10", "3.11", "3.12"] 57 | exclude: 58 | - os: macos-13 59 | python-version: '3.10' 60 | - os: macos-13 61 | python-version: '3.11' 62 | - os: macos-13 63 | python-version: '3.12' 64 | - os: macos-14 65 | python-version: 3.9 66 | - os: macos-14 67 | python-version: '3.10' 68 | 69 | steps: 70 | - name: Checkout github repo 71 | uses: actions/checkout@v4 72 | - name: Checkout submodules 73 | run: git submodule update --init --recursive 74 | - name: Set up Python ${{ matrix.python-version }} 75 | uses: actions/setup-python@v5 76 | with: 77 | python-version: ${{ matrix.python-version }} 78 | architecture: 'x64' 79 | - name: Run the tests on Mac and Linux 80 | if: matrix.os != 'windows-latest' 81 | run: | 82 | pip install pytest 83 | pip install -e . 84 | python -m pytest --doctest-modules --ignore=tutorial --ignore=docs --ignore=spyrit/dev --ignore=spyrit/hadamard_matrix || exit -1 85 | - name: Run the tests on Windows 86 | if: matrix.os == 'windows-latest' 87 | shell: cmd 88 | run: | 89 | pip install pytest 90 | pip install -e . 91 | python -m pytest --doctest-modules --ignore=tutorial --ignore=docs --ignore=spyrit\dev --ignore=spyrit\hadamard_matrix || exit /b -1 92 | 93 | test_wheel: 94 | runs-on: ${{ matrix.os }} 95 | needs: [build_wheel] 96 | strategy: 97 | fail-fast: false 98 | matrix: 99 | os: [ubuntu-latest, windows-latest, macos-13] 100 | python-version: [3.9, "3.10", "3.11", "3.12"] 101 | 102 | steps: 103 | - name: Checkout github repo 104 | uses: actions/checkout@v4 105 | - name: Checkout submodules 106 | run: git submodule update --init --recursive 107 | - name: Set up Python ${{ matrix.python-version }} 108 | uses: actions/setup-python@v5 109 | with: 110 | python-version: ${{ matrix.python-version }} 111 | architecture: 'x64' 112 | - uses: actions/download-artifact@v4 113 | with: 114 | pattern: dist* 115 | merge-multiple: true 116 | path: dist/ 117 | - name: Run tests on Mac and Linux 118 | if: matrix.os != 'windows-latest' 119 | run: | 120 | cd dist 121 | pip install spyrit-*.whl 122 | - name: Run the tests on Windows 123 | if: matrix.os == 'windows-latest' 124 | run: | 125 | cd dist 126 | $package=dir -Path . -Filter spyrit*.whl | %{$_.FullName} 127 | echo $package 128 | pip install $package 129 | 130 | publish_wheel: 131 | runs-on: ubuntu-latest 132 | needs: [build_wheel, test_wheel, test_install] 133 | steps: 134 | - name: Checkout github repo 135 | uses: actions/checkout@v4 136 | - name: Checkout submodules 137 | run: git submodule update --init --recursive 138 | - uses: actions/download-artifact@v4 139 | with: 140 | pattern: dist* 141 | merge-multiple: true 142 | path: dist/ 143 | - name: Publish to PyPI 144 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/') 145 | uses: pypa/gh-action-pypi-publish@release/v1 146 | with: 147 | user: __token__ 148 | password: ${{ secrets.PYPI }} 149 | skip_existing: true 150 | 151 | test_deepinv: 152 | runs-on: ubuntu-latest 153 | steps: 154 | - name: Checkout github repo 155 | uses: actions/checkout@v4 156 | - name: Checkout submodules 157 | run: git submodule update --init --recursive 158 | - name: Set up Python '3.12' 159 | uses: actions/setup-python@v5 160 | with: 161 | python-version: '3.12' 162 | architecture: 'x64' 163 | - name: Run the tests with deepinv 164 | run: | 165 | pip install pytest 166 | pip install deepinv 167 | pip install -e . 168 | cd tutorial 169 | cp wip/tuto_a00_connect_deepinv.py . 170 | python tuto_a00_connect_deepinv.py || exit -1 171 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | from sphinx_gallery.sorting import ExampleTitleSortKey 16 | 17 | # paths relative to this file 18 | sys.path.insert(0, os.path.abspath("../..")) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | project = "spyrit" 22 | copyright = "2021, Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" 23 | author = "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = "2.4.0" 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | "sphinx.ext.intersphinx", 35 | "sphinx.ext.autodoc", 36 | "sphinx.ext.mathjax", 37 | "sphinx.ext.todo", 38 | "sphinx.ext.autosummary", 39 | "sphinx.ext.napoleon", 40 | "sphinx.ext.viewcode", 41 | "sphinx_gallery.gen_gallery", 42 | "sphinx.ext.coverage", 43 | ] 44 | 45 | # Napoleon settings 46 | napoleon_google_docstring = False 47 | napoleon_numpy_docstring = True 48 | napoleon_include_private_with_doc = False 49 | napoleon_include_special_with_doc = False 50 | napoleon_use_admonition_for_examples = False 51 | napoleon_use_admonition_for_notes = False 52 | napoleon_use_admonition_for_references = False 53 | napoleon_use_ivar = True 54 | napoleon_use_param = False 55 | napoleon_use_rtype = False 56 | 57 | autodoc_member_order = "bysource" 58 | autosummary_generate = True 59 | todo_include_todos = True 60 | 61 | # Add any paths that contain templates here, relative to this directory. 62 | templates_path = ["_templates"] 63 | 64 | # List of patterns, relative to source directory, that match files and 65 | # directories to ignore when looking for source files. 66 | # This pattern also affects html_static_path and html_extra_path. 67 | exclude_patterns = [] 68 | 69 | sphinx_gallery_conf = { 70 | # path to your examples scripts 71 | "examples_dirs": [ 72 | "../../tutorial", 73 | ], 74 | # path where to save gallery generated examples 75 | "gallery_dirs": ["gallery"], 76 | "filename_pattern": "/tuto_", 77 | "ignore_pattern": "/_", 78 | # resize the thumbnails, original size = 400x280 79 | "thumbnail_size": (400, 280), 80 | # Remove the "Download all examples" button from the top level gallery 81 | "download_all_examples": False, 82 | # Sort gallery example by file name instead of number of lines (default) 83 | "within_subsection_order": ExampleTitleSortKey, 84 | # directory where function granular galleries are stored 85 | "backreferences_dir": "api/generated/backreferences", 86 | # Modules for which function level galleries are created. 87 | "doc_module": "spyrit", 88 | # Insert links to documentation of objects in the examples 89 | "reference_url": {"spyrit": None}, 90 | } 91 | 92 | # -- Options for HTML output ------------------------------------------------- 93 | 94 | # The theme to use for HTML and HTML Help pages. See the documentation for 95 | # a list of builtin themes. 96 | html_theme = "sphinx_rtd_theme" 97 | 98 | # directory containing custom CSS file (used to produce bigger thumbnails) 99 | 100 | # on_rtd is whether we are on readthedocs.org 101 | on_rtd = os.environ.get("READTHEDOCS", None) == "True" 102 | 103 | # Add any paths that contain custom static files (such as style sheets) here, 104 | # relative to this directory. They are copied after the builtin static files, 105 | # so a file named "default.css" will overwrite the builtin "default.css". 106 | # By default, this is set to include the _static path. 107 | html_static_path = ["_static"] 108 | html_css_files = ["css/sg_README.css"] 109 | 110 | # The master toctree document. 111 | master_doc = "index" 112 | 113 | html_sidebars = { 114 | "**": ["globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html"] 115 | } 116 | 117 | # http://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_mock_imports 118 | # autodoc_mock_imports incompatible with autosummary somehow 119 | # autodoc_mock_imports = "numpy matplotlib mpl_toolkits scipy torch torchvision Pillow opencv-python imutils PyWavelets pywt wget imageio".split() 120 | 121 | 122 | # exclude all torch.nn.Module members (except forward method) from the docs: 123 | import torch 124 | 125 | 126 | def skip_member_handler(app, what, name, obj, skip, options): 127 | always_document = [ # complete this list if needed by adding methods 128 | "forward", # you *always* want to see documented 129 | ] 130 | if name in always_document: 131 | return None 132 | if name in dir(torch.nn.Module): # used for most of the classes in spyrit 133 | return True 134 | if name in dir(torch.nn.Sequential): # used for FullNet and child classes 135 | return True 136 | return None 137 | 138 | 139 | def setup(app): 140 | app.connect("autodoc-skip-member", skip_member_handler) 141 | -------------------------------------------------------------------------------- /tutorial/tuto_01_a_acquisition_operators.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 01.a. Acquisition operators (basic) 3 | ==================================================== 4 | .. _tuto_acquisition_operators: 5 | 6 | This tutorial shows how to simulate measurements using the :mod:`spyrit.core.meas` submodule. 7 | 8 | 9 | .. image:: ../fig/tuto1.png 10 | :width: 600 11 | :align: center 12 | :alt: Reconstruction architecture sketch 13 | 14 | | 15 | 16 | All simulations are based on :class:`spyrit.core.meas.Linear` base class that simulates linear measurements 17 | 18 | .. math:: 19 | m = Hx, 20 | 21 | where :math:`H\in\mathbb{R}^{M\times N}` is the acquisition matrix, :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`M` is the number of measurements, and :math:`N` is the dimension of the signal. 22 | 23 | .. important:: 24 | The vector :math:`x \in \mathbb{R}^N` represents a multi-dimensional array (e.g, an image :math:`X \in \mathbb{R}^{N_1 \times N_2}` with :math:`N = N_1 \times N_2`). Both variables are related through vectorization , i.e., :math:`x = \texttt{vec}(X)`. 25 | 26 | """ 27 | 28 | # %% 29 | # 1D Measurements 30 | # ----------------------------------------------------------------------------- 31 | 32 | ############################################################################### 33 | # We instantiate a measurement operator from a matrix of shape (10, 15). 34 | import torch 35 | from spyrit.core.meas import Linear 36 | 37 | H = torch.randn(10, 15) 38 | meas_op = Linear(H) 39 | 40 | ############################################################################### 41 | # We consider 3 signals of length 15 42 | x = torch.randn(3, 15) 43 | 44 | ############################################################################### 45 | # We apply the operator to the batch of images, which produces 3 measurements 46 | # of length 10 47 | m = meas_op(x) 48 | print(m.shape) 49 | 50 | ############################################################################### 51 | # We now plot the matrix-vector products 52 | 53 | from spyrit.misc.disp import add_colorbar, noaxis 54 | import matplotlib.pyplot as plt 55 | 56 | f, axs = plt.subplots(1, 3, figsize=(10, 5)) 57 | axs[0].set_title("Forward matrix H") 58 | im = axs[0].imshow(H, cmap="gray") 59 | add_colorbar(im, "bottom") 60 | 61 | axs[1].set_title("Signals x") 62 | im = axs[1].imshow(x.T, cmap="gray") 63 | add_colorbar(im, "bottom") 64 | 65 | axs[2].set_title("Measurements m") 66 | im = axs[2].imshow(m.T, cmap="gray") 67 | add_colorbar(im, "bottom") 68 | 69 | noaxis(axs) 70 | # sphinx_gallery_thumbnail_number = 1 71 | 72 | # %% 73 | # 2D Measurements 74 | # ----------------------------------------------------------------------------- 75 | 76 | ############################################################################### 77 | # We load a batch of images from the :attr:`/images/` folder. Using the 78 | # :func:`transform_gray_norm` function with the :attr:`normalize=False` 79 | # argument returns images with values in (0,1). 80 | import os 81 | import torchvision 82 | from spyrit.misc.statistics import transform_gray_norm 83 | 84 | spyritPath = os.getcwd() 85 | imgs_path = os.path.join(spyritPath, "images/") 86 | 87 | # Grayscale images of size (32, 32), no normalization to keep values in (0,1) 88 | transform = transform_gray_norm(img_size=32, normalize=False) 89 | 90 | # Create dataset and loader (expects class folder :attr:'images/test/') 91 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 92 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 93 | 94 | x, _ = next(iter(dataloader)) 95 | 96 | ############################################################################### 97 | # We crop the batch to get image of shape (9, 25). 98 | x = x[:, :, :9, :25] 99 | print(f"Shape of input images: {x.shape}") 100 | 101 | ############################################################################### 102 | # We plot the second image. 103 | from spyrit.misc.disp import imagesc 104 | 105 | imagesc(x[1, 0, :, :], "Image X") 106 | 107 | 108 | ############################################################################### 109 | # We instantiate a measurement operator from a random matrix with shape (10, 9*25). To indicate that the operator works in 2D, we use the :attr:`meas_shape` argument. 110 | H = torch.randn(10, 9 * 25) 111 | meas_op = Linear(H, meas_shape=(9, 25)) 112 | 113 | ############################################################################### 114 | # We apply the operator to the batch of images, which produces a batch of measurement vectors of length 10. 115 | m = meas_op(x) 116 | print(m.shape) 117 | 118 | 119 | ############################################################################### 120 | # We now plot the matrix-vector products corresponding to the second image in the batch. 121 | 122 | ############################################################################### 123 | # We first select the second image and the second measurement vector in the batch. 124 | x_plot = x[1, 0, :, :] 125 | m_plot = m[1] 126 | 127 | ############################################################################### 128 | # Then we vectorize the image to get a 1D array of length 9*25. 129 | x_plot = x_plot.reshape(1, -1) 130 | 131 | print(f"Vectorised image with shape: {x_plot.shape}") 132 | 133 | ############################################################################### 134 | # We finally plot the matrix-vector products :math:`m = H x = H \texttt{vec}(X)`. 135 | 136 | from spyrit.misc.disp import add_colorbar, noaxis 137 | import matplotlib.pyplot as plt 138 | 139 | f, axs = plt.subplots(1, 3) 140 | axs[0].set_title("Forward matrix H") 141 | im = axs[0].imshow(H, cmap="gray") 142 | # add_colorbar(im, "bottom") 143 | 144 | axs[1].set_title("x = vec(X)") 145 | im = axs[1].imshow(x_plot.mT, cmap="gray") 146 | # add_colorbar(im, "bottom") 147 | 148 | axs[2].set_title("Measurements m") 149 | im = axs[2].imshow(m_plot.mT, cmap="gray") 150 | # add_colorbar(im, "bottom") 151 | 152 | noaxis(axs) 153 | -------------------------------------------------------------------------------- /spyrit/misc/load_data.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | #!/usr/bin/env python3 8 | # -*- coding: utf-8 -*- 9 | """ 10 | Created on Wed Jan 15 17:06:19 2020 11 | 12 | @author: crombez 13 | """ 14 | 15 | import os 16 | import sys 17 | import glob 18 | import numpy as np 19 | import PIL 20 | from typing import Union 21 | 22 | 23 | def Files_names(Path, name_type): 24 | files = glob.glob(Path + name_type) 25 | print 26 | files.sort(key=os.path.getmtime) 27 | return [os.path.basename(x) for x in files] 28 | 29 | 30 | def load_data_recon_3D(Path_files, list_files, Nl, Nc, Nh): 31 | Data = np.zeros((Nl, Nc, Nh)) 32 | 33 | for i in range(0, 2 * Nh, 2): 34 | Data[:, :, i // 2] = np.rot90( 35 | np.array(PIL.Image.open(Path_files + list_files[i])) 36 | ) - np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1]))) 37 | 38 | return Data 39 | 40 | 41 | # Load the data of the hSPIM and compresse the spectrale dimensions to do the reconstruction for every lambda 42 | # odl convention the set of data has to be arranged in such way that the positive part of the hadamard motifs comes first 43 | def load_data_Comp_1D_old(Path_files, list_files, Nh, Nl, Nc): 44 | Data = np.zeros((Nl, Nh)) 45 | 46 | for i in range(0, 2 * Nh, 2): 47 | Data[:, i // 2] = Sum_coll( 48 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc 49 | ) - Sum_coll( 50 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3), 51 | Nl, 52 | Nc, 53 | ) 54 | 55 | return Data 56 | 57 | 58 | # Load the data of the hSPIM and compresse the spectrale dimensions to do the reconstruction for every lambda 59 | # new convention the set of data has to be arranged in such way that the negative part of the hadamard motifs comes first 60 | def load_data_Comp_1D_new(Path_files, list_files, Nh, Nl, Nc): 61 | Data = np.zeros((Nl, Nh)) 62 | 63 | for i in range(0, 2 * Nh, 2): 64 | Data[:, i // 2] = Sum_coll( 65 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3), 66 | Nl, 67 | Nc, 68 | ) - Sum_coll( 69 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc 70 | ) 71 | 72 | return Data 73 | 74 | 75 | def download_girder( 76 | server_url: str, 77 | hex_ids: Union[str, list[str]], 78 | local_folder: str, 79 | file_names: Union[str, list[str]] = None, 80 | ): 81 | """ 82 | Downloads data from a Girder server and saves it locally. 83 | 84 | This function first creates the local folder if it does not exist. Then, it 85 | connects to the Girder server and gets the file names for the files 86 | whose name are not provided. For each file, it checks if it already exists 87 | by checking if the file name is already in the local folder. If not, it 88 | downloads the file. 89 | 90 | Args: 91 | server_url (str): The URL of the Girder server. 92 | 93 | hex_id (str or list[str]): The hexadecimal id of the file(s) to download. 94 | If a list is provided, the files are downloaded in the same order and 95 | are saved in the same folder. 96 | 97 | local_folder (str): The path to the local folder where the files will 98 | be saved. If it does not exist, it will be created. 99 | 100 | file_name (str or list[str], optional): The name of the file(s) to save. 101 | If a list is provided, it must have the same length as hex_id. Each 102 | element equal to `None` will be replaced by the name of the file on the 103 | server. If None, all the names will be obtained from the server. 104 | Default is None. All names include the extension. 105 | 106 | Raises: 107 | ValueError: If the number of file names provided does not match the 108 | number of files to download. 109 | 110 | Returns: 111 | list[str]: The absolute paths to the downloaded files. 112 | """ 113 | # leave import in function, so that the module can be used without 114 | # girder_client 115 | import girder_client 116 | 117 | # check the local folder exists 118 | if not os.path.exists(local_folder): 119 | print("Local folder not found, creating it... ", end="") 120 | os.makedirs(local_folder) 121 | print("done.") 122 | 123 | # connect to the server 124 | gc = girder_client.GirderClient(apiUrl=server_url) 125 | 126 | # create lists if strings are provided 127 | if type(hex_ids) is str: 128 | hex_ids = [hex_ids] 129 | if file_names is None: 130 | file_names = [None] * len(hex_ids) 131 | elif type(file_names) is str: 132 | file_names = [file_names] 133 | 134 | if len(file_names) != len(hex_ids): 135 | raise ValueError("There must be as many file names as hex ids.") 136 | 137 | abs_paths = [] 138 | 139 | # for each file, check if it exists and download if necessary 140 | for id, name in zip(hex_ids, file_names): 141 | 142 | if name is None: 143 | # get the file name 144 | name = gc.getFile(id)["name"] 145 | 146 | # check the file exists 147 | if not os.path.exists(os.path.join(local_folder, name)): 148 | # connect to the server to download the file 149 | print(f"Downloading {name}... ", end="\r") 150 | gc.downloadFile(id, os.path.join(local_folder, name)) 151 | print(f"Downloading {name}... done.") 152 | 153 | else: 154 | print("File already exists at", os.path.join(local_folder, name)) 155 | 156 | abs_paths.append(os.path.abspath(os.path.join(local_folder, name))) 157 | 158 | return abs_paths[0] if len(abs_paths) == 1 else abs_paths 159 | -------------------------------------------------------------------------------- /tutorial/wip/tuto_a00_connect_deepinv.py: -------------------------------------------------------------------------------- 1 | r""" 2 | a00. Connect to deepinverse (HadamSplit2d) 3 | ==================================================== 4 | .. _tuto_connect_deepinv: 5 | 6 | This tutorial shows how to use DeepInverse (https://github.com/deepinv/deepinv) algorithms with a HadamSplit2d linear model. It used the :class:`spyrit.core.meas.HadamSplit2d` class of the :mod:`spyrit.core.meas` submodule. 7 | 8 | 9 | .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_logolarge.png 10 | :width: 600 11 | :align: center 12 | :alt: Reconstruction architecture sketch 13 | 14 | | 15 | 16 | """ 17 | 18 | # %% 19 | # Loads images 20 | # ----------------------------------------------------------------------------- 21 | 22 | ############################################################################### 23 | # We load a batch of images from the :attr:`/images/` folder with values in (0,1). 24 | import os 25 | import torchvision 26 | import torch.nn 27 | 28 | import matplotlib.pyplot as plt 29 | 30 | from spyrit.misc.disp import imagesc 31 | from spyrit.misc.statistics import transform_gray_norm 32 | 33 | import deepinv as dinv 34 | 35 | spyritPath = os.getcwd() 36 | imgs_path = os.path.join(spyritPath, "images/") 37 | 38 | device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" 39 | 40 | # Grayscale images of size (32, 32), no normalization to keep values in (0,1) 41 | transform = transform_gray_norm(img_size=32, normalize=False) 42 | 43 | # Create dataset and loader (expects class folder 'images/test/') 44 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 45 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 46 | 47 | x, _ = next(iter(dataloader)) 48 | print(f"Ground-truth images: {x.shape}") 49 | 50 | ############################################################################### 51 | # We select the second image in the batch and plot it. 52 | 53 | i_plot = 1 54 | imagesc(x[i_plot, 0, :, :], r"$32\times 32$ image $X$") 55 | 56 | # %% 57 | # Basic example 58 | # ----------------------------------------------------------------------------- 59 | 60 | ###################################################################### 61 | # We instantiate an HadamSplit2d object and simulate the 2D hadamard transform of the input images. Reshape output is necesary for deepinv. We also add Poisson noise. 62 | from spyrit.core.meas import HadamSplit2d 63 | import spyrit.core.noise as noise 64 | from spyrit.core.prep import UnsplitRescale 65 | 66 | meas_spyrit = HadamSplit2d(32, 512, device=device, reshape_output=True) 67 | alpha = 50 # image intensity 68 | meas_spyrit.noise_model = noise.Poisson(alpha) 69 | y = meas_spyrit(x) 70 | 71 | # preprocess 72 | prep = UnsplitRescale(alpha) 73 | m_spyrit = prep(y) 74 | 75 | print(y.shape) 76 | 77 | 78 | ###################################################################### 79 | # The norm has to be computed to be passed to deepinv. We need to use the max singular value of the linear operator. 80 | norm = torch.linalg.norm(meas_spyrit.H, ord=2) 81 | print(norm) 82 | 83 | 84 | # %% 85 | # Forward operator 86 | # ---------------------------------------------------------------------- 87 | 88 | ############################################################################### 89 | # You can direcly give the forward operator to deepinv. You can also add noise using deepinv model or spyrit model. 90 | meas_deepinv = dinv.physics.LinearPhysics( 91 | lambda y: meas_spyrit.measure_H(y) / norm, 92 | A_adjoint=lambda y: meas_spyrit.unvectorize(meas_spyrit.adjoint_H(y) / norm), 93 | ) 94 | # meas_deepinv.noise_model = dinv.physics.GaussianNoise(sigma=0.01) 95 | m_deepinv = meas_deepinv(x) 96 | print("diff:", torch.linalg.norm(m_spyrit / norm - m_deepinv)) 97 | 98 | 99 | # %% 100 | # Reconstruction with deepinverse 101 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 102 | 103 | ###################################################################### 104 | # First, use the adjoint and dagger (pseudo-inverse) operators to reconstruct the image. 105 | x_adj = meas_deepinv.A_adjoint(m_spyrit / norm) 106 | imagesc(x_adj[1, 0, :, :].cpu(), "Adjoint") 107 | 108 | x_pinv = meas_deepinv.A_dagger(m_spyrit / norm) 109 | imagesc(x_pinv[1, 0, :, :].cpu(), "Pinv") 110 | 111 | 112 | ###################################################################### 113 | # You can also use optimization-based methods from deepinv. Here, we use Total Variation (TV) regularization with a projected gradient descent (PGD) algorithm. You can note the use of the custom_init parameter to initialize the algorithm with the dagger operator. 114 | model_tv = dinv.optim.optim_builder( 115 | iteration="PGD", 116 | prior=dinv.optim.TVPrior(), 117 | data_fidelity=dinv.optim.L2(), 118 | params_algo={"stepsize": 1, "lambda": 5e-2}, 119 | max_iter=10, 120 | custom_init=lambda y, Physics: {"est": (Physics.A_dagger(y),)}, 121 | ) 122 | 123 | x_tv, metrics_TV = model_tv(m_spyrit / norm, meas_deepinv, compute_metrics=True, x_gt=x) 124 | dinv.utils.plot_curves(metrics_TV) 125 | imagesc(x_tv[1, 0, :, :].cpu(), "TV recon") 126 | 127 | ###################################################################### 128 | # Deep Plug and Play (DPIR) algorithm can also be used with a pretrained denoiser. Here, we use the DRUNet denoiser. 129 | denoiser = dinv.models.DRUNet(in_channels=1, out_channels=1, device=device) 130 | model_dpir = dinv.optim.DPIR(sigma=1e-1, device=device, denoiser=denoiser) 131 | model_dpir.custom_init = lambda y, Physics: {"est": (Physics.A_dagger(y),)} 132 | with torch.no_grad(): 133 | x_dpir = model_dpir(m_spyrit / norm, meas_deepinv) 134 | imagesc(x_dpir[1, 0, :, :].cpu(), "DIPR recon") 135 | 136 | ###################################################################### 137 | # Reconstruct Anything Model (RAM) can also be used. 138 | model_ram = dinv.models.RAM(pretrained=True, device=device) 139 | model_ram.sigma_threshold = 1e-1 140 | with torch.no_grad(): 141 | x_ram = model_ram(m_spyrit / norm, meas_deepinv) 142 | imagesc(x_ram[1, 0, :, :].cpu(), "RAM recon") 143 | -------------------------------------------------------------------------------- /spyrit/misc/color.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Dec 2 20:53:59 2024 4 | 5 | @author: ducros 6 | """ 7 | import numpy as np 8 | import warnings 9 | from typing import Tuple 10 | from matplotlib.colors import LinearSegmentedColormap 11 | 12 | import warnings 13 | from typing import Tuple 14 | 15 | import numpy as np 16 | 17 | 18 | # %% 19 | def wavelength_to_rgb( 20 | wavelength: float, gamma: float = 0.8 21 | ) -> Tuple[float, float, float]: 22 | """Converts wavelength to RGB. 23 | 24 | Based on https://gist.github.com/friendly/67a7df339aa999e2bcfcfec88311abfc. 25 | Itself based on code by Dan Bruton: 26 | http://www.physics.sfasu.edu/astro/color/spectra.html 27 | 28 | Args: 29 | wavelength (float): 30 | Single wavelength to be converted to RGB. 31 | gamma (float, optional): 32 | Gamma correction. Defaults to 0.8. 33 | 34 | Returns: 35 | Tuple[float, float, float]: 36 | RGB value. 37 | """ 38 | 39 | if np.min(wavelength) < 380 or np.max(wavelength) > 750: 40 | warnings.warn("Some wavelengths are not in the visible range [380-750] nm") 41 | 42 | if wavelength >= 380 and wavelength <= 440: 43 | attenuation = 0.3 + 0.7 * (wavelength - 380) / (440 - 380) 44 | R = ((-(wavelength - 440) / (440 - 380)) * attenuation) ** gamma 45 | G = 0.0 46 | B = (1.0 * attenuation) ** gamma 47 | 48 | elif wavelength >= 440 and wavelength <= 490: 49 | R = 0.0 50 | G = ((wavelength - 440) / (490 - 440)) ** gamma 51 | B = 1.0 52 | 53 | elif wavelength >= 490 and wavelength <= 510: 54 | R = 0.0 55 | G = 1.0 56 | B = (-(wavelength - 510) / (510 - 490)) ** gamma 57 | 58 | elif wavelength >= 510 and wavelength <= 580: 59 | R = ((wavelength - 510) / (580 - 510)) ** gamma 60 | G = 1.0 61 | B = 0.0 62 | 63 | elif wavelength >= 580 and wavelength <= 645: 64 | R = 1.0 65 | G = (-(wavelength - 645) / (645 - 580)) ** gamma 66 | B = 0.0 67 | 68 | elif wavelength >= 645 and wavelength <= 750: 69 | attenuation = 0.3 + 0.7 * (750 - wavelength) / (750 - 645) 70 | R = (1.0 * attenuation) ** gamma 71 | G = 0.0 72 | B = 0.0 73 | 74 | else: 75 | R = 0.0 76 | G = 0.0 77 | B = 0.0 78 | 79 | return R, G, B 80 | 81 | 82 | def wavelength_to_rgb_mat(wav_range, gamma=1): 83 | 84 | rgb_mat = np.zeros((len(wav_range), 3)) 85 | 86 | for i, wav in enumerate(wav_range): 87 | rgb_mat[i, :] = wavelength_to_rgb(wav, gamma) 88 | 89 | return rgb_mat 90 | 91 | 92 | def spectral_colorization(M_gray, wav, axis=None): 93 | """ 94 | Colorize the last dimension of an array 95 | 96 | Args: 97 | M_gray (np.ndarray): Grayscale array where the last dimension is the 98 | spectral dimension. This is an A-by-C array, where A can indicate multiple 99 | dimensions (e.g., 4-by-3-by-7) and C is the number of spectral channels. 100 | 101 | wav (np.ndarray): Wavelenth. This is a 1D array of size C. 102 | 103 | axis (None or int or tuple of ints, optional): Axis or axes along which 104 | the grayscale input is normalized. By default, global normalization 105 | across all axes is considered. 106 | 107 | Returns: 108 | M_color (np.ndarray): Color array with an extra dimension. This is an A-by-C-by-3 array. 109 | 110 | """ 111 | 112 | # Normalize to adjust contrast 113 | M_gray_min = M_gray.min(keepdims=True, axis=axis) 114 | M_gray_max = M_gray.max(keepdims=True, axis=axis) 115 | M_gray = (M_gray - M_gray_min) / (M_gray_max - M_gray_min) 116 | 117 | # 118 | rgb_mat = wavelength_to_rgb_mat(wav, gamma=1) 119 | M_red = M_gray @ np.diag(rgb_mat[:, 0]) 120 | M_green = M_gray @ np.diag(rgb_mat[:, 1]) 121 | M_blue = M_gray @ np.diag(rgb_mat[:, 2]) 122 | 123 | M_color = np.stack((M_red, M_green, M_blue), axis=-1) 124 | 125 | return M_color 126 | 127 | 128 | def colorize(im, color, clip_percentile=0.1): 129 | """ 130 | Helper function to create an RGB image from a single-channel image using a 131 | specific color. 132 | """ 133 | # Check that we just have a 2D image 134 | if im.ndim > 2 and im.shape[2] != 1: 135 | raise ValueError("This function expects a single-channel image!") 136 | 137 | # Rescale the image according to how we want to display it 138 | im_scaled = im.astype(np.float32) - np.percentile(im, clip_percentile) 139 | im_scaled = im_scaled / np.percentile(im_scaled, 100 - clip_percentile) 140 | print( 141 | f"Norm: min={np.percentile(im, clip_percentile)}, max={np.percentile(im_scaled, 100 - clip_percentile)}" 142 | ) 143 | print(f"New: min={im_scaled.min()}, max={im_scaled.max()}") 144 | im_scaled = np.clip(im_scaled, 0, 1) 145 | 146 | # Need to make sure we have a channels dimension for the multiplication to work 147 | im_scaled = np.atleast_3d(im_scaled) 148 | 149 | # Reshape the color (here, we assume channels last) 150 | color = np.asarray(color).reshape((1, 1, -1)) 151 | return im_scaled * color 152 | 153 | 154 | def wavelength_to_colormap(wav, gamma=0.6): 155 | """ 156 | Creates a linear Matplotlib colormap that transitions from black to a specific 157 | color corresponding to a given electromagnetic wavelength. 158 | 159 | 160 | Args: 161 | wav (float): The wavelength in nanometers (nm) to determine the target 162 | color. Typically, this would be in the visible spectrum range (~380 163 | to 780 nm). 164 | 165 | gamma (float, optional): The gamma correction factor applied when 166 | calculating the RGB color from the wavelength. Defaults to 0.6. 167 | 168 | Returns: 169 | matplotlib.colors.LinearSegmentedColormap: A custom colormap object 170 | named 'DarkToColor' that spans from black at the low end (0.0) to the 171 | calculated wavelength-based color at the high end (1.0). 172 | 173 | Example: 174 | >>> cmap = wavelength_to_colormap(550, gamma=0.8) # Green color at 550nm 175 | >>> print(cmap) 176 | 177 | """ 178 | 179 | # 'dark_color' is the color at 0.0 (start) 180 | dark_color = ( 181 | "black" # You can use any dark color (e.g., '#000033', 'black', 'darkred') 182 | ) 183 | # 'target_color' is the color at 1.0 (end) 184 | 185 | target_color = wavelength_to_rgb(wav, gamma) 186 | 187 | # 2. Create the list of color nodes (tuples of position and color) 188 | # The colormap will transition linearly between these nodes. 189 | color_list = [(0.0, dark_color), (1.0, target_color)] 190 | 191 | custom_cmap = LinearSegmentedColormap.from_list("DarkToColor", color_list) 192 | 193 | return custom_cmap 194 | -------------------------------------------------------------------------------- /tutorial/tuto_01_b_splitting.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 01.b. Acquisition operators (splitting) 3 | ==================================================== 4 | .. _tuto_acquisition_operators_splitting: 5 | 6 | This tutorial shows how to simulate linear measurements by splitting an acquisition matrix :math:`H\in \mathbb{R}^{M\times N}` that contains negative values. It based on the :class:`spyrit.core.meas.LinearSplit` class of the :mod:`spyrit.core.meas` submodule. 7 | 8 | 9 | .. image:: ../fig/tuto1.png 10 | :width: 600 11 | :align: center 12 | :alt: Reconstruction architecture sketch 13 | 14 | | 15 | 16 | In practice, only positive values can be implemented using a digital micromirror device (DMD). Therefore, we acquire 17 | 18 | .. math:: 19 | y =Ax, 20 | 21 | where :math:`A \colon\, \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains positive DMD patterns, :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`2M` is the number of DMD patterns, and :math:`N` is the dimension of the signal. 22 | 23 | .. important:: 24 | The vector :math:`x \in \mathbb{R}^N` represents a multi-dimensional array (e.g, an image :math:`X \in \mathbb{R}^{N_1 \times N_2}` with :math:`N = N_1 \times N_2`). Both variables are related through vectorization , i.e., :math:`x = \texttt{vec}(X)`. 25 | 26 | Given a matrix :math:`H` with negative values, we define the positive DMD patterns :math:`A` from the positive and negative components :math:`H`. In practice, the even rows of :math:`A` contain the positive components of :math:`H`, while odd rows of :math:`A` contain the negative components of :math:`H`. 27 | 28 | .. math:: 29 | \begin{cases} 30 | A[0::2, :] = H_{+}, \text{ with } H_{+} = \max(0,H),\\ 31 | A[1::2, :] = H_{-}, \text{ with } H_{-} = \max(0,-H). 32 | \end{cases} 33 | 34 | """ 35 | 36 | # %% 37 | # Splitting in 1D 38 | # ----------------------------------------------------------------------------- 39 | 40 | ############################################################################### 41 | # We instantiate a measurement operator from a matrix of shape (10, 15). 42 | import torch 43 | from spyrit.core.meas import LinearSplit 44 | 45 | H = torch.randn(10, 15) 46 | meas_op = LinearSplit(H) 47 | 48 | ############################################################################### 49 | # We consider 3 signals of length 15. 50 | x = torch.randn(3, 15) 51 | 52 | ############################################################################### 53 | # We apply the operator to the batch of images, which produces 3 measurements 54 | # of length 10*2 = 20. 55 | y = meas_op(x) 56 | print(y.shape) 57 | 58 | ############################################################################### 59 | # .. note:: 60 | # The number of measurements is twice the number of rows of the matrix H that contains negative values. 61 | 62 | # %% 63 | # Illustration 64 | # ----------------------------------------------------------------------------- 65 | 66 | ############################################################################### 67 | # We plot the positive and negative components of H that are concatenated in the matrix A. 68 | 69 | A = meas_op.A 70 | H_pos = meas_op.A[::2, :] # Even rows 71 | H_neg = meas_op.A[1::2, :] # Odd rows 72 | 73 | from spyrit.misc.disp import add_colorbar, noaxis 74 | import matplotlib.pyplot as plt 75 | 76 | fig = plt.figure(figsize=(10, 5)) 77 | gs = fig.add_gridspec(2, 2) 78 | 79 | ax1 = fig.add_subplot(gs[:, 0]) 80 | ax2 = fig.add_subplot(gs[0, 1]) 81 | ax3 = fig.add_subplot(gs[1, 1]) 82 | 83 | ax1.set_title("Forward matrix A") 84 | im = ax1.imshow(A, cmap="gray") 85 | add_colorbar(im) 86 | 87 | ax2.set_title("Forward matrix H_pos") 88 | im = ax2.imshow(H_pos, cmap="gray") 89 | add_colorbar(im) 90 | 91 | ax3.set_title("Measurements H_neg") 92 | im = ax3.imshow(H_neg, cmap="gray") 93 | add_colorbar(im) 94 | 95 | noaxis(ax1) 96 | noaxis(ax2) 97 | noaxis(ax3) 98 | # sphinx_gallery_thumbnail_number = 1 99 | 100 | ############################################################################### 101 | # We can verify numerically that H = H_pos - H_neg 102 | 103 | H = meas_op.H 104 | diff = torch.linalg.norm(H - (H_pos - H_neg)) 105 | 106 | print(f"|| H - (H_pos - H_neg) || = {diff}") 107 | 108 | ############################################################################### 109 | # We now plot the matrix-vector products between A and x. 110 | 111 | f, axs = plt.subplots(1, 3, figsize=(10, 5)) 112 | axs[0].set_title("Forward matrix A") 113 | im = axs[0].imshow(A, cmap="gray") 114 | add_colorbar(im, "bottom") 115 | 116 | axs[1].set_title("Signals x") 117 | im = axs[1].imshow(x.T, cmap="gray") 118 | add_colorbar(im, "bottom") 119 | 120 | axs[2].set_title("Split measurements y") 121 | im = axs[2].imshow(y.T, cmap="gray") 122 | add_colorbar(im, "bottom") 123 | 124 | noaxis(axs) 125 | 126 | # %% 127 | # Simulations with noise and using the matrix H 128 | # -------------------------------------------------------------------- 129 | 130 | ###################################################################### 131 | # The operators in the :mod:`spyrit.core.meas` submodule allow for simulating noisy measurements 132 | # 133 | # .. math:: 134 | # y =\mathcal{N}\left(Ax\right), 135 | # 136 | # where :math:`\mathcal{N} \colon\, \mathbb{R}^{2M} \to \mathbb{R}^{2M}` represents a noise operator (e.g., Gaussian). By default, no noise is applied to the measurement, i.e., :math:`\mathcal{N}` is the identity. We can consider noise by setting the :attr:`noise_model` attribute of the :class:`spyrit.core.meas.LinearSplit` class. 137 | 138 | ##################################################################### 139 | # For instance, we can consider additive Gaussian noise with standard deviation 2. 140 | 141 | from spyrit.core.noise import Gaussian 142 | 143 | meas_op.noise_model = Gaussian(2) 144 | 145 | ##################################################################### 146 | # .. note:: 147 | # To learn more about noise models, please refer to :ref:`tutorial 2 `. 148 | 149 | ##################################################################### 150 | # We simulate the noisy measurement vectors 151 | y_noise = meas_op(x) 152 | 153 | ##################################################################### 154 | # Noiseless measurements can be simulated using the :meth:`spyrit.core.LinearSplit.measure` method. 155 | y_nonoise = meas_op.measure(x) 156 | 157 | ##################################################################### 158 | # The :meth:`spyrit.core.LinearSplit.measure_H` method simulates noiseless measurements using the matrix H, i.e., :math:`m = Hx`. 159 | m_nonoise = meas_op.measure_H(x) 160 | 161 | ##################################################################### 162 | # We now plot the noisy and noiseless measurements 163 | f, axs = plt.subplots(1, 3, figsize=(8, 5)) 164 | axs[0].set_title("Split measurements y \n with noise") 165 | im = axs[0].imshow(y_noise.mT, cmap="gray") 166 | add_colorbar(im) 167 | 168 | axs[1].set_title("Split measurements y \n without noise") 169 | im = axs[1].imshow(y_nonoise.mT, cmap="gray") 170 | add_colorbar(im) 171 | 172 | axs[2].set_title("Measurements m \n without noise") 173 | im = axs[2].imshow(m_nonoise.mT, cmap="gray") 174 | add_colorbar(im) 175 | 176 | noaxis(axs) 177 | -------------------------------------------------------------------------------- /spyrit/dev/prep.py: -------------------------------------------------------------------------------- 1 | # ================================================================================== 2 | class Preprocess_shift_poisson(nn.Module): # header needs to be updated! 3 | # ================================================================================== 4 | r"""Preprocess the measurements acquired using shifted patterns corrupted 5 | by Poisson noise 6 | 7 | Computes: 8 | m = (2 m_shift - m_offset)/N_0 9 | var = 4*Diag(m_shift + m_offset)/alpha**2 10 | Warning: dark measurement is assumed to be the 0-th entry of raw measurements 11 | 12 | Args: 13 | - :math:`alpha`: noise level 14 | - :math:`M`: number of measurements 15 | - :math:`N`: number of image pixels 16 | 17 | Shape: 18 | - Input1: scalar 19 | - Input2: scalar 20 | - Input3: scalar 21 | 22 | Example: 23 | >>> PSP = Preprocess_shift_poisson(9, 400, 32*32) 24 | """ 25 | 26 | def __init__(self, alpha, M, N): 27 | super().__init__() 28 | self.alpha = alpha 29 | self.N = N 30 | self.M = M 31 | 32 | def forward(self, x: torch.tensor, meas_op: Linear) -> torch.tensor: 33 | r""" 34 | 35 | Warning: 36 | - The offset measurement is the 0-th entry of the raw measurements. 37 | 38 | Args: 39 | - :math:`x`: Batch of images in Hadamard domain shifted by 1 40 | - :math:`meas_op`: Forward_operator 41 | 42 | Shape: 43 | - Input: :math:`(b*c, M+1)` 44 | - Output: :math:`(b*c, M)` 45 | 46 | Example: 47 | >>> Hsub = np.array(np.random.random([400,32*32])) 48 | >>> FO = Forward_operator(Hsub) 49 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float) 50 | >>> y_PSP = PSP(x, FO) 51 | >>> print(y_PSP.shape) 52 | torch.Size([10, 400]) 53 | 54 | """ 55 | y = self.offset(x) 56 | x = 2 * x[:, 1:] - y.expand( 57 | x.shape[0], self.M 58 | ) # Warning: dark measurement is the 0-th entry 59 | x = x / self.alpha 60 | x = 2 * x - meas_op.H( 61 | torch.ones(x.shape[0], self.N).to(x.device) 62 | ) # to shift images in [-1,1]^N 63 | return x 64 | 65 | def sigma(self, x): 66 | r""" 67 | Args: 68 | - :math:`x`: Batch of images in Hadamard domain shifted by 1 69 | 70 | Shape: 71 | - Input: :math:`(b*c, M+1)` 72 | 73 | Example: 74 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float) 75 | >>> sigma_PSP = PSP.sigma(x) 76 | >>> print(sigma_PSP.shape) 77 | torch.Size([10, 400]) 78 | """ 79 | # input x is a set of measurement vectors with shape (b*c, M+1) 80 | # output is a set of measurement vectors with shape (b*c,M) 81 | y = self.offset(x) 82 | x = 4 * x[:, 1:] + y.expand(x.shape[0], self.M) 83 | x = x / (self.alpha**2) 84 | x = 4 * x # to shift images in [-1,1]^N 85 | return x 86 | 87 | def cov(self, x): # return a full matrix ? It is such that Diag(a) + b 88 | return x 89 | 90 | def sigma_from_image(self, x, meas_op): # should check this! 91 | # input x is a set of images with shape (b*c, N) 92 | # input meas_op is a Forward_operator 93 | x = meas_op.H(x) 94 | y = self.offset(x) 95 | x = x[:, 1:] + y.expand(x.shape[0], self.M) 96 | x = x / (self.alpha) # here the alpha contribution is not squared. 97 | return x 98 | 99 | def offset(self, x): 100 | r"""Get offset component from bach of shifted images. 101 | 102 | Args: 103 | - :math:`x`: Batch of shifted images 104 | 105 | Shape: 106 | - Input: :math:`(bc, M+1)` 107 | - Output: :math:`(bc, 1)` 108 | 109 | Example: 110 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float) 111 | >>> y = PSP.offset(x) 112 | >>> print(y.shape) 113 | torch.Size([10, 1]) 114 | 115 | """ 116 | y = x[:, 0, None] 117 | return y 118 | 119 | 120 | # ================================================================================== 121 | class Preprocess_pos_poisson(nn.Module): # header needs to be updated! 122 | # ================================================================================== 123 | r"""Preprocess the measurements acquired using positive (shifted) patterns 124 | corrupted by Poisson noise 125 | 126 | The output value of the layer with input size :math:`(B*C, M)` can be 127 | described as: 128 | 129 | .. math:: 130 | \text{out}((B*C)_i, M_j}) = 2*\text{input}((B*C)_i, M_j}) - 131 | \sum_{k = 1}^{M-1} \text{input}((B*C)_i, M_k}) 132 | 133 | The output size of the layer is :math:`(B*C, M)`, which is the imput size 134 | 135 | 136 | Warning: 137 | dark measurement is assumed to be the 0-th entry of raw measurements 138 | 139 | Args: 140 | - :math:`alpha`: noise level 141 | - :math:`M`: number of measurements 142 | - :math:`N`: number of image pixels 143 | 144 | Shape: 145 | - Input1: scalar 146 | - Input2: scalar 147 | - Input3: scalar 148 | 149 | Example: 150 | >>> PPP = Preprocess_pos_poisson(9, 400, 32*32) 151 | 152 | """ 153 | 154 | def __init__(self, alpha, M, N): 155 | super().__init__() 156 | self.alpha = alpha 157 | self.N = N 158 | self.M = M 159 | 160 | def forward(self, x: torch.tensor, meas_op: Linear) -> torch.tensor: 161 | r""" 162 | Args: 163 | - :math:`x`: noise level 164 | - :math:`meas_op`: Forward_operator 165 | 166 | Shape: 167 | - Input1: :math:`(bc, M)` 168 | - Input2: None 169 | - Output: :math:`(bc, M)` 170 | 171 | Example: 172 | >>> Hsub = np.array(np.random.random([400,32*32])) 173 | >>> meas_op = Forward_operator(Hsub) 174 | >>> x = torch.tensor(np.random.random([10, 400]), dtype=torch.float) 175 | >>> y = PPP(x, meas_op) 176 | torch.Size([10, 400]) 177 | 178 | """ 179 | y = self.offset(x) 180 | x = 2 * x - y.expand(-1, self.M) 181 | x = x / self.alpha 182 | x = 2 * x - meas_op.H( 183 | torch.ones(x.shape[0], self.N).to(x.device) 184 | ) # to shift images in [-1,1]^N 185 | return x 186 | 187 | def offset(self, x): 188 | r"""Get offset component from bach of shifted images. 189 | 190 | Args: 191 | - :math:`x`: Batch of shifted images 192 | 193 | Shape: 194 | - Input: :math:`(bc, M)` 195 | - Output: :math:`(bc, 1)` 196 | 197 | Example: 198 | >>> x = torch.tensor(np.random.random([10, 400]), dtype=torch.float) 199 | >>> y = PPP.offset(x) 200 | >>> print(y.shape) 201 | torch.Size([10, 1]) 202 | 203 | """ 204 | y = 2 / (self.M - 2) * x[:, 1:].sum(dim=1, keepdim=True) 205 | return y 206 | -------------------------------------------------------------------------------- /spyrit/dev/meas.py: -------------------------------------------------------------------------------- 1 | # ================================================================================== 2 | class Linear_shift(Linear): 3 | # ================================================================================== 4 | r"""Linear with shifted pattern matrix of size :math:`(M+1,N)` and :math:`Perm` matrix of size :math:`(N,N)`. 5 | 6 | Args: 7 | - Hsub: subsampled Hadamard matrix 8 | - Perm: Permuation matrix 9 | 10 | Shape: 11 | - Input1: :math:`(M, N)` 12 | - Input2: :math:`(N, N)` 13 | 14 | Example: 15 | >>> Hsub = np.array(np.random.random([400,32*32])) 16 | >>> Perm = np.array(np.random.random([32*32,32*32])) 17 | >>> FO_Shift = Linear_shift(Hsub, Perm) 18 | 19 | """ 20 | 21 | def __init__(self, Hsub, Perm): 22 | super().__init__(Hsub) 23 | 24 | # Todo: Use index rather than permutation (see misc.walsh_hadamard) 25 | self.Perm = nn.Linear(self.N, self.N, False) 26 | self.Perm.weight.data = torch.from_numpy(Perm.T) 27 | self.Perm.weight.data = self.Perm.weight.data.float() 28 | self.Perm.weight.requires_grad = False 29 | 30 | H_shift = torch.cat((torch.ones((1, self.N)), (self.Hsub.weight.data + 1) / 2)) 31 | 32 | self.H_shift = nn.Linear(self.N, self.M + 1, False) 33 | self.H_shift.weight.data = H_shift # include the all-one pattern 34 | self.H_shift.weight.data = self.H_shift.weight.data.float() # keep ? 35 | self.H_shift.weight.requires_grad = False 36 | 37 | def forward(self, x: torch.tensor) -> torch.tensor: 38 | r"""Applies Linear transform such that :math:`y = \begin{bmatrix}{1}\\{H_{sub}}\end{bmatrix}x`. 39 | 40 | Args: 41 | :math:`x`: batch of images. 42 | 43 | Shape: 44 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of pixels in the image. 45 | - Output: :math:`(b*c, M+1)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`M+1` the number of measurements + 1. 46 | 47 | Example: 48 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float) 49 | >>> y = FO_Shift(x) 50 | >>> print(y.shape) 51 | torch.Size([10, 401]) 52 | """ 53 | # input x is a set of images with shape (b*c, N) 54 | # output input is a set of measurement vector with shape (b*c, M+1) 55 | x = self.H_shift(x) 56 | return x 57 | 58 | # x_shift = super().forward(x) - x_dark.expand(x.shape[0],self.M) # (H-1/2)x 59 | 60 | 61 | # ================================================================================== 62 | class Linear_pos(Linear): 63 | # ================================================================================== 64 | r"""Linear with Permutation Matrix :math:`Perm` of size :math:`(N,N)`. 65 | 66 | Args: 67 | - Hsub: subsampled Hadamard matrix 68 | - Perm: Permuation matrix 69 | 70 | Shape: 71 | - Input1: :math:`(M, N)` 72 | - Input2: :math:`(N, N)` 73 | 74 | Example: 75 | >>> Hsub = np.array(np.random.random([400,32*32])) 76 | >>> Perm = np.array(np.random.random([32*32,32*32])) 77 | >>> meas_op_pos = Linear_pos(Hsub, Perm) 78 | """ 79 | 80 | def __init__(self, Hsub, Perm): 81 | super().__init__(Hsub) 82 | 83 | # Todo: Use index rather than permutation (see misc.walsh_hadamard) 84 | self.Perm = nn.Linear(self.N, self.N, False) 85 | self.Perm.weight.data = torch.from_numpy(Perm.T) 86 | self.Perm.weight.data = self.Perm.weight.data.float() 87 | self.Perm.weight.requires_grad = False 88 | 89 | def forward(self, x: torch.tensor) -> torch.tensor: 90 | r"""Computes :math:`y` according to :math:`y=0.5(H_{sub}x+\sum_{j=1}^{N}x_{j})` where :math:`j` is the pixel (column) index of :math:`x`. 91 | 92 | Args: 93 | :math:`x`: Batch of images. 94 | 95 | Shape: 96 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of pixels in the image. 97 | - Output: :math:`(b*c, M)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`M` the number of measurements. 98 | 99 | Example: 100 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float) 101 | >>> y = meas_op_pos(x) 102 | >>> print(y.shape) 103 | torch.Size([100, 400]) 104 | """ 105 | # input x is a set of images with shape (b*c, N) 106 | # output is a set of measurement vectors with shape (b*c, M) 107 | 108 | # compute 1/2(H+1)x = 1/2 HX + 1/2 1x 109 | x = super().forward(x) + x.sum(dim=1, keepdim=True).expand(-1, self.M) 110 | x *= 0.5 111 | 112 | return x 113 | 114 | 115 | # ================================================================================== 116 | class Linear_shift_had(Linear_shift): 117 | # ================================================================================== 118 | r"""Linear_shift operator with inverse method. 119 | 120 | Args: 121 | - Hsub: subsampled Hadamard matrix 122 | - Perm: Permuation matrix 123 | 124 | Shape: 125 | - Input1: :math:`(M, N)` 126 | - Input2: :math:`(N, N)`. 127 | 128 | Example: 129 | >>> Hsub = np.array(np.random.random([400,32*32])) 130 | >>> Perm = np.array(np.random.random([32*32,32*32])) 131 | >>> FO_Shift_Had = Linear_shift_had(Hsub, Perm) 132 | """ 133 | 134 | def __init__(self, Hsub, Perm): 135 | super().__init__(Hsub, Perm) 136 | 137 | def inverse(self, x: torch.tensor, n: Union[None, int] = None) -> torch.tensor: 138 | r"""Inverse transform such that :math:`x = \frac{1}{N}H_{sub}y`. 139 | 140 | Args: 141 | :math:`x`: Batch of completed measurements. 142 | 143 | Shape: 144 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of measurements. 145 | - Output: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of reconstructed. pixels. 146 | 147 | Example: 148 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float) 149 | >>> x_reconstruct = FO_Shift_Had.inverse(y_pad) 150 | >>> print(x_reconstruct.shape) 151 | torch.Size([10, 1024]) 152 | """ 153 | # rearrange the terms + inverse transform 154 | # maybe needs to be initialised with a permutation matrix as well! 155 | # Permutation matrix may be sparsified when sparse tensors are no longer in 156 | # beta (as of pytorch 1.11, it is still in beta). 157 | 158 | # --> Use index rather than permutation (see misc.walsh_hadamard) 159 | 160 | # input x is a set of **measurements** with shape (b*c, N) 161 | # output is a set of **images** with shape (b*c, N) 162 | bc, N = x.shape 163 | x = self.Perm(x) 164 | 165 | if n is None: 166 | n = int(np.sqrt(N)) 167 | 168 | # Inverse transform 169 | x = x.reshape(bc, 1, n, n) 170 | x = ( 171 | 1 / self.N * walsh2_torch(x) 172 | ) # todo: initialize with 1D transform to speed up 173 | x = x.reshape(bc, N) 174 | return x 175 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /tutorial/tuto_03_pseudoinverse_linear.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 03. Pseudoinverse solution from linear measurements 3 | =================================================== 4 | .. _tuto_pseudoinverse_linear: 5 | 6 | This tutorial shows how to simulate measurements and perform image reconstruction using the :class:`spyrit.core.inverse.PseudoInverse` class of the :mod:`spyrit.core.inverse` submodule. 7 | 8 | .. image:: ../fig/tuto3_pinv.png 9 | :width: 600 10 | :align: center 11 | :alt: Reconstruction architecture sketch 12 | 13 | | 14 | """ 15 | 16 | # %% 17 | # Loads images 18 | # ----------------------------------------------------------------------------- 19 | 20 | ############################################################################### 21 | # We load a batch of images from the :attr:`/images/` folder. Using the 22 | # :func:`spyrit.misc.statistics.transform_gray_norm` function with the :attr:`normalize=False` 23 | # argument returns images with values in (0,1). 24 | import os 25 | import torchvision 26 | import torch.nn 27 | from spyrit.misc.statistics import transform_gray_norm 28 | 29 | spyritPath = os.getcwd() 30 | imgs_path = os.path.join(spyritPath, "images/") 31 | 32 | # Grayscale images of size 32 x 32, no normalization to keep values in (0,1) 33 | transform = transform_gray_norm(img_size=64, normalize=False) 34 | 35 | # Create dataset and loader (expects class folder 'images/test/') 36 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 37 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 38 | 39 | x, _ = next(iter(dataloader)) 40 | print(f"Ground-truth images: {x.shape}") 41 | 42 | 43 | # %% 44 | # Linear measurements without noise 45 | # ----------------------------------------------------------------------------- 46 | 47 | ############################################################################### 48 | # We consider a Hadamard matrix in "2D". The matrix has a shape of (64*64, 64*64) and values in {-1, 1}. 49 | from spyrit.core.torch import walsh_matrix_2d 50 | 51 | H = walsh_matrix_2d(64) 52 | 53 | print(f"Acquisition matrix: {H.shape}", end=" ") 54 | print(rf"with values in {{{H.min()}, {H.max()}}}") 55 | 56 | ############################################################################### 57 | # We instantiate a :class:`spyrit.core.meas.Linear` operator. To indicate that the operator works in 2D, on images with shape (64, 64), we use the :attr:`meas_shape` argument. 58 | from spyrit.core.meas import Linear 59 | 60 | meas_op = Linear(H, (64, 64)) 61 | 62 | ############################################################################### 63 | # We simulate the measurement vectors, which has a shape of (7, 1, 4096). 64 | y = meas_op(x) 65 | 66 | print(f"Measurement vectors: {y.shape}") 67 | 68 | ############################################################################### 69 | # We now compute the pseudo inverse solutions, which have a shape of (7, 1, 64, 64). 70 | from spyrit.core.inverse import PseudoInverse 71 | 72 | pinv = PseudoInverse(meas_op) 73 | x_rec = pinv(y) 74 | 75 | print(f"Reconstructed images: {x_rec.shape}") 76 | 77 | ############################################################################### 78 | # We plot the reconstruction of the second image in the batch 79 | from spyrit.misc.disp import imagesc, add_colorbar 80 | 81 | imagesc(x_rec[1, 0]) 82 | 83 | ############################################################################### 84 | # .. note:: 85 | # The measurement operator is chosen as a Hadamard matrix but any other matrix can be used as well. 86 | 87 | # %% 88 | # LinearSplit measurements with Gaussian noise 89 | # ----------------------------------------------------------------------------- 90 | 91 | ############################################################################### 92 | # We consider a linear operator where the positive and negative components are split, i.e. acquired separately. To do so, we instantiate a :class:`spyrit.core.meas.LinearSplit` operator. 93 | from spyrit.core.meas import LinearSplit 94 | 95 | meas_op = LinearSplit(H, (64, 64)) 96 | 97 | ############################################################################### 98 | # We consider additive Gaussian noise with standard deviation 2. 99 | from spyrit.core.noise import Gaussian 100 | 101 | meas_op.noise_model = Gaussian(2) 102 | 103 | ############################################################################### 104 | # We simulate the measurement vectors, which have shape (7, 1, 8192). 105 | y = meas_op(x) 106 | 107 | print(f"Measurement vectors: {y.shape}") 108 | 109 | ############################################################################### 110 | # We preprocess measurement vectors by computing the difference of the positive and negative components of the measurement vectors. To do so, we use the :class:`spyrit.core.prep.Unsplit` class. The preprocess measurements have a shape of (7, 1, 4096). 111 | 112 | from spyrit.core.prep import Unsplit 113 | 114 | prep = Unsplit() 115 | m = prep(y) 116 | 117 | print(f"Preprocessed measurement vectors: {m.shape}") 118 | 119 | ############################################################################### 120 | # We now compute the pseudo inverse solutions, which have a shape of (7, 1, 64, 64). 121 | from spyrit.core.inverse import PseudoInverse 122 | 123 | pinv = PseudoInverse(meas_op) 124 | x_rec = pinv(m) 125 | 126 | print(f"Reconstructed images: {x_rec.shape}") 127 | 128 | ############################################################################### 129 | # We plot the reconstruction 130 | from spyrit.misc.disp import imagesc 131 | 132 | imagesc(x_rec[1, 0]) 133 | 134 | # %% 135 | # HadamSplit2d with x4 subsampling with Poisson noise 136 | # ----------------------------------------------------------------------------- 137 | 138 | ############################################################################### 139 | # We consider the acquisition of the 2D Hadamard transform of an image, where the positive and negative components of acquisition matrix are acquired separately. To do so, we use the dedicated :class:`spyrit.core.meas.HadamSplit2d` operator. It also allows for subsampling the rows the Hadamard matrix, using a sampling map. 140 | 141 | from spyrit.core.meas import HadamSplit2d 142 | 143 | # Sampling map with ones in the top left corner and zeros elsewhere (low-frequency subsampling) 144 | sampling_map = torch.ones((64, 64)) 145 | sampling_map[:, 64 // 2 :] = 0 146 | sampling_map[64 // 2 :, :] = 0 147 | 148 | # Linear operator with HadamSplit2d 149 | meas_op = HadamSplit2d(64, 64**2 // 4, order=sampling_map, reshape_output=True) 150 | 151 | ############################################################################### 152 | # We consider additive Poisson noise with an intensity of 100 photons. 153 | from spyrit.core.noise import Poisson 154 | 155 | meas_op.noise_model = Poisson(100) 156 | 157 | 158 | ############################################################################### 159 | # We simulate the measurement vectors, which have a shape of (7, 1, 2048) 160 | 161 | ############################################################################### 162 | # .. note:: 163 | # The :class:`spyrit.core.noise.Poisson` class noise assumes that the images are in the range [0, 1] 164 | y = meas_op(x) 165 | 166 | print(rf"Reference images with values in {{{x.min()}, {x.max()}}}") 167 | print(f"Measurement vectors: {y.shape}") 168 | 169 | ############################################################################### 170 | # We preprocess measurement vectors by i) computing the difference of the positive and negative components, and ii) normalizing the intensity. To do so, we use the :class:`spyrit.core.prep.UnsplitRescale` class. The preprocessed measurements have a shape of (7, 1, 1024). 171 | 172 | from spyrit.core.prep import UnsplitRescale 173 | 174 | prep = UnsplitRescale(100) 175 | 176 | m = prep(y) # (y+ - y-)/alpha 177 | print(f"Preprocessed measurement vectors: {m.shape}") 178 | 179 | ############################################################################### 180 | # We compute the pseudo inverse solution, which has a shape of (7, 1, 64, 64). 181 | 182 | x_rec = meas_op.fast_pinv(m) 183 | 184 | print(f"Reconstructed images: {x_rec.shape}") 185 | 186 | ############################################################################### 187 | # .. note:: 188 | # There is no need to use the :class:`spyrit.core.inverse.PseudoInverse` class here, as the :class:`spyrit.core.meas.HadamSplit2d` class includes a method that returns the pseudo inverse solution. 189 | 190 | ############################################################################### 191 | # We plot the reconstruction 192 | 193 | imagesc(x_rec[1, 0]) 194 | # sphinx_gallery_thumbnail_number = 3 195 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ### Notations 4 | 5 | \- removals 6 | / changes 7 | \+ additions 8 | 9 | --- 10 | 11 |
12 | 13 | 14 | ## Changes to come in a future version 15 | 16 | 17 | 18 | 19 |
20 | 21 | --- 22 | 23 |
24 | 25 | ## v2.3.4 26 | 27 | 28 | ### spyrit.core 29 | * #### General changes 30 | * The input and output shapes have been standardized across operators. All still images (i.e. not videos) have shape `(*, h, w)`, where `*` is any batch dimension (e.g. batch size and number of channels), and `h` and `w` are the height and width of the image. All measurements have shape `(*, M)`, where `*` is the same batch dimension than the images they come from. Videos have shape `(*, t, c, h, w)` where `t` is the time dimension, representing the number of frames in the video, `c` is the number of channels. Dynamic measurements from videos will thus have shape `(*, c, M)`. 31 | * The overall use of gpu has been improved. Every class of the `core` module now has a method `self.device` that allows to track the device on which its parameters are. 32 | * #### spyrit.core.meas 33 | * / The regularization value 'L1' has been changed to 'rcond'. The behavior is unchanged but the reconstruction did not correspond to L1 regularization. 34 | * / Fixed .pinv() output shape (it was transposed with some regularisation methods) 35 | * / Fixed some device errors when using cuda with .pinv() 36 | * / The measurement matrix H is now stored with the data type it is given to the constructor (it was previously converted to torch.float32 for memory reasons) 37 | * \+ added in the .pinv() method a diff parameter enabling differentiated reconstructions (subtracting negative patterns/measurements to the positive patterns/measurements), only available for dynamic operators. 38 | * / For HadamSplit, the pinv has been overwritten to use a fast Walsh-Hadamard transform, zero-padding the measurements if necessary (in the case of subsampling). The inverse() method has been deprecated and will be removed in a future release. 39 | * #### spyrit.core.recon 40 | * \- The class core.recon.Denoise_layer is deprecated and will be removed in a future version 41 | * / The class TikhonovMeasurementPriorDiag no longer uses Denoise_layer and uses instead an internal method to handle the denoising. 42 | * #### spyrit.core.train 43 | * / load_net() uses the weights_only=True parameter in the torch.load() function. Documentation updated. 44 | * #### spyrit.core.warp 45 | * / The warping operation (forward method) now has to be performed on (b,c,h,w) input tensors, and returns (b, time, c, h, w) output tensors. 46 | * / The AffineDeformationField does not store anymore the field as an attribute, but is rather generated on the fly. This allows for more efficient memory management. 47 | * / In AffineDeformationField the image size can be changed. 48 | * \+ It is now possible to use biquintic (5th-order) warping. This uses scikit-image's (skimage) warp function, which relies on numpy arrays. 49 | 50 | ### Tutorials 51 | * Tutorial 2 integrated the change from 'L1' to 'rcond' 52 | * All Tutorials have been updated to include the above mentioned changes. 53 | 54 |
55 | 56 | --- 57 | 58 |
59 | 60 | ## v2.3.3 61 | 62 | 63 | ### spyrit.core 64 | * #### spyrit.core.meas 65 | * / The regularization value 'L1' has been changed to 'rcond'. The behavior is unchanged but the reconstruction did not correspond to L1 regularization. 66 | * #### spyrit.core.recon 67 | * / The documentation for the class core.recon.Denoise_layer has been clarified. 68 | 69 | ### Tutorials 70 | 71 | * Tutorial 2 integrated the change from 'L1' to 'rcond' 72 | 73 |
74 | 75 | --- 76 | 77 |
78 | 79 | ## v2.3.2 80 | 81 | 82 | ### spyrit.core 83 | * #### spyrit.core.meas 84 | * / The method forward_H has been optimized for the HadamSplit class 85 | * #### spyrit.core.torch 86 | * \+ Added spyrit.core.torch.fwht that implements in Pytorch the fast Walsh-Hadamard tranform for natural and Walsh ordered tranforms. 87 | * \+ Added spyrit.core.torch.fwht_2d that implements in Pytorch the fast Walsh-Hadamard tranform in 2 dimensions for natural and Walsh ordered tranforms. 88 | 89 | ### spyrit.misc 90 | 91 | * #### spyrit.misc.statistics 92 | * / The function spyrit.misc.statistics.Cov2Var has been sped up and now supports an output shape for non-square images 93 | * #### spyrit.misc.walsh_hadamard 94 | * / The function spyrit.misc.walsh_hadamard.fwht has been significantly sped up, especially for sequency-ordered walsh-hadamard tranforms. 95 | * \- fwht_torch is now deprecated. Use spyrit.core.torch.fwht instead. 96 | * \- walsh_torch is now deprecated. Use spyrit.core.torch.fwht instead. 97 | * \- walsh2_torch is now deprecated. Use spyrit.core.torch.fwht_2d instead. 98 | * #### spyrit.misc.load_data 99 | * \+ New function download_girder that downloads files identified by their hexadecimal ID from a url server 100 | 101 | ### Tutorials 102 | 103 | * Tutorials 3, 4, 6, 7, 8 now download data from our own servers instead of using google drive and the gdown library. Dependency on gdown library will be fully removed in a future version. 104 | 105 |
106 | 107 | --- 108 | 109 |
110 | 111 | ## v2.3.1 112 | 113 | 114 | ### spyrit.core 115 | 116 | * #### spyrit.core.meas 117 | * \+ For static classes, self.set_H_pinv has been renamed to self.build_H_pinv to match with the dynamic classes. 118 | * \+ The dynamic classes now support bicubic dynamic reconstruction (spyrit.core.meas.DynamicLinear.build_h_dyn()). This uses cubic B-splines. 119 | * #### spyrit.core.train 120 | * load_net() must take the full path, **with** the extension name (xyz.pth). 121 | 122 | ### Tutorials 123 | 124 | * Tutorial 6 has been changed accordingly to the modification of spyrit.core.train.load_net(). 125 | * Tutorial 8 is now available. 126 | 127 |
128 | 129 | --- 130 | 131 |
132 | 133 | ## v2.3.0 134 | 135 | 136 |
137 | 138 | ### spyrit.core 139 | 140 | 141 | * / no longer supports numpy.array as input, must use torch.tensor 142 | * #### spyrit.core.meas 143 | * \- class LinearRowSplit (use LinearSplit instead) 144 | * \+ 3 dynamic classes: DynamicLinear, DynamicLinearSplit, DynamicHadamSplit that allow measurements over time 145 | * spyrit.core.meas.Linear 146 | * \- self.get_H() deprecated (use self.H) 147 | * \- self.H_adjoint (you might want to use self.H.T) 148 | * / constructor argument 'reg' renamed to 'rtol' 149 | * / self.H no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable) 150 | * / self.H_pinv no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable) 151 | * \+ self.__init__() has 'Ord' and 'meas_shape' optional arguments 152 | * \+ self.pinv() now supports lstsq image reconstruction if self.H_pinv is not defined 153 | * \+ self.set_H_pinv(), self.reindex() inherited from spyrit.misc.torch 154 | * \+ self.meas_shape, self.indices, self.Ord, self.H_static 155 | * spyrit.core.meas.LinearSplit 156 | * / [includes changes from Linear] 157 | * / self.P no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable) 158 | * spyrit.core.meas.HadamSplit 159 | * / [includes changes from LinearSplit] 160 | * \- self.__init__() does not need 'meas_shape' argument, it is taken as (h,h) 161 | * \- self.Perm (use self.reindex() instead) 162 | * #### spyrit.core.noise 163 | * spyrit.core.noise.NoNoise 164 | * \+ self.reindex() inherited from spyrit.core.meas.Linear.reindex() 165 | * #### spyrit.core.prep 166 | * \- class SplitRowPoisson (was used with LinearRowSplit) 167 | * #### spyrit.core.recon 168 | * spyrit.core.recon.PseudoInverse 169 | * / self.forward() now has **kwargs that are passed to meas_op.pinv(), useful for lstsq image reconstruction 170 | * #### \+ spyrit.core.torch 171 | contains torch-specific functions that are commonly used in spyrit.core. Mirrors some spyrit.misc functions that are numpy-specific 172 | * #### \+ spyrit.core.warp 173 | * \+ class AffineDeformationField 174 | warps an image using an affine transformation matrix 175 | * \+ class DeformationField 176 | warps an image using a deformation field 177 |
178 | 179 |
180 | 181 | ### spyrit.misc 182 | 183 | 184 | * #### spyrit.misc.matrix_tools 185 | * \- Permutation_Matrix() is deprecated (already defined in spyrit.misc.sampling.Permutation_Matrix()) 186 | * #### spyrit.misc.sampling 187 | * \- meas2img2() is deprecated (use meas2img() instead) 188 | * / meas2img() can now handle batch of images 189 | * \+ sort_by_significance() & reindex() to speed up permutation mattrix multiplication 190 |
191 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_08_lpgd_split_measurements.py: -------------------------------------------------------------------------------- 1 | r""" 2 | ====================================================================== 3 | 08. Learned proximal gradient descent (LPGD) for split measurements 4 | ====================================================================== 5 | .. _tuto_lpgd_split_measurements: 6 | 7 | This tutorial shows how to perform image reconstruction with unrolled Learned Proximal Gradient 8 | Descent (LPGD) for split measurements. 9 | 10 | Unfortunately, it has a large memory consumption so it cannot be run interactively. 11 | If you want to run it yourself, please remove all the "if False:" statements at 12 | the beginning of each code block. The figures displayed are the ones that would 13 | be generated if the code was run. 14 | 15 | .. figure:: ../fig/lpgd.png 16 | :width: 600 17 | :align: center 18 | :alt: Sketch of the unrolled Learned Proximal Gradient Descent 19 | 20 | """ 21 | 22 | ############################################################################### 23 | # LPGD is a unrolled method, which can be explained as a recurrent network where 24 | # each block corresponds to un unrolled iteration of the proximal gradient descent. 25 | # At each iteration, the network performs a gradient step and a denoising step. 26 | # 27 | # The updated rule for the LPGD network is given by: 28 | # 29 | # .. math:: 30 | # x^{(k+1)} = \mathcal{G}_{\theta}(x^{(k)} - \gamma H^T(H(x^{(k)}-m))). 31 | # 32 | # where :math:`x^{(k)}` is the image estimate at iteration :math:`k`, 33 | # :math:`H` is the forward operator, :math:`\gamma` is the step size, 34 | # and :math:`\mathcal{G}_{\theta}` is a denoising network with 35 | # learnable parameters :math:`\theta`. 36 | 37 | # %% 38 | # Load a batch of images 39 | # ----------------------------------------------------------------------------- 40 | # 41 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized 42 | # using the :func:`transform_gray_norm` function. 43 | 44 | # sphinx_gallery_thumbnail_path = 'fig/lpgd.png' 45 | 46 | if False: 47 | import os 48 | 49 | import torch 50 | import torchvision 51 | import matplotlib.pyplot as plt 52 | 53 | import spyrit.core.torch as spytorch 54 | from spyrit.misc.disp import imagesc 55 | from spyrit.misc.statistics import transform_gray_norm 56 | 57 | spyritPath = os.getcwd() 58 | imgs_path = os.path.join(spyritPath, "images/") 59 | 60 | ###################################################################### 61 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function. 62 | 63 | h = 128 # image is resized to h x h 64 | transform = transform_gray_norm(img_size=h) 65 | 66 | ###################################################################### 67 | # Create a data loader from some dataset (images must be in the folder `images/test/`) 68 | 69 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 70 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 71 | 72 | x, _ = next(iter(dataloader)) 73 | print(f"Shape of input images: {x.shape}") 74 | 75 | ###################################################################### 76 | # Select the `i`-th image in the batch 77 | i = 1 # Image index (modify to change the image) 78 | x = x[i : i + 1, :, :, :] 79 | x = x.detach().clone() 80 | print(f"Shape of selected image: {x.shape}") 81 | b, c, h, w = x.shape 82 | 83 | ###################################################################### 84 | # Plot the selected image 85 | 86 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") 87 | 88 | ############################################################################### 89 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679972abaa5a90007058950/download 90 | # :width: 600 91 | # :align: center 92 | # :alt: Ground-truth image x in [-1, 1] 93 | # %% 94 | # Forward operators for split measurements 95 | # ----------------------------------------------------------------------------- 96 | # 97 | # We consider noisy split measurements for a Hadamard operator and a simple 98 | # rectangular subsampling” strategy 99 | # (for more details, refer to :ref:`Acquisition - split measurements `). 100 | # 101 | # 102 | # We define the measurement, noise and preprocessing operators and then 103 | # simulate a measurement vector :math:`y` corrupted by Poisson noise. As in the previous tutorial, 104 | # we simulate an accelerated acquisition by subsampling the measurement matrix 105 | # by retaining only the first rows of a Hadamard matrix. 106 | 107 | if False: 108 | import math 109 | 110 | from spyrit.core.meas import HadamSplit 111 | from spyrit.core.noise import Poisson 112 | from spyrit.core.prep import SplitPoisson 113 | 114 | # Measurement parameters 115 | M = h**2 // 4 # Number of measurements (here, 1/4 of the pixels) 116 | alpha = 10.0 # number of photons 117 | 118 | # Sampling: rectangular matrix 119 | Ord_rec = torch.zeros(h, h) 120 | n_sub = math.ceil(M**0.5) 121 | Ord_rec[:n_sub, :n_sub] = 1 122 | 123 | # Measurement and noise operators 124 | meas_op = HadamSplit(M, h, Ord_rec) 125 | noise_op = Poisson(meas_op, alpha) 126 | prep_op = SplitPoisson(alpha, meas_op) 127 | 128 | print(f"Shape of image: {x.shape}") 129 | 130 | # Measurements 131 | y = noise_op(x) # a noisy measurement vector 132 | m = prep_op(y) # preprocessed measurement vector 133 | 134 | m_plot = spytorch.meas2img(m, Ord_rec) 135 | imagesc(m_plot[0, 0, :, :], r"Measurements $m$") 136 | 137 | ############################################################################### 138 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679972bbaa5a90007058953/download 139 | # :width: 600 140 | # :align: center 141 | # :alt: Measurements m 142 | # 143 | # We define the LearnedPGD network by providing the measurement, noise and preprocessing operators, 144 | # the denoiser and other optional parameters to the class :class:`spyrit.core.recon.LearnedPGD`. 145 | # The optional parameters include the number of unrolled iterations (`iter_stop`) 146 | # and the step size decay factor (`step_decay`). 147 | # We choose Unet as the denoiser, as in previous tutorials. 148 | # For the optional parameters, we use three iterations and a step size decay 149 | # factor of 0.9, which worked well on this data (this should match the parameters 150 | # used during training). 151 | # 152 | # .. image:: ../fig/lpgd.png 153 | # :width: 600 154 | # :align: center 155 | # :alt: Sketch of the network architecture for LearnedPGD 156 | 157 | if False: 158 | from spyrit.core.nnet import Unet 159 | from spyrit.core.recon import LearnedPGD 160 | 161 | # use GPU, if available 162 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 163 | print("Using device:", device) 164 | # Define UNet denoiser 165 | denoi = Unet() 166 | # Define the LearnedPGD model 167 | lpgd_net = LearnedPGD(noise_op, prep_op, denoi, iter_stop=3, step_decay=0.9) 168 | 169 | ############################################################################### 170 | # Now, we download the pretrained weights and load them into the LPGD network. 171 | # Unfortunately, the pretrained weights are too heavy (2GB) to be downloaded 172 | # here. The last figure is nonetheless displayed to show the results. 173 | 174 | if False: 175 | from spyrit.core.train import load_net 176 | from spyrit.misc.load_data import download_girder 177 | 178 | # Download parameters 179 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 180 | dataID = "67221f60f03a54733161e96c" # unique ID of the file 181 | local_folder = "./model/" 182 | data_name = "tuto8_model_lpgd_light.pth" 183 | # Download from Girder 184 | model_abs_path = download_girder(url, dataID, local_folder, data_name) 185 | 186 | # Load pretrained weights to the model 187 | load_net(model_abs_path, lpgd_net, device, strict=False) 188 | 189 | lpgd_net.eval() 190 | lpgd_net.to(device) 191 | 192 | ############################################################################### 193 | # We reconstruct by calling the reconstruct method as in previous tutorials 194 | # and display the results. 195 | 196 | if False: 197 | from spyrit.misc.disp import add_colorbar, noaxis 198 | 199 | with torch.no_grad(): 200 | z_lpgd = lpgd_net.reconstruct(y.to(device)) 201 | 202 | # Plot results 203 | f, axs = plt.subplots(2, 1, figsize=(10, 10)) 204 | 205 | im1 = axs[0].imshow(x.cpu()[0, 0, :, :], cmap="gray") 206 | axs[0].set_title("Ground-truth image", fontsize=16) 207 | noaxis(axs[0]) 208 | add_colorbar(im1, "bottom") 209 | 210 | im2 = axs[1].imshow(z_lpgd.cpu()[0, 0, :, :], cmap="gray") 211 | axs[1].set_title("LPGD", fontsize=16) 212 | noaxis(axs[1]) 213 | add_colorbar(im2, "bottom") 214 | 215 | plt.show() 216 | 217 | ############################################################################### 218 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679853fbaa5a9000705894b/download 219 | # :width: 400 220 | # :align: center 221 | # :alt: Comparison of ground-truth image and LPGD reconstruction 222 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_05_recon_hadamSplit.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | if False: 4 | 5 | # %% 6 | # Split measurement operator and no noise 7 | # ----------------------------------------------------------------------------- 8 | # .. _split_measurements: 9 | 10 | ############################################################################### 11 | # .. math:: 12 | # y = P\tilde{x}= \begin{bmatrix} H_{+} \\ H_{-} \end{bmatrix} \tilde{x}. 13 | 14 | ############################################################################### 15 | # Hadamard split measurement operator is defined in the :class:`spyrit.core.meas.HadamSplit` class. 16 | # It computes linear measurements from incoming images, where :math:`P` is a 17 | # linear operator (matrix) with positive entries and :math:`\tilde{x}` is an image. 18 | # The class relies on a matrix :math:`H` with 19 | # shape :math:`(M,N)` where :math:`N` represents the number of pixels in the 20 | # image and :math:`M \le N` the number of measurements. The matrix :math:`P` 21 | # is obtained by splitting the matrix :math:`H` as :math:`H = H_{+}-H_{-}` where 22 | # :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. 23 | 24 | # %% 25 | # Measurement and noise operators 26 | # ----------------------------------------------------------------------------- 27 | 28 | ############################################################################### 29 | # We compute the measurement and noise operators and then 30 | # simulate the measurement vector :math:`y`. 31 | 32 | ############################################################################### 33 | # We consider Poisson noise, i.e., a noisy measurement vector given by 34 | # 35 | # .. math:: 36 | # y \sim \mathcal{P}(\alpha P \tilde{x}), 37 | # 38 | # where :math:`\alpha` is a scalar value that represents the maximum image intensity 39 | # (in photons). The larger :math:`\alpha`, the higher the signal-to-noise ratio. 40 | 41 | ############################################################################### 42 | # We use the :class:`spyrit.core.noise.Poisson` class, set :math:`\alpha` 43 | # to 100 photons, and simulate a noisy measurement vector for the two sampling 44 | # strategies. Subsampling is handled internally by the :class:`~spyrit.core.meas.HadamSplit` class. 45 | 46 | from spyrit.core.noise import Poisson 47 | from spyrit.core.meas import HadamSplit 48 | 49 | alpha = 100.0 # number of photons 50 | 51 | # "Naive subsampling" 52 | # Measurement and noise operators 53 | meas_nai_op = HadamSplit(M, h, Ord_naive) 54 | noise_nai_op = Poisson(meas_nai_op, alpha) 55 | 56 | # Measurement operator 57 | y_nai = noise_nai_op(x) # a noisy measurement vector 58 | 59 | # "Variance subsampling" 60 | meas_var_op = HadamSplit(M, h, Ord_variance) 61 | noise_var_op = Poisson(meas_var_op, alpha) 62 | y_var = noise_var_op(x) # a noisy measurement vector 63 | 64 | print(f"Shape of image: {x.shape}") 65 | print(f"Shape of simulated measurements y: {y_var.shape}") 66 | 67 | # %% 68 | # The preprocessing operator measurements for split measurements 69 | # ----------------------------------------------------------------------------- 70 | 71 | ############################################################################### 72 | # We compute the preprocessing operators for the three cases considered above, 73 | # using the :mod:`spyrit.core.prep` module. As previously introduced, 74 | # a preprocessing operator applies to the noisy measurements in order to 75 | # compensate for the scaling factors that appear in the measurement or noise operators: 76 | # 77 | # .. math:: 78 | # m = \texttt{Prep}(y), 79 | 80 | ############################################################################### 81 | # We consider the :class:`spyrit.core.prep.SplitPoisson` class that intends 82 | # to "undo" the :class:`spyrit.core.noise.Poisson` class, for split measurements, by compensating for 83 | # 84 | # * the scaling that appears when computing Poisson-corrupted measurements 85 | # 86 | # * the affine transformation to get images in [0,1] from images in [-1,1] 87 | # 88 | # For this, it computes 89 | # 90 | # .. math:: 91 | # m = \frac{2(y_+-y_-)}{\alpha} - P\mathbb{1}, 92 | # 93 | # where :math:`y_+=H_+\tilde{x}` and :math:`y_-=H_-\tilde{x}`. 94 | # This is handled internally by the :class:`spyrit.core.prep.SplitPoisson` class. 95 | 96 | ############################################################################### 97 | # We compute the preprocessing operator and the measurements vectors for 98 | # the two sampling strategies. 99 | 100 | from spyrit.core.prep import SplitPoisson 101 | 102 | # "Naive subsampling" 103 | # 104 | # Preprocessing operator 105 | prep_nai_op = SplitPoisson(alpha, meas_nai_op) 106 | 107 | # Preprocessed measurements 108 | m_nai = prep_nai_op(y_nai) 109 | 110 | # "Variance subsampling" 111 | prep_var_op = SplitPoisson(alpha, meas_var_op) 112 | m_var = prep_var_op(y_var) 113 | 114 | # %% 115 | # Noiseless measurements 116 | # ----------------------------------------------------------------------------- 117 | 118 | ############################################################################### 119 | # We consider now noiseless measurements for the "naive subsampling" strategy. 120 | # We compute the required operators and the noiseless measurement vector. 121 | # For this we use the :class:`spyrit.core.noise.NoNoise` class, which normalizes 122 | # the input image to get an image in [0,1], as explained in 123 | # :ref:`acquisition operators tutorial `. 124 | # For the preprocessing operator, we assign the number of photons equal to one. 125 | 126 | from spyrit.core.noise import NoNoise 127 | 128 | nonoise_nai_op = NoNoise(meas_nai_op) 129 | y_nai_nonoise = nonoise_nai_op(x) # a noisy measurement vector 130 | 131 | prep_nonoise_op = SplitPoisson(1.0, meas_nai_op) 132 | m_nai_nonoise = prep_nonoise_op(y_nai_nonoise) 133 | 134 | ############################################################################### 135 | # We can now plot the three measurement vectors 136 | 137 | # Plot the three measurement vectors 138 | m_plot = meas2img(m_nai_nonoise, Ord_naive) 139 | m_plot2 = meas2img(m_nai, Ord_naive) 140 | m_plot3 = spytorch.meas2img(m_var, Ord_variance) 141 | 142 | m_plot_max = m_plot[0, 0, :, :].max() 143 | m_plot_min = m_plot[0, 0, :, :].min() 144 | 145 | f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 7)) 146 | im1 = ax1.imshow(m_plot[0, 0, :, :], cmap="gray") 147 | ax1.set_title("Noiseless measurements $m$ \n 'Naive' subsampling", fontsize=20) 148 | noaxis(ax1) 149 | add_colorbar(im1, "bottom", size="20%") 150 | 151 | im2 = ax2.imshow(m_plot2[0, 0, :, :], cmap="gray", vmin=m_plot_min, vmax=m_plot_max) 152 | ax2.set_title("Measurements $m$ \n 'Naive' subsampling", fontsize=20) 153 | noaxis(ax2) 154 | add_colorbar(im2, "bottom", size="20%") 155 | 156 | im3 = ax3.imshow(m_plot3[0, 0, :, :], cmap="gray", vmin=m_plot_min, vmax=m_plot_max) 157 | ax3.set_title("Measurements $m$ \n 'Variance' subsampling", fontsize=20) 158 | noaxis(ax3) 159 | add_colorbar(im3, "bottom", size="20%") 160 | 161 | plt.show() 162 | 163 | # %% 164 | # PinvNet network 165 | # ----------------------------------------------------------------------------- 166 | 167 | ############################################################################### 168 | # We use the :class:`spyrit.core.recon.PinvNet` class where 169 | # the pseudo inverse reconstruction is performed by a neural network 170 | 171 | from spyrit.core.recon import PinvNet 172 | 173 | pinvnet_nai_nonoise = PinvNet(nonoise_nai_op, prep_nonoise_op) 174 | pinvnet_nai = PinvNet(noise_nai_op, prep_nai_op) 175 | pinvnet_var = PinvNet(noise_var_op, prep_var_op) 176 | 177 | # Reconstruction 178 | z_nai_nonoise = pinvnet_nai_nonoise.reconstruct(y_nai_nonoise) 179 | z_nai = pinvnet_nai.reconstruct(y_nai) 180 | z_var = pinvnet_var.reconstruct(y_var) 181 | 182 | ############################################################################### 183 | # We can now plot the three reconstructed images 184 | from spyrit.misc.disp import add_colorbar, noaxis 185 | 186 | # Plot 187 | f, axs = plt.subplots(2, 2, figsize=(10, 10)) 188 | im1 = axs[0, 0].imshow(x[0, 0, :, :], cmap="gray") 189 | axs[0, 0].set_title("Ground-truth image") 190 | noaxis(axs[0, 0]) 191 | add_colorbar(im1, "bottom") 192 | 193 | im2 = axs[0, 1].imshow(z_nai_nonoise[0, 0, :, :], cmap="gray") 194 | axs[0, 1].set_title("Reconstruction noiseless") 195 | noaxis(axs[0, 1]) 196 | add_colorbar(im2, "bottom") 197 | 198 | im3 = axs[1, 0].imshow(z_nai[0, 0, :, :], cmap="gray") 199 | axs[1, 0].set_title("Reconstruction \n 'Naive' subsampling") 200 | noaxis(axs[1, 0]) 201 | add_colorbar(im3, "bottom") 202 | 203 | im4 = axs[1, 1].imshow(z_var[0, 0, :, :], cmap="gray") 204 | axs[1, 1].set_title("Reconstruction \n 'Variance' subsampling") 205 | noaxis(axs[1, 1]) 206 | add_colorbar(im4, "bottom") 207 | 208 | plt.show() 209 | 210 | ############################################################################### 211 | # .. note:: 212 | # 213 | # Note that reconstructed images are pixelized when using the "naive subsampling", 214 | # while they are smoother and more similar to the ground-truth image when using the 215 | # "variance subsampling". 216 | # 217 | # Another way to further improve results is to include a nonlinear post-processing step, 218 | # which we will consider in a future tutorial. 219 | -------------------------------------------------------------------------------- /docs/source/gallery/tuto_a00_connect_deepinv.rst: -------------------------------------------------------------------------------- 1 | 2 | .. DO NOT EDIT. 3 | .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. 4 | .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: 5 | .. "gallery/tuto_a00_connect_deepinv.py" 6 | .. LINE NUMBERS ARE GIVEN BELOW. 7 | 8 | .. rst-class:: sphx-glr-example-title 9 | 10 | .. _sphx_glr_gallery_tuto_a00_connect_deepinv.py: 11 | 12 | 13 | a00. Connect to deepinverse (HadamSplit2d) 14 | ==================================================== 15 | .. _tuto_connect_deepinv: 16 | 17 | This tutorial shows how to use DeepInverse (https://github.com/deepinv/deepinv) algorithms with a HadamSplit2d linear model. It used the :class:`spyrit.core.meas.HadamSplit2d` class of the :mod:`spyrit.core.meas` submodule. 18 | 19 | 20 | .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_logolarge.png 21 | :width: 600 22 | :align: center 23 | :alt: Reconstruction architecture sketch 24 | 25 | | 26 | 27 | .. GENERATED FROM PYTHON SOURCE LINES 19-21 28 | 29 | Loads images 30 | ----------------------------------------------------------------------------- 31 | 32 | .. GENERATED FROM PYTHON SOURCE LINES 23-24 33 | 34 | We load a batch of images from the :attr:`/images/` folder with values in (0,1). 35 | 36 | .. GENERATED FROM PYTHON SOURCE LINES 24-50 37 | 38 | .. code-block:: Python 39 | 40 | import os 41 | import torchvision 42 | import torch.nn 43 | 44 | import matplotlib.pyplot as plt 45 | 46 | from spyrit.misc.disp import imagesc 47 | from spyrit.misc.statistics import transform_gray_norm 48 | 49 | import deepinv as dinv 50 | 51 | spyritPath = os.getcwd() 52 | imgs_path = os.path.join(spyritPath, "images/") 53 | 54 | device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" 55 | 56 | # Grayscale images of size (32, 32), no normalization to keep values in (0,1) 57 | transform = transform_gray_norm(img_size=32, normalize=False) 58 | 59 | # Create dataset and loader (expects class folder 'images/test/') 60 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 61 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 62 | 63 | x, _ = next(iter(dataloader)) 64 | print(f"Ground-truth images: {x.shape}") 65 | 66 | 67 | 68 | 69 | 70 | .. rst-class:: sphx-glr-script-out 71 | 72 | .. code-block:: none 73 | 74 | /Users/tbaudier/spyrit/deepinv/deepinv/__about__.py:8: DeprecationWarning: Implicit None on return values is deprecated and will raise KeyErrors. 75 | __license__ = metadata["License"] 76 | Ground-truth images: torch.Size([7, 1, 32, 32]) 77 | 78 | 79 | 80 | 81 | .. GENERATED FROM PYTHON SOURCE LINES 51-52 82 | 83 | We select the second image in the batch and plot it. 84 | 85 | .. GENERATED FROM PYTHON SOURCE LINES 52-56 86 | 87 | .. code-block:: Python 88 | 89 | 90 | i_plot = 1 91 | imagesc(x[i_plot, 0, :, :], r"$32\times 32$ image $X$") 92 | 93 | 94 | 95 | 96 | .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_001.png 97 | :alt: $32\times 32$ image $X$ 98 | :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_001.png 99 | :class: sphx-glr-single-img 100 | 101 | 102 | 103 | 104 | 105 | .. GENERATED FROM PYTHON SOURCE LINES 57-59 106 | 107 | Basic example 108 | ----------------------------------------------------------------------------- 109 | 110 | .. GENERATED FROM PYTHON SOURCE LINES 61-62 111 | 112 | We instantiate an HadamSplit2d object and simulate the 2D hadamard transform of the input images. Reshape output is necesary for deepinv. We also add Poisson noise. 113 | 114 | .. GENERATED FROM PYTHON SOURCE LINES 62-78 115 | 116 | .. code-block:: Python 117 | 118 | from spyrit.core.meas import HadamSplit2d 119 | import spyrit.core.noise as noise 120 | from spyrit.core.prep import UnsplitRescale 121 | 122 | meas_spyrit = HadamSplit2d(32, 512, device=device, reshape_output=True) 123 | alpha = 50 # image intensity 124 | meas_spyrit.noise_model = noise.Poisson(alpha) 125 | y = meas_spyrit(x) 126 | 127 | # preprocess 128 | prep = UnsplitRescale(alpha) 129 | m_spyrit = prep(y) 130 | 131 | print(y.shape) 132 | 133 | 134 | 135 | 136 | 137 | 138 | .. rst-class:: sphx-glr-script-out 139 | 140 | .. code-block:: none 141 | 142 | torch.Size([7, 1, 1024]) 143 | 144 | 145 | 146 | 147 | .. GENERATED FROM PYTHON SOURCE LINES 79-80 148 | 149 | The norm has to be computed to be passed to deepinv. We need to use the max singular value of the linear operator. 150 | 151 | .. GENERATED FROM PYTHON SOURCE LINES 80-84 152 | 153 | .. code-block:: Python 154 | 155 | norm = torch.linalg.norm(meas_spyrit.H, ord=2) 156 | print(norm) 157 | 158 | 159 | 160 | 161 | 162 | 163 | .. rst-class:: sphx-glr-script-out 164 | 165 | .. code-block:: none 166 | 167 | tensor(32.0000) 168 | 169 | 170 | 171 | 172 | .. GENERATED FROM PYTHON SOURCE LINES 85-87 173 | 174 | Forward operator 175 | ---------------------------------------------------------------------- 176 | 177 | .. GENERATED FROM PYTHON SOURCE LINES 89-90 178 | 179 | You can direcly give the forward operator to deepinv. You can also add noise using deepinv model or spyrit model. 180 | 181 | .. GENERATED FROM PYTHON SOURCE LINES 90-99 182 | 183 | .. code-block:: Python 184 | 185 | meas_deepinv = dinv.physics.LinearPhysics( 186 | lambda y: meas_spyrit.measure_H(y) / norm, 187 | A_adjoint=lambda y: meas_spyrit.unvectorize(meas_spyrit.adjoint_H(y) / norm), 188 | ) 189 | # meas_deepinv.noise_model = dinv.physics.GaussianNoise(sigma=0.01) 190 | m_deepinv = meas_deepinv(x) 191 | print("diff:", torch.linalg.norm(m_spyrit / norm - m_deepinv)) 192 | 193 | 194 | 195 | 196 | 197 | 198 | .. rst-class:: sphx-glr-script-out 199 | 200 | .. code-block:: none 201 | 202 | diff: tensor(5.6969) 203 | 204 | 205 | 206 | 207 | .. GENERATED FROM PYTHON SOURCE LINES 100-102 208 | 209 | Reconstruction with deepinverse 210 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 211 | 212 | .. GENERATED FROM PYTHON SOURCE LINES 104-105 213 | 214 | First, use the adjoint and dagger (pseudo-inverse) operators to reconstruct the image. 215 | 216 | .. GENERATED FROM PYTHON SOURCE LINES 105-112 217 | 218 | .. code-block:: Python 219 | 220 | x_adj = meas_deepinv.A_adjoint(m_spyrit / norm) 221 | imagesc(x_adj[1, 0, :, :].cpu(), "Adjoint") 222 | 223 | x_pinv = meas_deepinv.A_dagger(m_spyrit / norm) 224 | imagesc(x_pinv[1, 0, :, :].cpu(), "Pinv") 225 | 226 | 227 | 228 | 229 | 230 | .. rst-class:: sphx-glr-horizontal 231 | 232 | 233 | * 234 | 235 | .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_002.png 236 | :alt: Adjoint 237 | :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_002.png 238 | :class: sphx-glr-multi-img 239 | 240 | * 241 | 242 | .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_003.png 243 | :alt: Pinv 244 | :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_003.png 245 | :class: sphx-glr-multi-img 246 | 247 | 248 | 249 | 250 | 251 | .. GENERATED FROM PYTHON SOURCE LINES 113-114 252 | 253 | You can also use optimization-based methods from deepinv. Here, we use Total Variation (TV) regularization with a projected gradient descent (PGD) algorithm. You can note the use of the custom_init parameter to initialize the algorithm with the dagger operator. 254 | 255 | .. GENERATED FROM PYTHON SOURCE LINES 114-127 256 | 257 | .. code-block:: Python 258 | 259 | model_tv = dinv.optim.optim_builder( 260 | iteration="PGD", 261 | prior=dinv.optim.TVPrior(), 262 | data_fidelity=dinv.optim.L2(), 263 | params_algo={"stepsize": 1, "lambda": 5e-2}, 264 | max_iter=10, 265 | custom_init=lambda y, Physics: {"est": (Physics.A_dagger(y),)}, 266 | ) 267 | 268 | x_tv, metrics_TV = model_tv(m_spyrit / norm, meas_deepinv, compute_metrics=True, x_gt=x) 269 | dinv.utils.plot_curves(metrics_TV) 270 | imagesc(x_tv[1, 0, :, :].cpu(), "TV recon") 271 | 272 | 273 | 274 | 275 | .. rst-class:: sphx-glr-horizontal 276 | 277 | 278 | * 279 | 280 | .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_004.png 281 | :alt: PSNR, F, residual 282 | :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_004.png 283 | :class: sphx-glr-multi-img 284 | 285 | * 286 | 287 | .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_005.png 288 | :alt: TV recon 289 | :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_005.png 290 | :class: sphx-glr-multi-img 291 | 292 | 293 | 294 | 295 | 296 | .. GENERATED FROM PYTHON SOURCE LINES 128-129 297 | 298 | Deep Plug and Play (DPIR) algorithm can also be used with a pretrained denoiser. Here, we use the DRUNet denoiser. 299 | 300 | .. GENERATED FROM PYTHON SOURCE LINES 129-136 301 | 302 | .. code-block:: Python 303 | 304 | denoiser = dinv.models.DRUNet(in_channels=1, out_channels=1, device=device) 305 | model_dpir = dinv.optim.DPIR(sigma=1e-1, device=device, denoiser=denoiser) 306 | model_dpir.custom_init = lambda y, Physics: {"est": (Physics.A_dagger(y),)} 307 | with torch.no_grad(): 308 | x_dpir = model_dpir(m_spyrit / norm, meas_deepinv) 309 | imagesc(x_dpir[1, 0, :, :].cpu(), "DIPR recon") 310 | 311 | 312 | 313 | 314 | .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_006.png 315 | :alt: DIPR recon 316 | :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_006.png 317 | :class: sphx-glr-single-img 318 | 319 | 320 | 321 | 322 | 323 | .. GENERATED FROM PYTHON SOURCE LINES 137-138 324 | 325 | Reconstruct Anything Model (RAM) can also be used. 326 | 327 | .. GENERATED FROM PYTHON SOURCE LINES 138-143 328 | 329 | .. code-block:: Python 330 | 331 | model_ram = dinv.models.RAM(pretrained=True, device=device) 332 | model_ram.sigma_threshold = 1e-1 333 | with torch.no_grad(): 334 | x_ram = model_ram(m_spyrit / norm, meas_deepinv) 335 | imagesc(x_ram[1, 0, :, :].cpu(), "RAM recon") 336 | 337 | 338 | 339 | .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_007.png 340 | :alt: RAM recon 341 | :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_007.png 342 | :class: sphx-glr-single-img 343 | 344 | 345 | 346 | 347 | 348 | 349 | .. rst-class:: sphx-glr-timing 350 | 351 | **Total running time of the script:** (0 minutes 11.085 seconds) 352 | -------------------------------------------------------------------------------- /tutorial/tuto_04_pseudoinverse_cnn_linear.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 04.a. Pseudoinverse + CNN (reconstruction) 3 | ========================================== 4 | .. _tuto_04_pseudoinverse_cnn_linear: 5 | 6 | This tutorial shows how to simulate measurements and perform image reconstruction using the :class:`spyrit.core.recon.PinvNet` class of the :mod:`spyrit.core.recon` submodule. 7 | 8 | .. image:: ../fig/tuto4_pinvnet.png 9 | :width: 600 10 | :align: center 11 | :alt: Reconstruction architecture sketch 12 | 13 | | 14 | """ 15 | 16 | # %% 17 | # Load a batch of images 18 | # ----------------------------------------------------------------------------- 19 | 20 | ############################################################################### 21 | # We load a batch of images from the :attr:`/images/` folder. Using the 22 | # :func:`spyrit.misc.statistics.transform_gray_norm` function with the :attr:`normalize=False` 23 | # argument returns images with values in (0,1). 24 | import os 25 | import torchvision 26 | import torch.nn 27 | from spyrit.misc.statistics import transform_gray_norm 28 | 29 | spyritPath = os.getcwd() 30 | imgs_path = os.path.join(spyritPath, "images/") 31 | 32 | # Grayscale images of size 64 x 64, no normalization to keep values in (0,1) 33 | transform = transform_gray_norm(img_size=64, normalize=False) 34 | 35 | # Create dataset and loader (expects class folder 'images/test/') 36 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 37 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 38 | 39 | x, _ = next(iter(dataloader)) 40 | print(f"Ground-truth images: {x.shape}") 41 | 42 | ############################################################################### 43 | # We plot the second image in the batch 44 | from spyrit.misc.disp import imagesc 45 | 46 | imagesc(x[1, 0, :, :], "x[1, 0, :, :]") 47 | 48 | # %% 49 | # Linear measurements (no noise) 50 | # ----------------------------------------------------------------------------- 51 | 52 | ############################################################################### 53 | # We choose the acquisition matrix as the positive component of a Hadamard 54 | # matrix in "2D". This is a (0,1) matrix with shape of (64*64, 64*64). 55 | from spyrit.core.torch import walsh_matrix_2d 56 | 57 | H = walsh_matrix_2d(64) 58 | H = torch.where(H > 0, 1.0, 0.0) 59 | 60 | print(f"Acquisition matrix: {H.shape}", end=" ") 61 | print(rf"with values in {{{H.min()}, {H.max()}}}") 62 | 63 | ############################################################################### 64 | # We subsample the measurement operator by a factor four, keeping only the 65 | # low-frequency components 66 | 67 | Sampling_square = torch.zeros(64, 64) 68 | Sampling_square[:32, :32] = 1 69 | 70 | imagesc(Sampling_square, "Sampling map") 71 | 72 | ############################################################################### 73 | # We use spyrit.core.torch.sort_by_significance() to permutate the rows of H. 74 | # Then, we keep the first 1024 rows. 75 | 76 | from spyrit.core.torch import sort_by_significance 77 | 78 | H = sort_by_significance(H, Sampling_square, "rows", False) 79 | H = H[: 32 * 32, :] 80 | 81 | print(f"Shape of the measurement matrix: {H.shape}") 82 | 83 | ############################################################################### 84 | # We instantiate a :class:`spyrit.core.meas.Linear` operator. To indicate that 85 | # the operator works in 2D, on images with shape (64, 64), we use the 86 | # :attr:`meas_shape` argument. 87 | 88 | from spyrit.core.meas import Linear 89 | 90 | meas_op = Linear(H, (64, 64)) 91 | 92 | ############################################################################### 93 | # We simulate the measurement vectors, which has a shape of (7, 1, 1024). 94 | y = meas_op(x) 95 | 96 | print(f"Measurement vectors: {y.shape}") 97 | 98 | ############################################################################### 99 | # To display the subsampled measurement vector as an image in the transformed 100 | # domain, we use the :func:`spyrit.core.torch.meas2img` function 101 | 102 | # plot 103 | from spyrit.core.torch import meas2img 104 | 105 | m_plot = meas2img(y, Sampling_square) 106 | print(f"Shape of the preprocessed measurement image: {m_plot.shape}") 107 | 108 | imagesc(m_plot[0, 0, :, :], "Measurements (reshaped)") 109 | 110 | 111 | # %% 112 | # Pseudo inverse solution with PinvNet 113 | # ----------------------------------------------------------------------------- 114 | 115 | ############################################################################### 116 | # The :class:`spyrit.core.recon.PinvNet` class reconstructs an 117 | # image by computing the pseudoinverse solution. By default, the 118 | # torch.linalg.lstsq solver is used 119 | 120 | from spyrit.core.recon import PinvNet 121 | 122 | pinv_net = PinvNet(meas_op) 123 | 124 | ############################################################################### 125 | # We use the :func:`~spyrit.core.recon.PinvNet.reconstruct` method to 126 | # reconstruct the images from the measurement vectors :attr:`y` 127 | 128 | x_rec = pinv_net.reconstruct(y) 129 | 130 | imagesc(x_rec[1, 0, :, :], "Pseudo Inverse") 131 | 132 | ############################################################################### 133 | # Alternatively, the pseudo-inverse of the acquition matrix is computed and 134 | # stored. This option becomes efficient when a large number of reconstructions 135 | # are performed (e.g., during training). To do so, we used set 'store_H_pinv' 136 | # to 'True'. 137 | 138 | pinv_net_2 = PinvNet(meas_op, store_H_pinv=True) 139 | x_rec_2 = pinv_net.reconstruct(y) 140 | 141 | imagesc(x_rec_2[1, 0, :, :], "Pseudo Inverse") 142 | 143 | ############################################################################### 144 | # Contrary to pinv_net, pinv_net_2 stores the pseudo inverse matrix with shape 145 | # (4096,1024) 146 | print(f"pinv_net: {hasattr(pinv_net.pinv, 'pinv')}") 147 | print(f"pinv_net_2: {hasattr(pinv_net_2.pinv, 'pinv')}") 148 | print(f"Shape: {pinv_net_2.pinv.pinv.shape}") 149 | 150 | # %% 151 | # CNN post processing with PinvNet 152 | # ----------------------------------------------------------------------------- 153 | 154 | ############################################################################### 155 | # Reconstruction artefacts can be removed by post processing the pseudo inverse 156 | # solution using a denoising neural network. 157 | # In the following, we select a 158 | # small CNN using the :class:`spyrit.core.nnet.ConvNet` class, but it can be 159 | # replaced by any other neural network (e.g., a UNet 160 | # from :class:`spyrit.core.nnet.Unet`). 161 | 162 | ############################################################################### 163 | # We download a ConvNet that has been trained using STL-10 dataset. 164 | 165 | from spyrit.misc.load_data import download_girder 166 | 167 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 168 | dataID = "68639a2af39e1d2884b09abf" # unique ID of the file 169 | model_folder = "./model/" 170 | 171 | model_cnn_path = download_girder(url, dataID, model_folder) 172 | 173 | ############################################################################### 174 | # The CNN should be placed in an ordered dictionary and passed to a 175 | # :class:`nn.Sequential`. 176 | 177 | from typing import OrderedDict 178 | from spyrit.core.nnet import ConvNet 179 | 180 | denoiser = torch.nn.Sequential(OrderedDict({"denoi": ConvNet()})) 181 | 182 | ############################################################################### 183 | # We load the denoiser and send it to GPU, if available. 184 | 185 | from spyrit.core.train import load_net 186 | 187 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 188 | load_net(model_cnn_path, denoiser, device, False) 189 | 190 | 191 | ############################################################################### 192 | # We create a PinvNet with a postprocessing denoising step 193 | 194 | pinv_net = PinvNet(meas_op, denoi=denoiser, device=device) 195 | 196 | ############################################################################### 197 | # We reconstruct the image using PinvNet 198 | 199 | pinv_net.eval() 200 | y = y.to(device) 201 | 202 | with torch.no_grad(): 203 | x_rec_cnn = pinv_net.reconstruct(y) 204 | 205 | ############################################################################### 206 | # We finally plot the plot results 207 | 208 | import matplotlib.pyplot as plt 209 | from spyrit.misc.disp import add_colorbar, noaxis 210 | 211 | f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) 212 | 213 | im1 = ax1.imshow(x[1, 0, :, :], cmap="gray") 214 | ax1.set_title("Ground-truth", fontsize=20) 215 | noaxis(ax1) 216 | add_colorbar(im1, "bottom", size="20%") 217 | 218 | im2 = ax2.imshow(x_rec[1, 0, :, :].cpu(), cmap="gray") 219 | ax2.set_title("Pinv", fontsize=20) 220 | noaxis(ax2) 221 | add_colorbar(im2, "bottom", size="20%") 222 | 223 | im3 = ax3.imshow(x_rec_cnn.cpu()[1, 0, :, :], cmap="gray") 224 | ax3.set_title("Pinv + CNN", fontsize=20) 225 | noaxis(ax3) 226 | add_colorbar(im3, "bottom", size="20%") 227 | 228 | plt.show() 229 | 230 | ############################################################################### 231 | # We show the best result again (tutorial thumbnail purpose) 232 | # sphinx_gallery_thumbnail_number = 7 233 | 234 | imagesc(x_rec_cnn.cpu()[1, 0, :, :], "Pinv + CNN", title_fontsize=20) 235 | 236 | ############################################################################### 237 | # .. note:: 238 | # 239 | # In the :ref:`next tutorial `, we will 240 | # show how to train PinvNet + CNN denoiser. 241 | 242 | # %% 243 | # Compatibility between spyrit 2 and spyrit 3 244 | # ----------------------------------------------------------------------------- 245 | 246 | ######################################################################### 247 | # SPyRiT 2.4 trains neural networks for images with values in the 248 | # range (-1,1), while SPyRiT 3 assumes images with values in the range (0,1). 249 | # This can be compensated for using :class:`spyrit.core.prep.Rerange`. 250 | 251 | from spyrit.core.prep import Rerange 252 | 253 | rerange = Rerange((0, 1), (-1, 1)) 254 | denoiser = OrderedDict( 255 | {"rerange": rerange, "denoi": ConvNet(), "rerange_inv": rerange.inverse()} 256 | ) 257 | denoiser = torch.nn.Sequential(denoiser) 258 | 259 | 260 | ############################################################################### 261 | # We load a spyrit 2.4 denoiser and show the reconstruction 262 | 263 | dataID = "67221889f03a54733161e963" # unique ID of the file 264 | model_cnn_path = download_girder(url, dataID, model_folder) 265 | load_net(model_cnn_path, denoiser, device, False) 266 | 267 | pinv_net = PinvNet(meas_op, denoi=denoiser, device=device) 268 | 269 | with torch.no_grad(): 270 | x_rec_cnn = pinv_net.reconstruct(y) 271 | 272 | imagesc(x_rec_cnn.cpu()[1, 0, :, :], "Pinv + CNN (v2.4)", title_fontsize=20) 273 | -------------------------------------------------------------------------------- /tutorial/tuto_04_b_train_pseudoinverse_cnn_linear.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 04.b. Pseudoinverse + CNN (training) 3 | ================================================ 4 | .. _tuto_4b_train_pseudoinverse_cnn_linear: 5 | 6 | This tutorial trains a post processing CNN used by a 7 | :class:`spyrit.core.recon.PinvNet` (see the 8 | :ref:`previous tutorial `). 9 | 10 | .. image:: ../fig/tuto4_pinvnet.png 11 | :width: 600 12 | :align: center 13 | :alt: Reconstruction architecture sketch 14 | 15 | | 16 | 17 | For post-processing, we consider a small CNN; however, it be replaced by any other network (e.g., a Unet). Training is performed on the STL-10 dataset, but any other database can be considered. 18 | 19 | You can use Tensorboard for Pytorch for experiment tracking and 20 | for visualizing the training process: losses, network weights, 21 | and intermediate results (reconstructed images at different epochs). 22 | """ 23 | 24 | # %% 25 | # Measurement operator 26 | # ----------------------------------------------------------------------------- 27 | 28 | ############################################################################### 29 | # We choose the acquisition matrix as the positive component of a Hadamard 30 | # matrix in "2D". We subsample it by a factor four, keeping only the 31 | # low-frequency components (see :ref:`Tutorial 4 ` for details). 32 | 33 | ############################################################################ 34 | # Positive component of a Hadamard matrix in "2D". 35 | import torch 36 | from spyrit.core.torch import walsh_matrix_2d 37 | 38 | H = walsh_matrix_2d(64) 39 | H = torch.where(H > 0, 1.0, 0.0) 40 | 41 | ############################################################################ 42 | # Subsampling map 43 | 44 | Sampling_square = torch.zeros(64, 64) 45 | Sampling_square[:32, :32] = 1 46 | 47 | ############################################################################ 48 | # Permutation of the rows and subsampling 49 | 50 | from spyrit.core.torch import sort_by_significance 51 | 52 | H = sort_by_significance(H, Sampling_square, "rows", False) 53 | H = H[: 32 * 32, :] 54 | 55 | ############################################################################### 56 | # Associated :class:`spyrit.core.meas.Linear` operator 57 | 58 | from spyrit.core.meas import Linear 59 | 60 | # Send to GPU if available 61 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 62 | 63 | meas_op = Linear(H, (64, 64), device=device) 64 | 65 | ##################################################################### 66 | # .. note:: 67 | # 68 | # The linear measurement operator is chosen as the positive part of a 69 | # subsampled Hadamard matrix, but any other matrix can be used. 70 | 71 | 72 | # %% 73 | # Pseudo inverse solution followed by a CNN 74 | # ----------------------------------------------------------------------------- 75 | 76 | ############################################################################### 77 | # We consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs 78 | # an image by computing the pseudoinverse solution and applies a nonlinear 79 | # network denoiser. First, we must define the denoiser. As an example, 80 | # we choose a small CNN using the :class:`spyrit.core.nnet.ConvNet` class. 81 | # Then, we define the PinvNet network by passing the noise and preprocessing operators and the denoiser. 82 | 83 | from typing import OrderedDict 84 | from spyrit.core.nnet import ConvNet 85 | 86 | denoiser = torch.nn.Sequential(OrderedDict({"denoi": ConvNet()})) 87 | 88 | 89 | ############################################################################### 90 | # .. note:: 91 | # 92 | # Here, we consider a small CNN; however, it be replaced by any other 93 | # network (e.g., a Unet). 94 | 95 | 96 | ############################################################################### 97 | # We instantiate a :class:`spyrit.core.recon.PinvNet` with the CNN as an 98 | # image-domain post processing 99 | 100 | from spyrit.core.recon import PinvNet 101 | 102 | pinv_net = PinvNet(meas_op, denoi=denoiser, device=device, store_H_pinv=True) 103 | 104 | 105 | ##################################################################### 106 | # .. important:: 107 | # 108 | # We use :attr:`store_H_pinv=True` to compute and store the pseudo inverse 109 | # matrix. This will be *much* faster that using a solver (default option) when a 110 | # large number of pseudoinverse solutions will have to be computed during training. 111 | 112 | # %% 113 | # Dataloader for training 114 | # ----------------------------------------------------------------------------- 115 | # We now consider the STL10 dataset and use the 116 | # the :attr:`normalize=False` argument to keep images with values in (0,1). 117 | # 118 | # Set :attr:`mode_run=True` in the the script below to download the STL10 119 | # dataset and train the CNN. Otherwise, the CNN paramameters will be downloaded. 120 | 121 | # import torch.nn 122 | from spyrit.misc.statistics import data_loaders_stl10 123 | from pathlib import Path 124 | 125 | # Parameters 126 | h = 64 # image size hxh 127 | data_root = Path("./data/") # path to data folder (where the dataset is stored) 128 | batch_size = 700 129 | 130 | # Dataloader for STL-10 dataset 131 | mode_run = False 132 | if mode_run: 133 | dataloaders = data_loaders_stl10( 134 | data_root, 135 | img_size=h, 136 | batch_size=batch_size, 137 | seed=7, 138 | shuffle=True, 139 | download=True, 140 | normalize=False, 141 | ) 142 | 143 | ############################################################################### 144 | # .. note:: 145 | # 146 | # Here, training is performed on the STL-10 dataset, but any other database 147 | # can be considered. 148 | 149 | # %% 150 | # Optimizer 151 | # ----------------------------------------------------------------------------- 152 | 153 | ############################################################################### 154 | # We define a loss function (mean squared error), an optimizer (Adam) 155 | # and a scheduler. The scheduler decreases the learning rate by a factor of 156 | # :attr:`gamma` every :attr:`step_size` epochs. 157 | 158 | from spyrit.core.train import Weight_Decay_Loss 159 | 160 | # Parameters 161 | lr = 1e-3 162 | step_size = 10 163 | gamma = 0.5 164 | 165 | loss = torch.nn.MSELoss() 166 | criterion = Weight_Decay_Loss(loss) 167 | optimizer = torch.optim.Adam(pinv_net.parameters(), lr=lr) 168 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 169 | 170 | # %% 171 | # Training 172 | # ----------------------------------------------------------------------------- 173 | 174 | ############################################################################### 175 | # We use the :func:`spyrit.core.train.train_model` function, 176 | # which iterates through the dataloader, feeds the STL10 images to the full 177 | # network and optimizes the parameters of the CNN. In addition, it computes 178 | # the loss and desired metrics on the training and validation sets at each 179 | # iteration. The training process can be monitored using Tensorboard. 180 | 181 | 182 | ############################################################################### 183 | # Set :attr:`mode_run=True` to train the CNN (e.g., around 60 min for 20 epochs on my laptop equipped with a NVIDIA Quadro P1000). 184 | # Otherwise, download the CNN parameters. 185 | 186 | from spyrit.core.train import train_model 187 | from datetime import datetime 188 | 189 | # Parameters 190 | model_root = Path("./model") # path to model saving files 191 | num_epochs = 20 # number of training epochs (num_epochs = 30) 192 | checkpoint_interval = 0 # interval between saving model checkpoints 193 | tb_freq = ( 194 | 50 # interval between logging to Tensorboard (iterations through the dataloader) 195 | ) 196 | 197 | # Path for Tensorboard experiment tracking logs 198 | name_run = "stl10_hadam_positive" 199 | now = datetime.now().strftime("%Y-%m-%d_%H-%M") 200 | tb_path = f"runs/runs_{name_run}_nonoise_m{meas_op.M}/{now}" 201 | 202 | # Train the network 203 | if mode_run: 204 | pinv_net, train_info = train_model( 205 | pinv_net, 206 | criterion, 207 | optimizer, 208 | scheduler, 209 | dataloaders, 210 | device, 211 | model_root, 212 | num_epochs=num_epochs, 213 | disp=True, 214 | do_checkpoint=checkpoint_interval, 215 | tb_path=tb_path, 216 | tb_freq=tb_freq, 217 | ) 218 | else: 219 | train_info = {} 220 | 221 | ############################################################################### 222 | # .. note:: 223 | # 224 | # To launch Tensorboard type in a new console: 225 | # 226 | # tensorboard --logdir runs 227 | # 228 | # and open the provided link in a browser. The training process can be monitored 229 | # in real time in the "Scalars" tab. The "Images" tab allows to visualize the 230 | # reconstructed images at different iterations :attr:`tb_freq`. 231 | 232 | # %% 233 | # Training history 234 | # ----------------------------------------------------------------------------- 235 | 236 | ############################################################################### 237 | # We save the model so that it can later be utilized. We save the network's 238 | # architecture, the training parameters and the training history. 239 | 240 | from spyrit.core.train import save_net 241 | 242 | title = "tuto_4b" 243 | 244 | Path(model_root).mkdir(parents=True, exist_ok=True) 245 | model_path = model_root / (title + ".pth") 246 | train_path = model_root / (title + ".pkl") 247 | 248 | if checkpoint_interval: 249 | Path(model_path).mkdir(parents=True, exist_ok=True) 250 | 251 | save_net(model_path, pinv_net.denoi) 252 | # save_net(model_root/(title+"_cnn.pth"), pinv_net.denoi.denoi) 253 | 254 | # Save training history 255 | import pickle 256 | 257 | 258 | if mode_run: 259 | from spyrit.core.train import Train_par 260 | 261 | reg = 1e-7 # Default value 262 | params = Train_par(batch_size, lr, h, reg=reg) 263 | params.set_loss(train_info) 264 | 265 | train_path = model_root / (title + ".pkl") 266 | 267 | with open(train_path, "wb") as param_file: 268 | pickle.dump(params, param_file) 269 | torch.cuda.empty_cache() 270 | 271 | else: 272 | from spyrit.misc.load_data import download_girder 273 | 274 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 275 | dataID = "68639a2af39e1d2884b09abc" # unique ID of the file 276 | 277 | download_girder(url, dataID, model_root) 278 | 279 | with open(train_path, "rb") as param_file: 280 | params = pickle.load(param_file) 281 | 282 | train_info["train"] = params.train_loss 283 | train_info["val"] = params.val_loss 284 | 285 | 286 | # %% 287 | # Validation and training losses 288 | # ----------------------------------------------------------------------------- 289 | 290 | ############################################################################### 291 | # We plot the training loss and validation loss 292 | 293 | import matplotlib.pyplot as plt 294 | import numpy as np 295 | 296 | epoch = np.arange(1, num_epochs + 1) 297 | 298 | fig = plt.figure() 299 | plt.semilogy(epoch, train_info["train"], label="train") 300 | plt.semilogy(epoch, train_info["val"], label="val") 301 | plt.xticks([5, 10, 15, 20]) 302 | plt.xlabel("Epochs", fontsize=20) 303 | plt.ylabel("Loss", fontsize=20) 304 | plt.legend(fontsize=20) 305 | plt.show() 306 | -------------------------------------------------------------------------------- /spyrit/misc/metrics.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim import lr_scheduler 11 | import numpy as np 12 | import torchvision 13 | from torchvision import datasets, models, transforms 14 | import torch.nn.functional as F 15 | import imageio 16 | import matplotlib.pyplot as plt 17 | 18 | # import skimage.metrics as skm 19 | 20 | 21 | def batch_psnr(torch_batch, output_batch): 22 | list_psnr = [] 23 | for i in range(torch_batch.shape[0]): 24 | img = torch_batch[i, 0, :, :] 25 | img_out = output_batch[i, 0, :, :] 26 | img = img.cpu().detach().numpy() 27 | img_out = img_out.cpu().detach().numpy() 28 | list_psnr.append(psnr(img, img_out)) 29 | return list_psnr 30 | 31 | 32 | def batch_psnr_(torch_batch, output_batch, r=2): 33 | list_psnr = [] 34 | for i in range(torch_batch.shape[0]): 35 | img = torch_batch[i, 0, :, :] 36 | img_out = output_batch[i, 0, :, :] 37 | img = img.cpu().detach().numpy() 38 | img_out = img_out.cpu().detach().numpy() 39 | list_psnr.append(psnr_(img, img_out, r=r)) 40 | return list_psnr 41 | 42 | 43 | def batch_ssim(torch_batch, output_batch): 44 | list_ssim = [] 45 | for i in range(torch_batch.shape[0]): 46 | img = torch_batch[i, 0, :, :] 47 | img_out = output_batch[i, 0, :, :] 48 | img = img.cpu().detach().numpy() 49 | img_out = img_out.cpu().detach().numpy() 50 | list_ssim.append(ssim(img, img_out)) 51 | return list_ssim 52 | 53 | 54 | def dataset_meas(dataloader, model, device): 55 | meas = [] 56 | for inputs, labels in dataloader: 57 | inputs = inputs.to(device) 58 | # with torch.no_grad(): 59 | b, c, h, w = inputs.shape 60 | net_output = model.acquire(inputs, b, c, h, w) 61 | raw = net_output[:, 0, :] 62 | raw = raw.cpu().detach().numpy() 63 | meas.extend(raw) 64 | return meas 65 | 66 | 67 | # 68 | # def dataset_psnr_different_measures(dataloader, model, model_2, device): 69 | # psnr = []; 70 | # #psnr_fc = []; 71 | # for inputs, labels in dataloader: 72 | # inputs = inputs.to(device) 73 | # m = model_2.normalized measure(inputs); 74 | # net_output = model.forward_reconstruct(inputs); 75 | # #net_output2 = model.evaluate_fcl(inputs); 76 | # 77 | # psnr += batch_psnr(inputs, net_output); 78 | # #psnr_fc += batch_psnr(inputs, net_output2); 79 | # psnr = np.array(psnr); 80 | # #psnr_fc = np.array(psnr_fc); 81 | # return psnr; 82 | # 83 | 84 | 85 | def dataset_psnr(dataloader, model, device): 86 | psnr = [] 87 | psnr_fc = [] 88 | for inputs, labels in dataloader: 89 | inputs = inputs.to(device) 90 | # with torch.no_grad(): 91 | # b,c,h,w = inputs.shape; 92 | 93 | net_output = model.evaluate(inputs) 94 | net_output2 = model.evaluate_fcl(inputs) 95 | 96 | psnr += batch_psnr(inputs, net_output) 97 | psnr_fc += batch_psnr(inputs, net_output2) 98 | psnr = np.array(psnr) 99 | psnr_fc = np.array(psnr_fc) 100 | return psnr, psnr_fc 101 | 102 | 103 | def dataset_ssim(dataloader, model, device): 104 | ssim = [] 105 | ssim_fc = [] 106 | for inputs, labels in dataloader: 107 | inputs = inputs.to(device) 108 | # evaluate full model and fully connected layer 109 | net_output = model.evaluate(inputs) 110 | net_output2 = model.evaluate_fcl(inputs) 111 | # compute SSIM and concatenate 112 | ssim += batch_ssim(inputs, net_output) 113 | ssim_fc += batch_ssim(inputs, net_output2) 114 | ssim = np.array(ssim) 115 | ssim_fc = np.array(ssim_fc) 116 | return ssim, ssim_fc 117 | 118 | 119 | def dataset_psnr_ssim(dataloader, model, device): 120 | # init lists 121 | psnr = [] 122 | ssim = [] 123 | # loop over batches 124 | for inputs, labels in dataloader: 125 | inputs = inputs.to(device) 126 | # evaluate full model 127 | net_output = model.evaluate(inputs) 128 | # compute PSNRs and concatenate 129 | psnr += batch_psnr(inputs, net_output) 130 | # compute SSIMs and concatenate 131 | ssim += batch_ssim(inputs, net_output) 132 | # convert 133 | psnr = np.array(psnr) 134 | ssim = np.array(ssim) 135 | return psnr, ssim 136 | 137 | 138 | def dataset_psnr_ssim_fcl(dataloader, model, device): 139 | # init lists 140 | psnr = [] 141 | ssim = [] 142 | # loop over batches 143 | for inputs, labels in dataloader: 144 | inputs = inputs.to(device) 145 | # evaluate fully connected layer 146 | net_output = model.evaluate_fcl(inputs) 147 | # compute PSNRs and concatenate 148 | psnr += batch_psnr(inputs, net_output) 149 | # compute SSIMs and concatenate 150 | ssim += batch_ssim(inputs, net_output) 151 | # convert 152 | psnr = np.array(psnr) 153 | ssim = np.array(ssim) 154 | return psnr, ssim 155 | 156 | 157 | def psnr(I1, I2): 158 | """ 159 | Computes the psnr between two images I1 and I2 160 | """ 161 | d = np.amax(I1) - np.amin(I1) 162 | diff = np.square(I2 - I1) 163 | MSE = diff.sum() / I1.size 164 | Psnr = 10 * np.log(d**2 / MSE) / np.log(10) 165 | return Psnr 166 | 167 | 168 | def psnr_(img1, img2, r=2): 169 | """ 170 | Computes the psnr between two image with values expected in a given range 171 | 172 | Args: 173 | img1, img2 (np.ndarray): images 174 | r (float): image range 175 | 176 | Returns: 177 | Psnr (float): Peak signal-to-noise ratio 178 | 179 | """ 180 | MSE = np.mean((img1 - img2) ** 2) 181 | Psnr = 10 * np.log(r**2 / MSE) / np.log(10) 182 | return Psnr 183 | 184 | 185 | def psnr_torch(img_gt, img_rec, mask=None, dim=(-2, -1), img_dyn=None): 186 | r""" 187 | Computes the Peak Signal-to-Noise Ratio (PSNR) between two images. 188 | 189 | .. math:: 190 | 191 | \text{PSNR} = 20 \, \log_{10} \left( \frac{\text{d}}{\sqrt{\text{MSE}}} \right), \\ 192 | \text{MSE} = \frac{1}{L}\sum_{\ell=1}^L \|I_\ell - \tilde{I}_\ell\|^2_2, 193 | 194 | where :math:`d` is the image dynamic and :math:`\{I_\ell\}` (resp. :math:`\{\tilde{I}_\ell\}`) is the set of ground truth (resp. reconstructed) images. 195 | 196 | Args: 197 | :attr:`img_gt`: Tensor containing the *ground-truth* image. 198 | 199 | :attr:`img_rec`: Tensor containing the reconstructed image. 200 | 201 | :attr:`mask`: Mask where the squared error is computed. Defaults :attr:`None`, i.e., no mask is considered. 202 | 203 | :attr:`dim`: Dimensions where the squared error is computed. If mask is :attr:`None`, defaults to :attr:`-1` (i.e., the last dimension). Othewise defaults to :attr:`(-2,-1)` (i.e., the last two dimensions). 204 | 205 | :attr:`img_dyn`: Image dynamic range (e.g., 1.0 for normalized images, 255 for 8-bit images). When :attr:`img_dyn` is :attr:`None`, the dynamic range is computed from the ground-truth image. 206 | 207 | Returns: 208 | PSNR value. 209 | 210 | .. note:: 211 | :attr:`psnr_torch(img_gt, img_rec)` is different from :attr:`psnr_torch(img_rec, img_gt)`. The first expression assumes :attr:`img_gt` is the ground truth while the second assumes that this is :attr:`img_rec`. This leads to different dynamic ranges. 212 | 213 | Example 1: 10 images of size 64x64 with values in [0,1) corrupted with 5% noise 214 | >>> x = torch.rand(10,1,64,64) 215 | >>> n = x + 0.05*torch.randn(x.shape) 216 | >>> out = psnr_torch(x,n) 217 | >>> print(out.shape) 218 | torch.Size([10, 1]) 219 | 220 | Example 2: 10 images of size 64x64 with values in [0,1) corrupted with 5% noise 221 | >>> psnr_torch(n,x) 222 | tensor(...) 223 | >>> psnr_torch(x,n) 224 | tensor(...) 225 | >>> psnr_torch(n,x,img_dyn=1.0) 226 | tensor(...) 227 | 228 | """ 229 | if mask is not None: 230 | dim = -1 231 | img_gt = img_gt[mask > 0] 232 | img_rec = img_rec[mask > 0] 233 | print("mask") 234 | 235 | mse = (img_gt - img_rec) ** 2 236 | mse = torch.mean(mse, dim=dim) 237 | 238 | if img_dyn is None: 239 | img_dyn = torch.amax(img_gt, dim=dim) - torch.amin(img_gt, dim=dim) 240 | 241 | return 10 * torch.log10(img_dyn**2 / mse) 242 | 243 | 244 | def ssim(I1, I2): 245 | """ 246 | Computes the ssim between two images I1 and I2 247 | """ 248 | L = np.amax(I1) - np.amin(I1) 249 | mu1 = np.mean(I1) 250 | mu2 = np.mean(I2) 251 | s1 = np.std(I1) 252 | s2 = np.std(I2) 253 | s12 = np.mean(np.multiply((I1 - mu1), (I2 - mu2))) 254 | c1 = (0.01 * L) ** 2 255 | c2 = (0.03 * L) ** 2 256 | result = ((2 * mu1 * mu2 + c1) * (2 * s12 + c2)) / ( 257 | (mu1**2 + mu2**2 + c1) * (s1**2 + s2**2 + c2) 258 | ) 259 | return result 260 | 261 | 262 | # def ssim_sk(x_gt, x, img_dyn=None): 263 | # """ 264 | # SSIM from skimage 265 | 266 | # Args: 267 | # torch tensors 268 | 269 | # Returns: 270 | # torch tensor 271 | # """ 272 | # if not isinstance(x, np.ndarray): 273 | # x = x.cpu().detach().numpy().squeeze() 274 | # x_gt = x_gt.cpu().detach().numpy().squeeze() 275 | # ssim_val = np.zeros(x.shape[0]) 276 | # for i in range(x.shape[0]): 277 | # ssim_val[i] = skm.structural_similarity(x_gt[i], x[i], data_range=img_dyn) 278 | # return torch.tensor(ssim_val) 279 | 280 | 281 | def batch_psnr_vid(input_batch, output_batch): 282 | list_psnr = [] 283 | batch_size, seq_length, c, h, w = input_batch.shape 284 | input_batch = input_batch.reshape(batch_size * seq_length * c, 1, h, w) 285 | output_batch = output_batch.reshape(batch_size * seq_length * c, 1, h, w) 286 | for i in range(input_batch.shape[0]): 287 | img = input_batch[i, 0, :, :] 288 | img_out = output_batch[i, 0, :, :] 289 | img = img.cpu().detach().numpy() 290 | img_out = img_out.cpu().detach().numpy() 291 | list_psnr.append(psnr(img, img_out)) 292 | return list_psnr 293 | 294 | 295 | def batch_ssim_vid(input_batch, output_batch): 296 | list_ssim = [] 297 | batch_size, seq_length, c, h, w = input_batch.shape 298 | input_batch = input_batch.reshape(batch_size * seq_length * c, 1, h, w) 299 | output_batch = output_batch.reshape(batch_size * seq_length * c, 1, h, w) 300 | for i in range(input_batch.shape[0]): 301 | img = input_batch[i, 0, :, :] 302 | img_out = output_batch[i, 0, :, :] 303 | img = img.cpu().detach().numpy() 304 | img_out = img_out.cpu().detach().numpy() 305 | list_ssim.append(ssim(img, img_out)) 306 | return list_ssim 307 | 308 | 309 | def compare_video_nets_supervised(net_list, testloader, device): 310 | psnr = [[] for i in range(len(net_list))] 311 | ssim = [[] for i in range(len(net_list))] 312 | for batch, (inputs, labels) in enumerate(testloader): 313 | [batch_size, seq_length, c, h, w] = inputs.shape 314 | print("Batch :{}/{}".format(batch + 1, len(testloader))) 315 | inputs = inputs.to(device) 316 | labels = labels.to(device) 317 | with torch.no_grad(): 318 | for i in range(len(net_list)): 319 | outputs = net_list[i].evaluate(inputs) 320 | psnr[i] += batch_psnr_vid(labels, outputs) 321 | ssim[i] += batch_ssim_vid(labels, outputs) 322 | return psnr, ssim 323 | 324 | 325 | def compare_nets_unsupervised(net_list, testloader, device): 326 | psnr = [[] for i in range(len(net_list))] 327 | ssim = [[] for i in range(len(net_list))] 328 | for batch, (inputs, labels) in enumerate(testloader): 329 | [batch_size, seq_length, c, h, w] = inputs.shape 330 | print("Batch :{}/{}".format(batch + 1, len(testloader))) 331 | inputs = inputs.to(device) 332 | labels = labels.to(device) 333 | with torch.no_grad(): 334 | for i in range(len(net_list)): 335 | outputs = net_list[i].evaluate(inputs) 336 | psnr[i] += batch_psnr_vid(outputs, labels) 337 | ssim[i] += batch_ssim_vid(outputs, labels) 338 | return psnr, ssim 339 | 340 | 341 | def print_mean_std(x, tag=""): 342 | print("{}psnr = {} +/- {}".format(tag, np.mean(x), np.std(x))) 343 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_06_dcnet_split_measurements.py: -------------------------------------------------------------------------------- 1 | r""" 2 | ========================================= 3 | 06. Denoised Completion Network (DCNet) 4 | ========================================= 5 | .. _tuto_dcnet_split_measurements: 6 | This tutorial shows how to perform image reconstruction using the denoised 7 | completion network (DCNet) with a trainable image denoiser. In the next 8 | tutorial, we will plug a denoiser into a DCNet, which requires no training. 9 | 10 | .. figure:: ../fig/tuto3.png 11 | :width: 600 12 | :align: center 13 | :alt: Reconstruction and neural network denoising architecture sketch using split measurements 14 | """ 15 | 16 | ###################################################################### 17 | # .. note:: 18 | # As in the previous tutorials, we consider a split Hadamard operator and 19 | # measurements corrupted by Poisson noise (see :ref:`Tutorial 5 `). 20 | 21 | # %% 22 | # Load a batch of images 23 | # ========================================= 24 | 25 | ###################################################################### 26 | # Update search path 27 | 28 | # sphinx_gallery_thumbnail_path = 'fig/tuto6.png' 29 | if False: 30 | 31 | import os 32 | 33 | import torch 34 | import torchvision 35 | import matplotlib.pyplot as plt 36 | 37 | import spyrit.core.torch as spytorch 38 | from spyrit.misc.disp import imagesc 39 | from spyrit.misc.statistics import transform_gray_norm 40 | 41 | spyritPath = os.getcwd() 42 | imgs_path = os.path.join(spyritPath, "images/") 43 | 44 | ###################################################################### 45 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function. 46 | 47 | h = 64 # image is resized to h x h 48 | transform = transform_gray_norm(img_size=h) 49 | 50 | ###################################################################### 51 | # Create a data loader from some dataset (images must be in the folder `images/test/`) 52 | 53 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 54 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 55 | 56 | x, _ = next(iter(dataloader)) 57 | print(f"Shape of input images: {x.shape}") 58 | 59 | ###################################################################### 60 | # Select the `i`-th image in the batch 61 | i = 1 # Image index (modify to change the image) 62 | x = x[i : i + 1, :, :, :] 63 | x = x.detach().clone() 64 | print(f"Shape of selected image: {x.shape}") 65 | b, c, h, w = x.shape 66 | 67 | ###################################################################### 68 | # Plot the selected image 69 | 70 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") 71 | 72 | # %% 73 | # Forward operators for split measurements 74 | # ========================================= 75 | 76 | ###################################################################### 77 | # We consider noisy measurements obtained from a split Hadamard operator, and a subsampling strategy that retaines the coefficients with the largest variance (for more details, refer to :ref:`Tutorial 5 `). 78 | 79 | ###################################################################### 80 | # First, we download the covariance matrix from our warehouse. 81 | 82 | import girder_client 83 | from spyrit.misc.load_data import download_girder 84 | 85 | # Get covariance matrix 86 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 87 | dataId = "672207cbf03a54733161e95d" 88 | data_folder = "./stat/" 89 | cov_name = "Cov_64x64.pt" 90 | # download 91 | file_abs_path = download_girder(url, dataId, data_folder, cov_name) 92 | 93 | try: 94 | Cov = torch.load(file_abs_path, weights_only=True) 95 | print(f"Cov matrix {cov_name} loaded") 96 | except: 97 | Cov = torch.eye(h * h) 98 | print(f"Cov matrix {cov_name} not found! Set to the identity") 99 | 100 | ###################################################################### 101 | # We define the measurement, noise and preprocessing operators and then simulate 102 | # a measurement vector corrupted by Poisson noise. As in the previous tutorials, 103 | # we simulate an accelerated acquisition by subsampling the measurement matrix 104 | # by retaining only the first rows of a Hadamard matrix that is permuted looking 105 | # at the diagonal of the covariance matrix. 106 | 107 | from spyrit.core.meas import HadamSplit 108 | from spyrit.core.noise import Poisson 109 | from spyrit.core.prep import SplitPoisson 110 | 111 | # Measurement parameters 112 | M = h**2 // 4 # Number of measurements (here, 1/4 of the pixels) 113 | alpha = 100.0 # number of photons 114 | 115 | # Measurement and noise operators 116 | Ord = spytorch.Cov2Var(Cov) 117 | meas_op = HadamSplit(M, h, Ord) 118 | noise_op = Poisson(meas_op, alpha) 119 | prep_op = SplitPoisson(alpha, meas_op) 120 | 121 | print(f"Shape of image: {x.shape}") 122 | 123 | # Measurements 124 | y = noise_op(x) # a noisy measurement vector 125 | m = prep_op(y) # preprocessed measurement vector 126 | 127 | m_plot = spytorch.meas2img(m, Ord) 128 | imagesc(m_plot[0, 0, :, :], r"Measurements $m$") 129 | 130 | # %% 131 | # Pseudo inverse solution 132 | # ========================================= 133 | 134 | ###################################################################### 135 | # We compute the pseudo inverse solution using :class:`spyrit.core.recon.PinvNet` class as in the previous tutorial. 136 | 137 | # Instantiate a PinvNet (with no denoising by default) 138 | from spyrit.core.recon import PinvNet 139 | 140 | pinvnet = PinvNet(noise_op, prep_op) 141 | 142 | # Use GPU, if available 143 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 144 | print("Using device: ", device) 145 | pinvnet = pinvnet.to(device) 146 | y = y.to(device) 147 | 148 | # Reconstruction 149 | with torch.no_grad(): 150 | z_invnet = pinvnet.reconstruct(y) 151 | 152 | # %% 153 | # Denoised completion network (DCNet) 154 | # ========================================= 155 | 156 | ###################################################################### 157 | # .. image:: ../fig/dcnet.png 158 | # :width: 400 159 | # :align: center 160 | # :alt: Sketch of the DCNet architecture 161 | 162 | ###################################################################### 163 | # The DCNet is based on four sequential steps: 164 | # 165 | # i) Denoising in the measurement domain. 166 | # 167 | # ii) Estimation of the missing measurements from the denoised ones. 168 | # 169 | # iii) Image-domain mapping. 170 | # 171 | # iv) (Learned) Denoising in the image domain. 172 | # 173 | # Typically, only the last step involves learnable parameters. 174 | 175 | # %% 176 | # Denoised completion 177 | # ========================================= 178 | 179 | ###################################################################### 180 | # The first three steps implement denoised completion, which corresponds to Tikhonov regularization. Considering linear measurements :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is the unknown image, it estimates :math:`x` from :math:`y` by minimizing 181 | # 182 | # .. math:: 183 | # \| y - Hx \|^2_{\Sigma^{-1}_\alpha} + \|x\|^2_{\Sigma^{-1}}, 184 | # 185 | # where :math:`\Sigma` is a covariance prior and :math:`\Sigma_\alpha` is the noise covariance. Denoised completation can be performed using the :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag` class (see documentation for more details). 186 | 187 | ###################################################################### 188 | # In practice, it is more convenient to use the :class:`spyrit.core.recon.DCNet` class, which relies on a forward operator, a preprocessing operator, and a covariance prior. 189 | 190 | from spyrit.core.recon import DCNet 191 | 192 | dcnet = DCNet(noise_op, prep_op, Cov) 193 | 194 | # Use GPU, if available 195 | dcnet = dcnet.to(device) 196 | y = y.to(device) 197 | 198 | with torch.no_grad(): 199 | z_dcnet = dcnet.reconstruct(y) 200 | 201 | ###################################################################### 202 | # .. note:: 203 | # In this tutorial, the covariance matrix used to define subsampling is also used as prior knowledge during reconstruction. 204 | 205 | # %% 206 | # (Learned) Denoising in the image domain 207 | # ========================================= 208 | 209 | ###################################################################### 210 | # To implement denoising in the image domain, we provide a :class:`spyrit.core.nnet.Unet` denoiser to a :class:`spyrit.core.recon.DCNet`. 211 | 212 | from spyrit.core.nnet import Unet 213 | 214 | denoi = Unet() 215 | dcnet_unet = DCNet(noise_op, prep_op, Cov, denoi) 216 | dcnet_unet = dcnet_unet.to(device) # Use GPU, if available 217 | 218 | ######################################################################## 219 | # We load pretrained weights for the UNet 220 | 221 | from spyrit.core.train import load_net 222 | 223 | local_folder = "./model/" 224 | # Create model folder 225 | if os.path.exists(local_folder): 226 | print(f"{local_folder} found") 227 | else: 228 | os.mkdir(local_folder) 229 | print(f"Created {local_folder}") 230 | 231 | # Load pretrained model 232 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 233 | dataID = "67221559f03a54733161e960" # unique ID of the file 234 | data_name = "tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth" 235 | model_unet_path = os.path.join(local_folder, data_name) 236 | 237 | if os.path.exists(model_unet_path): 238 | print(f"Model found : {data_name}") 239 | 240 | else: 241 | print(f"Model not found : {data_name}") 242 | print(f"Downloading model... ", end="") 243 | try: 244 | gc = girder_client.GirderClient(apiUrl=url) 245 | gc.downloadFile(dataID, model_unet_path) 246 | print("Done") 247 | except Exception as e: 248 | print("Failed with error: ", e) 249 | 250 | # Load pretrained model 251 | load_net(model_unet_path, dcnet_unet, device, False) 252 | 253 | ###################################################################### 254 | # We reconstruct the image 255 | with torch.no_grad(): 256 | z_dcnet_unet = dcnet_unet.reconstruct(y) 257 | 258 | # %% 259 | # Results 260 | # ========================================= 261 | 262 | from spyrit.misc.disp import add_colorbar, noaxis 263 | 264 | f, axs = plt.subplots(2, 2, figsize=(10, 10)) 265 | 266 | # Plot the ground-truth image 267 | im1 = axs[0, 0].imshow(x[0, 0, :, :], cmap="gray") 268 | axs[0, 0].set_title("Ground-truth image", fontsize=16) 269 | noaxis(axs[0, 0]) 270 | add_colorbar(im1, "bottom") 271 | 272 | # Plot the pseudo inverse solution 273 | im2 = axs[0, 1].imshow(z_invnet.cpu()[0, 0, :, :], cmap="gray") 274 | axs[0, 1].set_title("Pseudo inverse", fontsize=16) 275 | noaxis(axs[0, 1]) 276 | add_colorbar(im2, "bottom") 277 | 278 | # Plot the solution obtained from denoised completion 279 | im3 = axs[1, 0].imshow(z_dcnet.cpu()[0, 0, :, :], cmap="gray") 280 | axs[1, 0].set_title(f"Denoised completion", fontsize=16) 281 | noaxis(axs[1, 0]) 282 | add_colorbar(im3, "bottom") 283 | 284 | # Plot the solution obtained from denoised completion with UNet denoising 285 | im4 = axs[1, 1].imshow(z_dcnet_unet.cpu()[0, 0, :, :], cmap="gray") 286 | axs[1, 1].set_title(f"Denoised completion with UNet denoising", fontsize=16) 287 | noaxis(axs[1, 1]) 288 | add_colorbar(im4, "bottom") 289 | 290 | plt.show() 291 | 292 | ###################################################################### 293 | # .. note:: 294 | # While the pseudo inverse reconstrcution is pixelized, the solution obtained by denoised completion is smoother. DCNet with UNet denoising in the image domain provides the best reconstruction. 295 | 296 | ###################################################################### 297 | # .. note:: 298 | # We refer to `spyrit-examples tutorials `_ for a comparison of different solutions (pinvNet, DCNet and DRUNet) that can be run in colab. 299 | --------------------------------------------------------------------------------