├── .github └── workflows │ └── main.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── AUTHORS ├── CHANGELOG.md ├── Codemeta.json ├── LICENSE.md ├── README.md ├── docs ├── Makefile ├── make.bat └── source │ ├── _static │ └── css │ │ └── sg_README.css │ ├── _templates │ ├── spyrit-class-template.rst │ ├── spyrit-method-template.rst │ └── spyrit-module-template.rst │ ├── build-tuto-local.md │ ├── conf.py │ ├── fig │ ├── dcnet.png │ ├── direct_net.png │ ├── drunet.png │ ├── full.png │ ├── lpgd.png │ ├── pinvnet.png │ ├── pinvnet_cnn.png │ ├── principle.png │ ├── spi_principle.png │ ├── tuto1.png │ ├── tuto2.png │ ├── tuto3.png │ ├── tuto3_pinv.png │ ├── tuto6.png │ └── tuto9.png │ ├── index.rst │ ├── organisation.rst │ ├── sg_execution_times.rst │ └── single_pixel.rst ├── pyproject.toml ├── requirements.txt ├── spyrit ├── __init__.py ├── core │ ├── __init__.py │ ├── inverse.py │ ├── meas.py │ ├── nnet.py │ ├── noise.py │ ├── prep.py │ ├── recon.py │ ├── torch.py │ ├── train.py │ └── warp.py ├── dev │ ├── meas.py │ ├── prep.py │ └── recon.py ├── external │ ├── __init__.py │ └── drunet.py ├── hadamard_matrix │ ├── __init__.py │ ├── create_hadamard_matrix_with_sage.py │ └── download_hadamard_matrix.py └── misc │ ├── __init__.py │ ├── color.py │ ├── data_visualisation.py │ ├── disp.py │ ├── examples.py │ ├── load_data.py │ ├── matrix_tools.py │ ├── metrics.py │ ├── pattern_choice.py │ ├── sampling.py │ ├── statistics.py │ └── walsh_hadamard.py └── tutorial ├── README.txt ├── images └── test │ ├── ILSVRC2012_test_00000001.jpeg │ ├── ILSVRC2012_test_00000002.jpeg │ ├── ILSVRC2012_test_00000003.jpeg │ ├── ILSVRC2012_test_00000004.jpeg │ ├── ILSVRC2012_test_00000005.jpeg │ ├── ILSVRC2012_test_00000006.jpeg │ └── ILSVRC2012_test_00000007.jpeg ├── tuto_01_a_acquisition_operators.py ├── tuto_01_b_splitting.py ├── tuto_01_c_HadamSplit2d.py ├── tuto_02_noise.py ├── tuto_03_pseudoinverse_linear.py └── wip ├── _tuto_03_pseudoinverse_cnn_linear.py ├── _tuto_04_train_pseudoinverse_cnn_linear.py ├── _tuto_05_recon_hadamSplit.py ├── _tuto_06_dcnet_split_measurements.py ├── _tuto_07_drunet_split_measurements.py ├── _tuto_08_lpgd_split_measurements.py ├── _tuto_09_dynamic.py └── _tuto_bonus_advanced_methods_colab.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | 2 | name: CI 3 | 4 | on: 5 | push: 6 | branches: [ master ] 7 | tags: 8 | - '*' 9 | pull_request: 10 | branches: 11 | - '*' 12 | schedule: 13 | - cron: '0 0 * * 0' 14 | workflow_dispatch: 15 | 16 | jobs: 17 | build_wheel: 18 | runs-on: ${{ matrix.os }} 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | os: [ubuntu-latest] 23 | python-version: [3.9] 24 | 25 | steps: 26 | - name: Checkout github repo 27 | uses: actions/checkout@v4 28 | - name: Checkout submodules 29 | run: git submodule update --init --recursive 30 | - name: Set up Python ${{ matrix.python-version }} 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | architecture: 'x64' 35 | - name: Create Wheel 36 | run: | 37 | pip install build 38 | python -m build 39 | mkdir wheelhouse 40 | cp dist/spyrit-* wheelhouse/ 41 | ls wheelhouse 42 | rm -r dist 43 | mv wheelhouse dist 44 | - name: Upload wheels 45 | uses: actions/upload-artifact@v4 46 | with: 47 | name: dist 48 | path: dist/ 49 | 50 | test_install: 51 | runs-on: ${{ matrix.os }} 52 | strategy: 53 | fail-fast: false 54 | matrix: 55 | os: [ubuntu-latest, windows-latest, macos-13, macos-14] 56 | python-version: [3.9, "3.10", "3.11", "3.12"] 57 | exclude: 58 | - os: macos-13 59 | python-version: '3.10' 60 | - os: macos-13 61 | python-version: '3.11' 62 | - os: macos-13 63 | python-version: '3.12' 64 | - os: macos-14 65 | python-version: 3.9 66 | - os: macos-14 67 | python-version: '3.10' 68 | 69 | steps: 70 | - name: Checkout github repo 71 | uses: actions/checkout@v4 72 | - name: Checkout submodules 73 | run: git submodule update --init --recursive 74 | - name: Set up Python ${{ matrix.python-version }} 75 | uses: actions/setup-python@v5 76 | with: 77 | python-version: ${{ matrix.python-version }} 78 | architecture: 'x64' 79 | - name: Run the tests on Mac and Linux 80 | if: matrix.os != 'windows-latest' 81 | run: | 82 | pip install pytest 83 | pip install -e . 84 | python -m pytest --doctest-modules --ignore=tutorial --ignore=docs --ignore=spyrit/dev --ignore=spyrit/hadamard_matrix || exit -1 85 | - name: Run the tests on Windows 86 | if: matrix.os == 'windows-latest' 87 | shell: cmd 88 | run: | 89 | pip install pytest 90 | pip install -e . 91 | python -m pytest --doctest-modules --ignore=tutorial --ignore=docs --ignore=spyrit\dev --ignore=spyrit\hadamard_matrix || exit /b -1 92 | 93 | test_wheel: 94 | runs-on: ${{ matrix.os }} 95 | needs: [build_wheel] 96 | strategy: 97 | fail-fast: false 98 | matrix: 99 | os: [ubuntu-latest, windows-latest, macos-13] 100 | python-version: [3.9, "3.10", "3.11", "3.12"] 101 | 102 | steps: 103 | - name: Checkout github repo 104 | uses: actions/checkout@v4 105 | - name: Checkout submodules 106 | run: git submodule update --init --recursive 107 | - name: Set up Python ${{ matrix.python-version }} 108 | uses: actions/setup-python@v5 109 | with: 110 | python-version: ${{ matrix.python-version }} 111 | architecture: 'x64' 112 | - uses: actions/download-artifact@v4 113 | with: 114 | pattern: dist* 115 | merge-multiple: true 116 | path: dist/ 117 | - name: Run tests on Mac and Linux 118 | if: matrix.os != 'windows-latest' 119 | run: | 120 | cd dist 121 | pip install spyrit-*.whl 122 | - name: Run the tests on Windows 123 | if: matrix.os == 'windows-latest' 124 | run: | 125 | cd dist 126 | $package=dir -Path . -Filter spyrit*.whl | %{$_.FullName} 127 | echo $package 128 | pip install $package 129 | 130 | publish_wheel: 131 | runs-on: ubuntu-latest 132 | needs: [build_wheel, test_wheel, test_install] 133 | steps: 134 | - name: Checkout github repo 135 | uses: actions/checkout@v4 136 | - name: Checkout submodules 137 | run: git submodule update --init --recursive 138 | - uses: actions/download-artifact@v4 139 | with: 140 | pattern: dist* 141 | merge-multiple: true 142 | path: dist/ 143 | - name: Publish to PyPI 144 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/') 145 | uses: pypa/gh-action-pypi-publish@release/v1 146 | with: 147 | user: __token__ 148 | password: ${{ secrets.PYPI }} 149 | skip_existing: true 150 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.py[co] 3 | *.swp 4 | *.pdf 5 | *.png 6 | *.sh 7 | *.ipynb 8 | *.npy 9 | *.npz 10 | 11 | #folders 12 | data 13 | img 14 | models 15 | dist 16 | build 17 | spyrit.egg-info 18 | stats_walsh 19 | **/.ipynb_checkpoints/* 20 | docs/source/_build* 21 | model/ 22 | runs/ 23 | spyrit/drunet/ 24 | !spyrit/images/tuto/*.png 25 | docs/source/html 26 | docs/source/_autosummary 27 | docs/source/_templates 28 | docs/source/api 29 | docs/source/gallery 30 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - repo: https://github.com/psf/black 8 | rev: 25.1.0 9 | hooks: 10 | - id: black 11 | ci: 12 | autofix_commit_msg: | 13 | [pre-commit.ci] Automatic python formatting 14 | autofix_prs: true 15 | submodules: false 16 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Optionally build your docs in additional formats such as PDF 9 | formats: all # pdf, epub and htlmzip 10 | 11 | # Optionally set the version of Python and requirements required to build your docs 12 | build: 13 | os: ubuntu-22.04 14 | tools: 15 | python: "3.11" 16 | python: 17 | install: 18 | - requirements: requirements.txt 19 | 20 | # Build documentation in the docs/ directory with Sphinx 21 | sphinx: 22 | configuration: docs/source/conf.py 23 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Nicolas Ducros 2 | Romain Phan 3 | Thomas Baudier 4 | Juan Abascal 5 | Fadoua Taia-Alaoui 6 | Claire Mouton 7 | Guilherme Beneti 8 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ### Notations 4 | 5 | \- removals 6 | / changes 7 | \+ additions 8 | 9 | --- 10 | 11 |
12 | 13 | 14 | ## Changes to come in a future version 15 | 16 | 17 | 18 | 19 |
20 | 21 | --- 22 | 23 |
24 | 25 | ## v2.3.4 26 | 27 | 28 | ### spyrit.core 29 | * #### General changes 30 | * The input and output shapes have been standardized across operators. All still images (i.e. not videos) have shape `(*, h, w)`, where `*` is any batch dimension (e.g. batch size and number of channels), and `h` and `w` are the height and width of the image. All measurements have shape `(*, M)`, where `*` is the same batch dimension than the images they come from. Videos have shape `(*, t, c, h, w)` where `t` is the time dimension, representing the number of frames in the video, `c` is the number of channels. Dynamic measurements from videos will thus have shape `(*, c, M)`. 31 | * The overall use of gpu has been improved. Every class of the `core` module now has a method `self.device` that allows to track the device on which its parameters are. 32 | * #### spyrit.core.meas 33 | * / The regularization value 'L1' has been changed to 'rcond'. The behavior is unchanged but the reconstruction did not correspond to L1 regularization. 34 | * / Fixed .pinv() output shape (it was transposed with some regularisation methods) 35 | * / Fixed some device errors when using cuda with .pinv() 36 | * / The measurement matrix H is now stored with the data type it is given to the constructor (it was previously converted to torch.float32 for memory reasons) 37 | * \+ added in the .pinv() method a diff parameter enabling differentiated reconstructions (subtracting negative patterns/measurements to the positive patterns/measurements), only available for dynamic operators. 38 | * / For HadamSplit, the pinv has been overwritten to use a fast Walsh-Hadamard transform, zero-padding the measurements if necessary (in the case of subsampling). The inverse() method has been deprecated and will be removed in a future release. 39 | * #### spyrit.core.recon 40 | * \- The class core.recon.Denoise_layer is deprecated and will be removed in a future version 41 | * / The class TikhonovMeasurementPriorDiag no longer uses Denoise_layer and uses instead an internal method to handle the denoising. 42 | * #### spyrit.core.train 43 | * / load_net() uses the weights_only=True parameter in the torch.load() function. Documentation updated. 44 | * #### spyrit.core.warp 45 | * / The warping operation (forward method) now has to be performed on (b,c,h,w) input tensors, and returns (b, time, c, h, w) output tensors. 46 | * / The AffineDeformationField does not store anymore the field as an attribute, but is rather generated on the fly. This allows for more efficient memory management. 47 | * / In AffineDeformationField the image size can be changed. 48 | * \+ It is now possible to use biquintic (5th-order) warping. This uses scikit-image's (skimage) warp function, which relies on numpy arrays. 49 | 50 | ### Tutorials 51 | * Tutorial 2 integrated the change from 'L1' to 'rcond' 52 | * All Tutorials have been updated to include the above mentioned changes. 53 | 54 |
55 | 56 | --- 57 | 58 |
59 | 60 | ## v2.3.3 61 | 62 | 63 | ### spyrit.core 64 | * #### spyrit.core.meas 65 | * / The regularization value 'L1' has been changed to 'rcond'. The behavior is unchanged but the reconstruction did not correspond to L1 regularization. 66 | * #### spyrit.core.recon 67 | * / The documentation for the class core.recon.Denoise_layer has been clarified. 68 | 69 | ### Tutorials 70 | 71 | * Tutorial 2 integrated the change from 'L1' to 'rcond' 72 | 73 |
74 | 75 | --- 76 | 77 |
78 | 79 | ## v2.3.2 80 | 81 | 82 | ### spyrit.core 83 | * #### spyrit.core.meas 84 | * / The method forward_H has been optimized for the HadamSplit class 85 | * #### spyrit.core.torch 86 | * \+ Added spyrit.core.torch.fwht that implements in Pytorch the fast Walsh-Hadamard tranform for natural and Walsh ordered tranforms. 87 | * \+ Added spyrit.core.torch.fwht_2d that implements in Pytorch the fast Walsh-Hadamard tranform in 2 dimensions for natural and Walsh ordered tranforms. 88 | 89 | ### spyrit.misc 90 | 91 | * #### spyrit.misc.statistics 92 | * / The function spyrit.misc.statistics.Cov2Var has been sped up and now supports an output shape for non-square images 93 | * #### spyrit.misc.walsh_hadamard 94 | * / The function spyrit.misc.walsh_hadamard.fwht has been significantly sped up, especially for sequency-ordered walsh-hadamard tranforms. 95 | * \- fwht_torch is now deprecated. Use spyrit.core.torch.fwht instead. 96 | * \- walsh_torch is now deprecated. Use spyrit.core.torch.fwht instead. 97 | * \- walsh2_torch is now deprecated. Use spyrit.core.torch.fwht_2d instead. 98 | * #### spyrit.misc.load_data 99 | * \+ New function download_girder that downloads files identified by their hexadecimal ID from a url server 100 | 101 | ### Tutorials 102 | 103 | * Tutorials 3, 4, 6, 7, 8 now download data from our own servers instead of using google drive and the gdown library. Dependency on gdown library will be fully removed in a future version. 104 | 105 |
106 | 107 | --- 108 | 109 |
110 | 111 | ## v2.3.1 112 | 113 | 114 | ### spyrit.core 115 | 116 | * #### spyrit.core.meas 117 | * \+ For static classes, self.set_H_pinv has been renamed to self.build_H_pinv to match with the dynamic classes. 118 | * \+ The dynamic classes now support bicubic dynamic reconstruction (spyrit.core.meas.DynamicLinear.build_h_dyn()). This uses cubic B-splines. 119 | * #### spyrit.core.train 120 | * load_net() must take the full path, **with** the extension name (xyz.pth). 121 | 122 | ### Tutorials 123 | 124 | * Tutorial 6 has been changed accordingly to the modification of spyrit.core.train.load_net(). 125 | * Tutorial 8 is now available. 126 | 127 |
128 | 129 | --- 130 | 131 |
132 | 133 | ## v2.3.0 134 | 135 | 136 |
137 | 138 | ### spyrit.core 139 | 140 | 141 | * / no longer supports numpy.array as input, must use torch.tensor 142 | * #### spyrit.core.meas 143 | * \- class LinearRowSplit (use LinearSplit instead) 144 | * \+ 3 dynamic classes: DynamicLinear, DynamicLinearSplit, DynamicHadamSplit that allow measurements over time 145 | * spyrit.core.meas.Linear 146 | * \- self.get_H() deprecated (use self.H) 147 | * \- self.H_adjoint (you might want to use self.H.T) 148 | * / constructor argument 'reg' renamed to 'rtol' 149 | * / self.H no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable) 150 | * / self.H_pinv no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable) 151 | * \+ self.__init__() has 'Ord' and 'meas_shape' optional arguments 152 | * \+ self.pinv() now supports lstsq image reconstruction if self.H_pinv is not defined 153 | * \+ self.set_H_pinv(), self.reindex() inherited from spyrit.misc.torch 154 | * \+ self.meas_shape, self.indices, self.Ord, self.H_static 155 | * spyrit.core.meas.LinearSplit 156 | * / [includes changes from Linear] 157 | * / self.P no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable) 158 | * spyrit.core.meas.HadamSplit 159 | * / [includes changes from LinearSplit] 160 | * \- self.__init__() does not need 'meas_shape' argument, it is taken as (h,h) 161 | * \- self.Perm (use self.reindex() instead) 162 | * #### spyrit.core.noise 163 | * spyrit.core.noise.NoNoise 164 | * \+ self.reindex() inherited from spyrit.core.meas.Linear.reindex() 165 | * #### spyrit.core.prep 166 | * \- class SplitRowPoisson (was used with LinearRowSplit) 167 | * #### spyrit.core.recon 168 | * spyrit.core.recon.PseudoInverse 169 | * / self.forward() now has **kwargs that are passed to meas_op.pinv(), useful for lstsq image reconstruction 170 | * #### \+ spyrit.core.torch 171 | contains torch-specific functions that are commonly used in spyrit.core. Mirrors some spyrit.misc functions that are numpy-specific 172 | * #### \+ spyrit.core.warp 173 | * \+ class AffineDeformationField 174 | warps an image using an affine transformation matrix 175 | * \+ class DeformationField 176 | warps an image using a deformation field 177 |
178 | 179 |
180 | 181 | ### spyrit.misc 182 | 183 | 184 | * #### spyrit.misc.matrix_tools 185 | * \- Permutation_Matrix() is deprecated (already defined in spyrit.misc.sampling.Permutation_Matrix()) 186 | * #### spyrit.misc.sampling 187 | * \- meas2img2() is deprecated (use meas2img() instead) 188 | * / meas2img() can now handle batch of images 189 | * \+ sort_by_significance() & reindex() to speed up permutation mattrix multiplication 190 |
191 | -------------------------------------------------------------------------------- /Codemeta.json: -------------------------------------------------------------------------------- 1 | { 2 | "@context": "https://doi.org/10.5063/schema/codemeta-2.0", 3 | "type": "SoftwareSourceCode", 4 | "applicationCategory": "Single-pixel imaging", 5 | "codeRepository": "https://github.com/openspyrit/spyrit", 6 | "dateCreated": "2020-12-10", 7 | "datePublished": "2021-03-11", 8 | "description": "SPyRiT is a PyTorch-based deep image reconstruction package primarily designed for single-pixel imaging.", 9 | "keywords": [ 10 | "Single-pixel imaging", 11 | "pytorch" 12 | ], 13 | "license": "https://spdx.org/licenses/LGPL-3.0", 14 | "name": "SPyRiT", 15 | "operatingSystem": [ 16 | "Linux", 17 | "Windows", 18 | "MacOS" 19 | ], 20 | "programmingLanguage": "Python 3", 21 | "contIntegration": "https://github.com/openspyrit/spyrit/actions", 22 | "codemeta:continuousIntegration": { 23 | "id": "https://github.com/openspyrit/spyrit/actions" 24 | }, 25 | "issueTracker": "https://github.com/openspyrit/spyrit/issues" 26 | } 27 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/openspyrit/spyrit?logo=github) 2 | [![GitHub](https://img.shields.io/github/license/openspyrit/spyrit?style=plastic)](https://github.com/openspyrit/spyrit/blob/master/LICENSE.md) 3 | [![PyPI pyversions](https://img.shields.io/pypi/pyversions/spyrit.svg)](https://pypi.python.org/pypi/spyrit/) 4 | [![Docs](https://readthedocs.org/projects/spyrit/badge/?version=master&style=flat)](https://spyrit.readthedocs.io/en/master/) 5 | 6 | # SPyRiT 7 | SPyRiT is a [PyTorch]()-based deep image reconstruction package primarily designed for single-pixel imaging. 8 | 9 | # Installation 10 | The spyrit package is available for Linux, MacOs and Windows. We recommend to use a virtual environment. 11 | ## Linux and MacOs 12 | (user mode) 13 | ``` 14 | pip install spyrit 15 | ``` 16 | (developper mode) 17 | ``` 18 | git clone https://github.com/openspyrit/spyrit.git 19 | cd spyrit 20 | pip install -e . 21 | ``` 22 | 23 | ## Windows 24 | On Windows you may need to install PyTorch first. It may also be necessary to run the following commands using administrator rights (e.g., starting your Python environment with administrator rights). 25 | 26 | Adapt the two examples below to your configuration (see [here](https://pytorch.org/get-started/locally/) for the latest instructions) 27 | 28 | (CPU version using `pip`) 29 | 30 | ``` 31 | pip3 install torch torchvision torchaudio 32 | ``` 33 | 34 | (GPU version using `conda`) 35 | 36 | ``` shell 37 | conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia 38 | ``` 39 | 40 | Then, install SPyRiT using `pip`: 41 | 42 | (user mode) 43 | ``` 44 | pip install spyrit 45 | ``` 46 | (developper mode) 47 | ``` 48 | git clone https://github.com/openspyrit/spyrit.git 49 | cd spyrit 50 | pip install -e . 51 | ``` 52 | 53 | 54 | ## Test 55 | To check the installation, run in your python terminal: 56 | ``` 57 | import spyrit 58 | ``` 59 | 60 | ## Get started - Examples 61 | To start, check the [documentation tutorials](https://spyrit.readthedocs.io/en/master/gallery/index.html). These tutorials must be runned from `tutorial` folder (they load image samples from `spyrit/images/`): 62 | ``` 63 | cd spyrit/tutorial/ 64 | ``` 65 | 66 | More advanced reconstruction examples can be found in [spyrit-examples/tutorial](https://github.com/openspyrit/spyrit-examples/tree/master/tutorial). Run advanced tutorial in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit-examples/blob/master/tutorial/tuto_core_2d_drunet.ipynb) 67 | 68 | 69 | # API Documentation 70 | https://spyrit.readthedocs.io/ 71 | 72 | # Contributors (alphabetical order) 73 | * Juan Abascal - [Website](https://juanabascal78.wixsite.com/juan-abascal-webpage) 74 | * Thomas Baudier 75 | * Sebastien Crombez 76 | * Nicolas Ducros - [Website](https://www.creatis.insa-lyon.fr/~ducros/WebPage/index.html) 77 | * Antonio Tomas Lorente Mur - [Website]( https://sites.google.com/view/antonio-lorente-mur/) 78 | * Romain Phan 79 | * Fadoua Taia-Alaoui 80 | 81 | # How to cite? 82 | When using SPyRiT in scientific publications, please cite the following paper: 83 | 84 | * G. Beneti-Martin, L Mahieu-Williame, T Baudier, N Ducros, "OpenSpyrit: an Ecosystem for Reproducible Single-Pixel Hyperspectral Imaging," Optics Express, Vol. 31, No. 10, (2023). https://doi.org/10.1364/OE.483937. 85 | 86 | When using SPyRiT specifically for the denoised completion network, please cite the following paper: 87 | 88 | * A Lorente Mur, P Leclerc, F Peyrin, and N Ducros, "Single-pixel image reconstruction from experimental data using neural networks," Opt. Express 29, 17097-17110 (2021). https://doi.org/10.1364/OE.424228. 89 | 90 | # License 91 | This project is licensed under the LGPL-3.0 license - see the [LICENSE.md](LICENSE.md) file for details 92 | 93 | # Acknowledgments 94 | * [Jin LI](https://github.com/happyjin/ConvGRU-pytorch) for his implementation of Convolutional Gated Recurrent Units for PyTorch 95 | * [Erik Lindernoren](https://github.com/eriklindernoren/Action-Recognition) for his processing of the UCF-101 Dataset. 96 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/css/sg_README.css: -------------------------------------------------------------------------------- 1 | .sphx-glr-thumbnails { 2 | width: 100%; 3 | margin: 0px 0px 0px 0px; 4 | 5 | /* align thumbnails on a grid */ 6 | justify-content: space-between; 7 | display: grid; 8 | /* each grid column should be at least 160px (this will determine 9 | the actual number of columns) and then take as much of the 10 | remaining width as possible */ 11 | grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)) !important; 12 | gap: 20px; 13 | } 14 | .sphx-glr-thumbcontainer { 15 | width: 100% !important; 16 | min-height: 210px !important; 17 | margin: 0px !important; 18 | } 19 | .sphx-glr-thumbcontainer .figure { 20 | min-width: 100px !important; 21 | height: 100px !important; 22 | } 23 | .sphx-glr-thumbcontainer img { 24 | display: inline !important; 25 | object-fit: cover !important; 26 | max-height: 150px !important; 27 | min-width: 300px !important; 28 | } 29 | -------------------------------------------------------------------------------- /docs/source/_templates/spyrit-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | 8 | {% block methods %} 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | :toctree: 14 | :template: spyrit-method-template.rst 15 | {% for item in methods %} 16 | {%- if item is in members %} 17 | ~{{ name }}.{{ item }} 18 | {%- endif %} 19 | {%- endfor %} 20 | {% endif %} 21 | {% endblock %} 22 | -------------------------------------------------------------------------------- /docs/source/_templates/spyrit-method-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. automethod:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/source/_templates/spyrit-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | :show-inheritance: 5 | 6 | {% block attributes %} 7 | {% if attributes %} 8 | .. rubric:: {{ _('Module Attributes') }} 9 | 10 | .. autosummary:: 11 | :toctree: 12 | {% for item in attributes %} 13 | {{ item }} 14 | {%- endfor %} 15 | {% endif %} 16 | {% endblock %} 17 | 18 | {% block functions %} 19 | {% if functions %} 20 | .. rubric:: {{ _('Functions') }} 21 | 22 | .. autosummary:: 23 | :toctree: 24 | {% for item in functions %} 25 | {{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block classes %} 31 | {% if classes %} 32 | .. rubric:: {{ _('Classes') }} 33 | 34 | .. autosummary:: 35 | :toctree: 36 | :template: spyrit-class-template.rst 37 | {% for item in classes %} 38 | {{ item }} 39 | {%- endfor %} 40 | {% endif %} 41 | {% endblock %} 42 | 43 | {% block exceptions %} 44 | {% if exceptions %} 45 | .. rubric:: {{ _('Exceptions') }} 46 | 47 | .. autosummary:: 48 | :toctree: 49 | {% for item in exceptions %} 50 | {{ item }} 51 | {%- endfor %} 52 | {% endif %} 53 | {% endblock %} 54 | 55 | {% block modules %} 56 | {% if modules %} 57 | .. rubric:: Modules 58 | 59 | .. autosummary:: 60 | :toctree: 61 | :template: spyrit-module-template.rst 62 | :recursive: 63 | {% for item in modules %} 64 | {{ item }} 65 | {%- endfor %} 66 | {% endif %} 67 | {% endblock %} 68 | -------------------------------------------------------------------------------- /docs/source/build-tuto-local.md: -------------------------------------------------------------------------------- 1 | 2 | ``` shell 3 | git clone --no-single-branch --depth 50 https://github.com/openspyrit/spyrit . 4 | git checkout --force origin/gallery 5 | git clean -d -f -f 6 | cat .readthedocs.yml 7 | ``` 8 | 9 | # Linux 10 | ``` shell 11 | python3.7 -mvirtualenv $READTHEDOCS_VIRTUALENV_PATH 12 | python -m pip install --upgrade --no-cache-dir pip setuptools 13 | python -m pip install --upgrade --no-cache-dir pillow==5.4.1 mock==1.0.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.9.1 recommonmark==0.5.0 sphinx sphinx-rtd-theme readthedocs-sphinx-ext<2.3 14 | python -m pip install --exists-action=w --no-cache-dir -r requirements.txt 15 | cat docs/source/conf.py 16 | python -m sphinx -T -E -b html -d _build/doctrees -D language=en . $READTHEDOCS_OUTPUT/html 17 | ``` 18 | 19 | # Windows using conda 20 | ``` powershell 21 | conda create --name readthedoc 22 | conda activate readthedoc 23 | conda install pip 24 | python.exe -m pip install --upgrade --no-cache-dir pip setuptools 25 | pip install --upgrade --no-cache-dir pillow==10.0.0 mock==1.0.1 alabaster==0.7.13 commonmark==0.9.1 recommonmark==0.5.0 sphinx sphinx-rtd-theme readthedocs-sphinx-ext==2.2.2 26 | cd .\myenv\spyrit\ # replace myenv by the environment in which spyrit is installed 27 | pip install --exists-action=w --no-cache-dir -r requirements.txt 28 | cd .\docs\source\ 29 | python -m sphinx -T -E -b html -d _build/doctrees -D language=en . html 30 | ``` 31 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | from sphinx_gallery.sorting import ExampleTitleSortKey 16 | 17 | # paths relative to this file 18 | sys.path.insert(0, os.path.abspath("../..")) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | project = "spyrit" 22 | copyright = "2021, Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" 23 | author = "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = "2.4.0" 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | "sphinx.ext.intersphinx", 35 | "sphinx.ext.autodoc", 36 | "sphinx.ext.mathjax", 37 | "sphinx.ext.todo", 38 | "sphinx.ext.autosummary", 39 | "sphinx.ext.napoleon", 40 | "sphinx.ext.viewcode", 41 | "sphinx_gallery.gen_gallery", 42 | "sphinx.ext.coverage", 43 | ] 44 | 45 | # Napoleon settings 46 | napoleon_google_docstring = False 47 | napoleon_numpy_docstring = True 48 | napoleon_include_private_with_doc = False 49 | napoleon_include_special_with_doc = False 50 | napoleon_use_admonition_for_examples = False 51 | napoleon_use_admonition_for_notes = False 52 | napoleon_use_admonition_for_references = False 53 | napoleon_use_ivar = True 54 | napoleon_use_param = False 55 | napoleon_use_rtype = False 56 | 57 | autodoc_member_order = "bysource" 58 | autosummary_generate = True 59 | todo_include_todos = True 60 | 61 | # Add any paths that contain templates here, relative to this directory. 62 | templates_path = ["_templates"] 63 | 64 | # List of patterns, relative to source directory, that match files and 65 | # directories to ignore when looking for source files. 66 | # This pattern also affects html_static_path and html_extra_path. 67 | exclude_patterns = [] 68 | 69 | sphinx_gallery_conf = { 70 | # path to your examples scripts 71 | "examples_dirs": [ 72 | "../../tutorial", 73 | ], 74 | # path where to save gallery generated examples 75 | "gallery_dirs": ["gallery"], 76 | "filename_pattern": "/tuto_", 77 | "ignore_pattern": "/_", 78 | # resize the thumbnails, original size = 400x280 79 | "thumbnail_size": (400, 280), 80 | # Remove the "Download all examples" button from the top level gallery 81 | "download_all_examples": False, 82 | # Sort gallery example by file name instead of number of lines (default) 83 | "within_subsection_order": ExampleTitleSortKey, 84 | # directory where function granular galleries are stored 85 | "backreferences_dir": "api/generated/backreferences", 86 | # Modules for which function level galleries are created. 87 | "doc_module": "spyrit", 88 | # Insert links to documentation of objects in the examples 89 | "reference_url": {"spyrit": None}, 90 | } 91 | 92 | # -- Options for HTML output ------------------------------------------------- 93 | 94 | # The theme to use for HTML and HTML Help pages. See the documentation for 95 | # a list of builtin themes. 96 | html_theme = "sphinx_rtd_theme" 97 | 98 | # directory containing custom CSS file (used to produce bigger thumbnails) 99 | 100 | # on_rtd is whether we are on readthedocs.org 101 | on_rtd = os.environ.get("READTHEDOCS", None) == "True" 102 | 103 | # Add any paths that contain custom static files (such as style sheets) here, 104 | # relative to this directory. They are copied after the builtin static files, 105 | # so a file named "default.css" will overwrite the builtin "default.css". 106 | # By default, this is set to include the _static path. 107 | html_static_path = ["_static"] 108 | html_css_files = ["css/sg_README.css"] 109 | 110 | # The master toctree document. 111 | master_doc = "index" 112 | 113 | html_sidebars = { 114 | "**": ["globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html"] 115 | } 116 | 117 | # http://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_mock_imports 118 | # autodoc_mock_imports incompatible with autosummary somehow 119 | # autodoc_mock_imports = "numpy matplotlib mpl_toolkits scipy torch torchvision Pillow opencv-python imutils PyWavelets pywt wget imageio".split() 120 | 121 | 122 | # exclude all torch.nn.Module members (except forward method) from the docs: 123 | import torch 124 | 125 | 126 | def skip_member_handler(app, what, name, obj, skip, options): 127 | always_document = [ # complete this list if needed by adding methods 128 | "forward", # you *always* want to see documented 129 | ] 130 | if name in always_document: 131 | return None 132 | if name in dir(torch.nn.Module): # used for most of the classes in spyrit 133 | return True 134 | if name in dir(torch.nn.Sequential): # used for FullNet and child classes 135 | return True 136 | return None 137 | 138 | 139 | def setup(app): 140 | app.connect("autodoc-skip-member", skip_member_handler) 141 | -------------------------------------------------------------------------------- /docs/source/fig/dcnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/dcnet.png -------------------------------------------------------------------------------- /docs/source/fig/direct_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/direct_net.png -------------------------------------------------------------------------------- /docs/source/fig/drunet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/drunet.png -------------------------------------------------------------------------------- /docs/source/fig/full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/full.png -------------------------------------------------------------------------------- /docs/source/fig/lpgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/lpgd.png -------------------------------------------------------------------------------- /docs/source/fig/pinvnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/pinvnet.png -------------------------------------------------------------------------------- /docs/source/fig/pinvnet_cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/pinvnet_cnn.png -------------------------------------------------------------------------------- /docs/source/fig/principle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/principle.png -------------------------------------------------------------------------------- /docs/source/fig/spi_principle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/spi_principle.png -------------------------------------------------------------------------------- /docs/source/fig/tuto1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto1.png -------------------------------------------------------------------------------- /docs/source/fig/tuto2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto2.png -------------------------------------------------------------------------------- /docs/source/fig/tuto3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto3.png -------------------------------------------------------------------------------- /docs/source/fig/tuto3_pinv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto3_pinv.png -------------------------------------------------------------------------------- /docs/source/fig/tuto6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto6.png -------------------------------------------------------------------------------- /docs/source/fig/tuto9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto9.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. spyrit documentation master file, created by 2 | sphinx-quickstart on Fri Mar 12 11:04:59 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | SPyRiT 7 | ##################################################################### 8 | 9 | SPyRiT is a `PyTorch `_-based image reconstruction 10 | package designed for `single-pixel imaging `_. SPyRiT has a `modular organisation `_ and may be useful for other inverse problems. 11 | 12 | Github repository: `openspyrit/spyrit `_ 13 | 14 | 15 | Installation 16 | ================================== 17 | 18 | SPyRiT is available for Linux, MacOs and Windows:: 19 | 20 | pip install spyrit 21 | 22 | See `here `_ for advanced installation guidelines. 23 | 24 | 25 | Getting started 26 | ================================== 27 | 28 | Please check our `tutorials `_ as well as the `examples `_ on GitHub. 29 | 30 | Cite us 31 | ================================== 32 | 33 | When using SPyRiT in scientific publications, please cite [v3]_ for SPyRiT v3, [v2]_ for SPyRiT v2, and [v1]_ for DC-Net. 34 | 35 | .. [v3] JFJP Abascal, T Baudier, R Phan, A Repetti, N Ducros, "SPyRiT 3.0: an open source package for single-pixel imaging based on deep learning," Preprint (2024). 36 | .. [v2] G Beneti-Martin, L Mahieu-Williame, T Baudier, N Ducros, "OpenSpyrit: an Ecosystem for Reproducible Single-Pixel Hyperspectral Imaging," *Optics Express*, Vol. 31, Issue 10, (2023). `DOI `_. 37 | .. [v1] A Lorente Mur, P Leclerc, F Peyrin, and N Ducros, "Single-pixel image reconstruction from experimental data using neural networks," *Opt. Express*, Vol. 29, Issue 11, 17097-17110 (2021). `DOI `_. 38 | 39 | 40 | Join the project 41 | ================================== 42 | 43 | The list of contributors can be found `here `_. Feel free to contact us by `e-mail `_ for any question. Direct contributions via pull requests (PRs) are welcome. 44 | 45 | .. toctree:: 46 | :maxdepth: 2 47 | :hidden: 48 | 49 | single_pixel 50 | organisation 51 | 52 | 53 | .. toctree:: 54 | :maxdepth: 2 55 | :caption: Tutorials 56 | :hidden: 57 | 58 | gallery/index 59 | 60 | 61 | Contents 62 | ======== 63 | 64 | .. autosummary:: 65 | :toctree: _autosummary 66 | :template: spyrit-module-template.rst 67 | :recursive: 68 | :caption: Contents 69 | 70 | spyrit.core 71 | spyrit.misc 72 | spyrit.external 73 | -------------------------------------------------------------------------------- /docs/source/organisation.rst: -------------------------------------------------------------------------------- 1 | Organisation of the package 2 | ================================== 3 | 4 | .. figure:: fig/direct_net.png 5 | :width: 600 6 | :align: center 7 | 8 | 9 | SPyRiT's typical pipeline. 10 | 11 | SPyRiT allows to simulate measurements and perform image reconstruction using 12 | a full network. A full network includes a measurement operator 13 | :math:`A`, a noise operator :math:`\mathcal{N}`, a preprocessing 14 | operator :math:`B`, a reconstruction operator :math:`\mathcal{R}`, 15 | and a learnable neural network :math:`\mathcal{G}_{\theta}`. All operators 16 | inherit from :class:`torch.nn.Module`. 17 | 18 | 19 | Submodules 20 | ----------------------------------- 21 | 22 | SPyRiT has a modular structure with the core functionality organised in the 8 submodules of 23 | :mod:`spyrit.core`. 24 | 25 | 1. :mod:`spyrit.core.meas` provides measurement operators that compute linear measurements corresponding to :math:`A` in Eq. :eq:`eq_acquisition`. It also provides the adjoint and the pseudoinverse of :math:`A`, which are the basis of any reconstruction algorithm. 26 | 27 | 2. :mod:`spyrit.core.noise` provides noise operators corresponding to :math:`\mathcal{N}` in Eq. :eq:`eq_acquisition`. 28 | 29 | 3. :mod:`spyrit.core.prep` provides preprocessing operators for the operator :math:`B` introduced in Eq. :eq:`eq_prep`. 30 | 31 | 4. :mod:`spyrit.core.nnet` provides known neural networks corresponding to :math:`\mathcal{G}` in Eq. :eq:`eq_recon_direct` or Eq. :eq:`eq_pgd_no_Gamma`. 32 | 33 | 5. :mod:`spyrit.core.recon` returns the reconstruction operator corresponding to :math:`\mathcal{R}`. 34 | 35 | 6. :mod:`spyrit.core.train` provides the functionality to solve the minimisation problem of Eq. :eq:`eq_train`. 36 | 37 | 7. :mod:`spyrit.core.warp` contains the operators used for dynamic acquisitions. 38 | 39 | 8. :mod:`spyrit.core.torch` contains utility functions. 40 | 41 | In addition, :mod:`spyrit.misc` contains various utility functions for Numpy / PyTorch that can be used independently of the core functions. 42 | 43 | Finally, :mod:`spyrit.external` provides access to `DR-UNet `_. 44 | -------------------------------------------------------------------------------- /docs/source/sg_execution_times.rst: -------------------------------------------------------------------------------- 1 | 2 | :orphan: 3 | 4 | .. _sphx_glr_sg_execution_times: 5 | 6 | 7 | Computation times 8 | ================= 9 | **00:00.000** total execution time for 7 files **from all galleries**: 10 | 11 | .. container:: 12 | 13 | .. raw:: html 14 | 15 | 19 | 20 | 21 | 22 | 27 | 28 | .. list-table:: 29 | :header-rows: 1 30 | :class: table table-striped sg-datatable 31 | 32 | * - Example 33 | - Time 34 | - Mem (MB) 35 | * - :ref:`sphx_glr_gallery_tuto_01_acquisition_operators.py` (``..\..\tutorial\tuto_01_acquisition_operators.py``) 36 | - 00:00.000 37 | - 0.0 38 | * - :ref:`sphx_glr_gallery_tuto_02_pseudoinverse_linear.py` (``..\..\tutorial\tuto_02_pseudoinverse_linear.py``) 39 | - 00:00.000 40 | - 0.0 41 | * - :ref:`sphx_glr_gallery_tuto_03_pseudoinverse_cnn_linear.py` (``..\..\tutorial\tuto_03_pseudoinverse_cnn_linear.py``) 42 | - 00:00.000 43 | - 0.0 44 | * - :ref:`sphx_glr_gallery_tuto_04_train_pseudoinverse_cnn_linear.py` (``..\..\tutorial\tuto_04_train_pseudoinverse_cnn_linear.py``) 45 | - 00:00.000 46 | - 0.0 47 | * - :ref:`sphx_glr_gallery_tuto_05_acquisition_split_measurements.py` (``..\..\tutorial\tuto_05_acquisition_split_measurements.py``) 48 | - 00:00.000 49 | - 0.0 50 | * - :ref:`sphx_glr_gallery_tuto_06_dcnet_split_measurements.py` (``..\..\tutorial\tuto_06_dcnet_split_measurements.py``) 51 | - 00:00.000 52 | - 0.0 53 | * - :ref:`sphx_glr_gallery_tuto_bonus_advanced_methods_colab.py` (``..\..\tutorial\tuto_bonus_advanced_methods_colab.py``) 54 | - 00:00.000 55 | - 0.0 56 | -------------------------------------------------------------------------------- /docs/source/single_pixel.rst: -------------------------------------------------------------------------------- 1 | Single-pixel imaging 2 | ================================== 3 | .. _principle: 4 | .. figure:: fig/spi_principle.png 5 | :width: 800 6 | :align: center 7 | 8 | Overview of the principle of single-pixel imaging. 9 | 10 | 11 | Simulation of the measurements 12 | ----------------------------------- 13 | Single-pixel imaging aims to recover an unknown image :math:`x\in\mathbb{R}^N` from a few noisy observations 14 | 15 | .. math:: 16 | m \approx Hx, 17 | 18 | where :math:`H\colon \mathbb{R}^{M\times N}` is a linear measurement operator, :math:`M` is the number of measurements and :math:`N` is the number of pixels in the image. 19 | 20 | In practice, measurements are obtained by uploading a set of light patterns onto a spatial light modulator (e.g., a digital micromirror device (DMD), see :ref:`principle`). Therefore, only positive patterns can be implemented. We model the actual acquisition process as 21 | 22 | 23 | .. math:: 24 | :label: eq_acquisition 25 | 26 | y = \mathcal{N}(Ax) 27 | 28 | where :math:`\mathcal{N} \colon \mathbb{R}^J \to \mathbb{R}^J` represents a noise operator (e.g., Poisson or Poisson-Gaussian), :math:`A \in \mathbb{R}_+^{J\times N}` is the actual acquisition operator that models the (positive) DMD patterns, and :math:`J` is the number of DMD patterns. 29 | 30 | Handling non negativity with pre-processing 31 | ---------------------------------------------------------------------- 32 | We may preprocess the measurements before reconstruction to transform the actual measurements into the target measurements 33 | 34 | .. math:: 35 | :label: eq_prep 36 | 37 | m = By \approx Hx 38 | 39 | 40 | where :math:`B\colon\mathbb{R}^{J}\to \mathbb{R}^{M}` is the preprocessing operator chosen such that :math:`BA=H`. Note that the noise of the preprocessed measurements :math:`m=By` is not the same as that of the actual measurements :math:`y`. 41 | 42 | Data-driven image reconstruction 43 | ----------------------------------- 44 | Data-driven methods based on deep learning aim to find an estimate :math:`x^*\in \mathbb{R}^N` of the unknown image :math:`x` from the preprocessed measurements :math:`By`, using a reconstruction operator :math:`\mathcal{R}_{\theta^*} \colon \mathbb{R}^M \to \mathbb{R}^N` 45 | 46 | .. math:: 47 | \mathcal{R}_{\theta^*}(m) = x^* \approx x, 48 | 49 | where :math:`\theta^*` represents the parameters learned during a training procedure. 50 | 51 | Learning phase 52 | ----------------------------------- 53 | In the case of supervised learning, it is assumed that a training dataset :math:`\{x^{(i)},y^{(i)}\}_{1 \le i \le I}` of :math:`I` pairs of ground truth images in :math:`\mathbb{R}^N` and measurements in :math:`\mathbb{R}^M` is available}. :math:`\theta^*` is then obtained by solving 54 | 55 | .. math:: 56 | :label: eq_train 57 | 58 | \min_{\theta}\,{\sum_{i =1}^I \mathcal{L}\left(x^{(i)},\mathcal{R}_\theta(By^{(i)})\right)}, 59 | 60 | 61 | where :math:`\mathcal{L}` is the training loss (e.g., squared error). In the case where only ground truth images :math:`\{x^{(i)}\}_{1 \le i \le I}` are available, the associated measurements are simulated as :math:`y^{(i)} = \mathcal{N}(Ax^{(i)})`, :math:`1 \le i \le I`. 62 | 63 | 64 | Reconstruction operator 65 | ----------------------------------- 66 | A simple yet efficient method consists in correcting a traditional (e.g. linear) reconstruction by a data-driven nonlinear step 67 | 68 | .. math:: 69 | :label: eq_recon_direct 70 | 71 | \mathcal{R}_\theta = \mathcal{G}_\theta \circ \mathcal{R}, 72 | 73 | where :math:`\mathcal{R}\colon\mathbb{R}^{M}\to\mathbb{R}^N` is a traditional hand-crafted (e.g., regularized) reconstruction operator and :math:`\mathcal{G}_\theta\colon\mathbb{R}^{N}\to\mathbb{R}^N` is a nonlinear neural network that acts in the image domain. 74 | 75 | Algorithm unfolding consists in defining :math:`\mathcal{R}_\theta` from an iterative scheme 76 | 77 | .. math:: 78 | :label: eq_pgd_no_Gamma 79 | 80 | \mathcal{R}_\theta = \mathcal{R}_{\theta_K} \circ ... \circ \mathcal{R}_{\theta_1}, 81 | 82 | where :math:`\mathcal{R}_{\theta_k}` can be interpreted as the computation of the :math:`k`-th iteration of the iterative scheme and :math:`\theta = \bigcup_{k} \theta_k`. 83 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=67", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.setuptools] 9 | include-package-data = true 10 | zip-safe = false 11 | script-files = [ 12 | "tutorial/tuto_01_a_acquisition_operators.py", 13 | "tutorial/tuto_01_b_splitting.py", 14 | "tutorial/tuto_01_c_HadamSplit2d.py", 15 | "tutorial/tuto_02_noise.py", 16 | "tutorial/tuto_03_pseudoinverse_linear.py" 17 | ] 18 | 19 | [tool.setuptools.dynamic] 20 | readme = {file = ["README.md"]} 21 | 22 | [tool.setuptools.packages] 23 | find = {} # Scanning implicit namespaces is active by default 24 | 25 | [project] 26 | name = "spyrit" 27 | version = "3.0.1" 28 | dynamic = ["readme"] 29 | authors = [{name = "Nicolas Ducros", email = "Nicolas.Ducros@insa-lyon.fr"}] 30 | description = "Toolbox for deep image reconstruction" 31 | license = {file = "LICENSE.md"} 32 | classifiers = [ 33 | "Programming Language :: Python", 34 | "Programming Language :: Python :: 3.9", 35 | "Programming Language :: Python :: 3.10", 36 | "Programming Language :: Python :: 3.11", 37 | "Programming Language :: Python :: Implementation :: PyPy", 38 | "Operating System :: OS Independent", 39 | ] 40 | dependencies = [ 41 | "numpy", 42 | "matplotlib", 43 | "scipy", 44 | "torch", 45 | "torchvision", 46 | "Pillow", 47 | "PyWavelets", 48 | "wget", 49 | "sympy", 50 | "imageio", 51 | "astropy", 52 | "requests", 53 | "tqdm", 54 | "girder-client", 55 | ] 56 | requires-python = ">=3.9" 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | scipy 4 | torch 5 | torchvision 6 | Pillow 7 | opencv-python 8 | imutils 9 | PyWavelets 10 | wget 11 | sympy 12 | imageio 13 | astropy 14 | tensorboard 15 | sphinx_gallery 16 | sphinx_rtd_theme 17 | girder-client 18 | gdown==v4.6.3 19 | -------------------------------------------------------------------------------- /spyrit/__init__.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | # from __future__ import division, print_function, absolute_import 8 | # from distutils.version import LooseVersion 9 | # 10 | # 11 | # from . import spyritest 12 | # 13 | # 14 | # __all__ = [s for s in dir() if not s.startswith('_')] 15 | 16 | # from . import core 17 | # from . import misc 18 | # from . import external 19 | -------------------------------------------------------------------------------- /spyrit/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core module for Spyrit package, containing the main classes and functions.""" 2 | -------------------------------------------------------------------------------- /spyrit/dev/meas.py: -------------------------------------------------------------------------------- 1 | # ================================================================================== 2 | class Linear_shift(Linear): 3 | # ================================================================================== 4 | r"""Linear with shifted pattern matrix of size :math:`(M+1,N)` and :math:`Perm` matrix of size :math:`(N,N)`. 5 | 6 | Args: 7 | - Hsub: subsampled Hadamard matrix 8 | - Perm: Permuation matrix 9 | 10 | Shape: 11 | - Input1: :math:`(M, N)` 12 | - Input2: :math:`(N, N)` 13 | 14 | Example: 15 | >>> Hsub = np.array(np.random.random([400,32*32])) 16 | >>> Perm = np.array(np.random.random([32*32,32*32])) 17 | >>> FO_Shift = Linear_shift(Hsub, Perm) 18 | 19 | """ 20 | 21 | def __init__(self, Hsub, Perm): 22 | super().__init__(Hsub) 23 | 24 | # Todo: Use index rather than permutation (see misc.walsh_hadamard) 25 | self.Perm = nn.Linear(self.N, self.N, False) 26 | self.Perm.weight.data = torch.from_numpy(Perm.T) 27 | self.Perm.weight.data = self.Perm.weight.data.float() 28 | self.Perm.weight.requires_grad = False 29 | 30 | H_shift = torch.cat((torch.ones((1, self.N)), (self.Hsub.weight.data + 1) / 2)) 31 | 32 | self.H_shift = nn.Linear(self.N, self.M + 1, False) 33 | self.H_shift.weight.data = H_shift # include the all-one pattern 34 | self.H_shift.weight.data = self.H_shift.weight.data.float() # keep ? 35 | self.H_shift.weight.requires_grad = False 36 | 37 | def forward(self, x: torch.tensor) -> torch.tensor: 38 | r"""Applies Linear transform such that :math:`y = \begin{bmatrix}{1}\\{H_{sub}}\end{bmatrix}x`. 39 | 40 | Args: 41 | :math:`x`: batch of images. 42 | 43 | Shape: 44 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of pixels in the image. 45 | - Output: :math:`(b*c, M+1)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`M+1` the number of measurements + 1. 46 | 47 | Example: 48 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float) 49 | >>> y = FO_Shift(x) 50 | >>> print(y.shape) 51 | torch.Size([10, 401]) 52 | """ 53 | # input x is a set of images with shape (b*c, N) 54 | # output input is a set of measurement vector with shape (b*c, M+1) 55 | x = self.H_shift(x) 56 | return x 57 | 58 | # x_shift = super().forward(x) - x_dark.expand(x.shape[0],self.M) # (H-1/2)x 59 | 60 | 61 | # ================================================================================== 62 | class Linear_pos(Linear): 63 | # ================================================================================== 64 | r"""Linear with Permutation Matrix :math:`Perm` of size :math:`(N,N)`. 65 | 66 | Args: 67 | - Hsub: subsampled Hadamard matrix 68 | - Perm: Permuation matrix 69 | 70 | Shape: 71 | - Input1: :math:`(M, N)` 72 | - Input2: :math:`(N, N)` 73 | 74 | Example: 75 | >>> Hsub = np.array(np.random.random([400,32*32])) 76 | >>> Perm = np.array(np.random.random([32*32,32*32])) 77 | >>> meas_op_pos = Linear_pos(Hsub, Perm) 78 | """ 79 | 80 | def __init__(self, Hsub, Perm): 81 | super().__init__(Hsub) 82 | 83 | # Todo: Use index rather than permutation (see misc.walsh_hadamard) 84 | self.Perm = nn.Linear(self.N, self.N, False) 85 | self.Perm.weight.data = torch.from_numpy(Perm.T) 86 | self.Perm.weight.data = self.Perm.weight.data.float() 87 | self.Perm.weight.requires_grad = False 88 | 89 | def forward(self, x: torch.tensor) -> torch.tensor: 90 | r"""Computes :math:`y` according to :math:`y=0.5(H_{sub}x+\sum_{j=1}^{N}x_{j})` where :math:`j` is the pixel (column) index of :math:`x`. 91 | 92 | Args: 93 | :math:`x`: Batch of images. 94 | 95 | Shape: 96 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of pixels in the image. 97 | - Output: :math:`(b*c, M)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`M` the number of measurements. 98 | 99 | Example: 100 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float) 101 | >>> y = meas_op_pos(x) 102 | >>> print(y.shape) 103 | torch.Size([100, 400]) 104 | """ 105 | # input x is a set of images with shape (b*c, N) 106 | # output is a set of measurement vectors with shape (b*c, M) 107 | 108 | # compute 1/2(H+1)x = 1/2 HX + 1/2 1x 109 | x = super().forward(x) + x.sum(dim=1, keepdim=True).expand(-1, self.M) 110 | x *= 0.5 111 | 112 | return x 113 | 114 | 115 | # ================================================================================== 116 | class Linear_shift_had(Linear_shift): 117 | # ================================================================================== 118 | r"""Linear_shift operator with inverse method. 119 | 120 | Args: 121 | - Hsub: subsampled Hadamard matrix 122 | - Perm: Permuation matrix 123 | 124 | Shape: 125 | - Input1: :math:`(M, N)` 126 | - Input2: :math:`(N, N)`. 127 | 128 | Example: 129 | >>> Hsub = np.array(np.random.random([400,32*32])) 130 | >>> Perm = np.array(np.random.random([32*32,32*32])) 131 | >>> FO_Shift_Had = Linear_shift_had(Hsub, Perm) 132 | """ 133 | 134 | def __init__(self, Hsub, Perm): 135 | super().__init__(Hsub, Perm) 136 | 137 | def inverse(self, x: torch.tensor, n: Union[None, int] = None) -> torch.tensor: 138 | r"""Inverse transform such that :math:`x = \frac{1}{N}H_{sub}y`. 139 | 140 | Args: 141 | :math:`x`: Batch of completed measurements. 142 | 143 | Shape: 144 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of measurements. 145 | - Output: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of reconstructed. pixels. 146 | 147 | Example: 148 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float) 149 | >>> x_reconstruct = FO_Shift_Had.inverse(y_pad) 150 | >>> print(x_reconstruct.shape) 151 | torch.Size([10, 1024]) 152 | """ 153 | # rearrange the terms + inverse transform 154 | # maybe needs to be initialised with a permutation matrix as well! 155 | # Permutation matrix may be sparsified when sparse tensors are no longer in 156 | # beta (as of pytorch 1.11, it is still in beta). 157 | 158 | # --> Use index rather than permutation (see misc.walsh_hadamard) 159 | 160 | # input x is a set of **measurements** with shape (b*c, N) 161 | # output is a set of **images** with shape (b*c, N) 162 | bc, N = x.shape 163 | x = self.Perm(x) 164 | 165 | if n is None: 166 | n = int(np.sqrt(N)) 167 | 168 | # Inverse transform 169 | x = x.reshape(bc, 1, n, n) 170 | x = ( 171 | 1 / self.N * walsh2_torch(x) 172 | ) # todo: initialize with 1D transform to speed up 173 | x = x.reshape(bc, N) 174 | return x 175 | -------------------------------------------------------------------------------- /spyrit/dev/prep.py: -------------------------------------------------------------------------------- 1 | # ================================================================================== 2 | class Preprocess_shift_poisson(nn.Module): # header needs to be updated! 3 | # ================================================================================== 4 | r"""Preprocess the measurements acquired using shifted patterns corrupted 5 | by Poisson noise 6 | 7 | Computes: 8 | m = (2 m_shift - m_offset)/N_0 9 | var = 4*Diag(m_shift + m_offset)/alpha**2 10 | Warning: dark measurement is assumed to be the 0-th entry of raw measurements 11 | 12 | Args: 13 | - :math:`alpha`: noise level 14 | - :math:`M`: number of measurements 15 | - :math:`N`: number of image pixels 16 | 17 | Shape: 18 | - Input1: scalar 19 | - Input2: scalar 20 | - Input3: scalar 21 | 22 | Example: 23 | >>> PSP = Preprocess_shift_poisson(9, 400, 32*32) 24 | """ 25 | 26 | def __init__(self, alpha, M, N): 27 | super().__init__() 28 | self.alpha = alpha 29 | self.N = N 30 | self.M = M 31 | 32 | def forward(self, x: torch.tensor, meas_op: Linear) -> torch.tensor: 33 | r""" 34 | 35 | Warning: 36 | - The offset measurement is the 0-th entry of the raw measurements. 37 | 38 | Args: 39 | - :math:`x`: Batch of images in Hadamard domain shifted by 1 40 | - :math:`meas_op`: Forward_operator 41 | 42 | Shape: 43 | - Input: :math:`(b*c, M+1)` 44 | - Output: :math:`(b*c, M)` 45 | 46 | Example: 47 | >>> Hsub = np.array(np.random.random([400,32*32])) 48 | >>> FO = Forward_operator(Hsub) 49 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float) 50 | >>> y_PSP = PSP(x, FO) 51 | >>> print(y_PSP.shape) 52 | torch.Size([10, 400]) 53 | 54 | """ 55 | y = self.offset(x) 56 | x = 2 * x[:, 1:] - y.expand( 57 | x.shape[0], self.M 58 | ) # Warning: dark measurement is the 0-th entry 59 | x = x / self.alpha 60 | x = 2 * x - meas_op.H( 61 | torch.ones(x.shape[0], self.N).to(x.device) 62 | ) # to shift images in [-1,1]^N 63 | return x 64 | 65 | def sigma(self, x): 66 | r""" 67 | Args: 68 | - :math:`x`: Batch of images in Hadamard domain shifted by 1 69 | 70 | Shape: 71 | - Input: :math:`(b*c, M+1)` 72 | 73 | Example: 74 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float) 75 | >>> sigma_PSP = PSP.sigma(x) 76 | >>> print(sigma_PSP.shape) 77 | torch.Size([10, 400]) 78 | """ 79 | # input x is a set of measurement vectors with shape (b*c, M+1) 80 | # output is a set of measurement vectors with shape (b*c,M) 81 | y = self.offset(x) 82 | x = 4 * x[:, 1:] + y.expand(x.shape[0], self.M) 83 | x = x / (self.alpha**2) 84 | x = 4 * x # to shift images in [-1,1]^N 85 | return x 86 | 87 | def cov(self, x): # return a full matrix ? It is such that Diag(a) + b 88 | return x 89 | 90 | def sigma_from_image(self, x, meas_op): # should check this! 91 | # input x is a set of images with shape (b*c, N) 92 | # input meas_op is a Forward_operator 93 | x = meas_op.H(x) 94 | y = self.offset(x) 95 | x = x[:, 1:] + y.expand(x.shape[0], self.M) 96 | x = x / (self.alpha) # here the alpha contribution is not squared. 97 | return x 98 | 99 | def offset(self, x): 100 | r"""Get offset component from bach of shifted images. 101 | 102 | Args: 103 | - :math:`x`: Batch of shifted images 104 | 105 | Shape: 106 | - Input: :math:`(bc, M+1)` 107 | - Output: :math:`(bc, 1)` 108 | 109 | Example: 110 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float) 111 | >>> y = PSP.offset(x) 112 | >>> print(y.shape) 113 | torch.Size([10, 1]) 114 | 115 | """ 116 | y = x[:, 0, None] 117 | return y 118 | 119 | 120 | # ================================================================================== 121 | class Preprocess_pos_poisson(nn.Module): # header needs to be updated! 122 | # ================================================================================== 123 | r"""Preprocess the measurements acquired using positive (shifted) patterns 124 | corrupted by Poisson noise 125 | 126 | The output value of the layer with input size :math:`(B*C, M)` can be 127 | described as: 128 | 129 | .. math:: 130 | \text{out}((B*C)_i, M_j}) = 2*\text{input}((B*C)_i, M_j}) - 131 | \sum_{k = 1}^{M-1} \text{input}((B*C)_i, M_k}) 132 | 133 | The output size of the layer is :math:`(B*C, M)`, which is the imput size 134 | 135 | 136 | Warning: 137 | dark measurement is assumed to be the 0-th entry of raw measurements 138 | 139 | Args: 140 | - :math:`alpha`: noise level 141 | - :math:`M`: number of measurements 142 | - :math:`N`: number of image pixels 143 | 144 | Shape: 145 | - Input1: scalar 146 | - Input2: scalar 147 | - Input3: scalar 148 | 149 | Example: 150 | >>> PPP = Preprocess_pos_poisson(9, 400, 32*32) 151 | 152 | """ 153 | 154 | def __init__(self, alpha, M, N): 155 | super().__init__() 156 | self.alpha = alpha 157 | self.N = N 158 | self.M = M 159 | 160 | def forward(self, x: torch.tensor, meas_op: Linear) -> torch.tensor: 161 | r""" 162 | Args: 163 | - :math:`x`: noise level 164 | - :math:`meas_op`: Forward_operator 165 | 166 | Shape: 167 | - Input1: :math:`(bc, M)` 168 | - Input2: None 169 | - Output: :math:`(bc, M)` 170 | 171 | Example: 172 | >>> Hsub = np.array(np.random.random([400,32*32])) 173 | >>> meas_op = Forward_operator(Hsub) 174 | >>> x = torch.tensor(np.random.random([10, 400]), dtype=torch.float) 175 | >>> y = PPP(x, meas_op) 176 | torch.Size([10, 400]) 177 | 178 | """ 179 | y = self.offset(x) 180 | x = 2 * x - y.expand(-1, self.M) 181 | x = x / self.alpha 182 | x = 2 * x - meas_op.H( 183 | torch.ones(x.shape[0], self.N).to(x.device) 184 | ) # to shift images in [-1,1]^N 185 | return x 186 | 187 | def offset(self, x): 188 | r"""Get offset component from bach of shifted images. 189 | 190 | Args: 191 | - :math:`x`: Batch of shifted images 192 | 193 | Shape: 194 | - Input: :math:`(bc, M)` 195 | - Output: :math:`(bc, 1)` 196 | 197 | Example: 198 | >>> x = torch.tensor(np.random.random([10, 400]), dtype=torch.float) 199 | >>> y = PPP.offset(x) 200 | >>> print(y.shape) 201 | torch.Size([10, 1]) 202 | 203 | """ 204 | y = 2 / (self.M - 2) * x[:, 1:].sum(dim=1, keepdim=True) 205 | return y 206 | -------------------------------------------------------------------------------- /spyrit/external/__init__.py: -------------------------------------------------------------------------------- 1 | """This module uses a modified version of the Unet presented in https://github.com/cszn/DPIR/blob/master/models/network_unet.py""" 2 | 3 | # from . import drunet 4 | -------------------------------------------------------------------------------- /spyrit/hadamard_matrix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/spyrit/hadamard_matrix/__init__.py -------------------------------------------------------------------------------- /spyrit/hadamard_matrix/create_hadamard_matrix_with_sage.py: -------------------------------------------------------------------------------- 1 | from sage.all import * 2 | from sage.combinat.matrices.hadamard_matrix import ( 3 | hadamard_matrix, 4 | skew_hadamard_matrix, 5 | is_hadamard_matrix, 6 | is_skew_hadamard_matrix, 7 | ) 8 | import numpy as np 9 | import glob 10 | 11 | # Get all Hadamard matrices of order 4*n for Sage 12 | # https://github.com/sagemath/sage/ 13 | # run in conda env with: 14 | # sage create_hadamard_matrix_with_sage.py 15 | 16 | k = Integer(2000) 17 | for n in range(Integer(1), k + Integer(1)): 18 | try: 19 | H = hadamard_matrix(Integer(4) * n, check=False) 20 | 21 | if is_hadamard_matrix(H): 22 | print(n * 4) 23 | a = np.array(H) 24 | a[a == -1] = 0 25 | a = a.astype(bool) 26 | 27 | # find the files with that order 28 | files = glob.glob("had." + str(n * 4) + "*.npz") 29 | already_saved = False 30 | for file in files: 31 | b = np.load(file) 32 | if a == b: 33 | already_saved = True 34 | if already_saved: 35 | break 36 | 37 | if not already_saved: 38 | name = "had." + str(n * 4) + ".sage.npz" 39 | np.savez_compressed(name, a) 40 | except ValueError as e: 41 | pass 42 | -------------------------------------------------------------------------------- /spyrit/hadamard_matrix/download_hadamard_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import requests 3 | import os 4 | import glob 5 | import importlib.util 6 | import tqdm 7 | 8 | 9 | def download_from_girder(): 10 | """ 11 | Download Hadamard matrices from the Girder repository into hadamard_matrix folder. 12 | """ 13 | 14 | hadamard_matrix_path = os.path.dirname(__file__) 15 | if os.path.isfile( 16 | os.path.join(hadamard_matrix_path, "had.236.sage.cooper-wallis.npz") 17 | ): 18 | return 19 | print("Downloading Hadamard matrices (>2300) from Girder repository...") 20 | print( 21 | "The matrices were downloaded from http://neilsloane.com/hadamard/ Sloane et al." 22 | ) 23 | import girder_client 24 | 25 | gc = girder_client.GirderClient( 26 | apiUrl="https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 27 | ) 28 | 29 | collection_id = "66796d3cbaa5a90007058946" 30 | folder_id = "6800c6891240141f6aa53845" 31 | limit = 50 # Number of items to retrieve per request 32 | offset = 0 # Starting point 33 | pbar = tqdm.tqdm(total=0) 34 | 35 | while True: 36 | items = gc.get( 37 | "item", 38 | parameters={ 39 | "parentType": "collection", 40 | "parentId": collection_id, 41 | "folderId": folder_id, 42 | "limit": limit, 43 | "offset": offset, 44 | }, 45 | ) 46 | if not items: 47 | break 48 | pbar.total += len(items) 49 | pbar.refresh() 50 | for item in items: 51 | files = gc.get(f'item/{item["_id"]}/files') 52 | for file in files: 53 | pbar.update(1) 54 | gc.downloadFile( 55 | file["_id"], os.path.join(hadamard_matrix_path, file["name"]) 56 | ) 57 | offset += limit 58 | pbar.close() 59 | 60 | 61 | def read_text_file_from_url(url): 62 | response = requests.get(url) 63 | content = response.text 64 | return content 65 | 66 | 67 | def download_from_sloane(): 68 | from selenium import webdriver 69 | from selenium.webdriver.common.by import By 70 | from selenium.webdriver.chrome.service import Service 71 | from webdriver_manager.chrome import ChromeDriverManager 72 | 73 | # Set up the WebDriver 74 | driver = webdriver.Chrome(service=Service(ChromeDriverManager().install())) 75 | 76 | # Open the website 77 | driver.get("http://neilsloane.com/hadamard/") 78 | 79 | # Find all links to Hadamard matrices 80 | links = driver.find_elements(By.XPATH, "//a[contains(@href, 'had.')]") 81 | 82 | # Extract the URLs 83 | hadamard_urls = set([link.get_attribute("href") for link in links]) 84 | 85 | # Print the URLs 86 | for url in hadamard_urls: 87 | print(url) 88 | # Read the text file from the URL 89 | file_content = read_text_file_from_url(url) 90 | # Split the content into lines 91 | lines = file_content.splitlines() 92 | 93 | # Print the content of the file 94 | if "+" in file_content or "0" in file_content or "-1" in file_content: 95 | if len(lines) > 1: 96 | size = len(lines[1]) 97 | else: 98 | size = len(lines[0]) 99 | array = [] 100 | for line in lines: 101 | if len(line) == size: 102 | line = line.replace("-1", "0") 103 | tmp = [] 104 | for e in line: 105 | if e == "+" or e == "1": 106 | tmp += [1] 107 | elif e == "-" or e == "0": 108 | tmp += [0] 109 | elif e == " ": 110 | pass 111 | else: 112 | print("Error during reading of " + url) 113 | array += [tmp] 114 | np_array = np.array(array, dtype=bool) 115 | 116 | name = url.split("/")[-1][:-4] 117 | order = int(name.split(".")[1]) 118 | 119 | # Check if the file already exists 120 | files = glob.glob("had." + str(order) + "*.npz") 121 | already_saved = False 122 | for file in files: 123 | b = np.load(file) 124 | if np.all(np_array == b): 125 | already_saved = True 126 | if already_saved: 127 | break 128 | 129 | if not already_saved: 130 | np.savez_compressed(name + ".npz", np_array) 131 | else: 132 | print("no ok for " + url) 133 | # print(file_content) 134 | 135 | # Close the WebDriver 136 | driver.quit() 137 | 138 | 139 | if __name__ == "__main__": 140 | download_from_sloane() 141 | -------------------------------------------------------------------------------- /spyrit/misc/__init__.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | """Contains miscellaneous Numpy / Pytorch functions useful for spyrit.core.""" 8 | 9 | # from . import color 10 | # from . import data_visualisation 11 | # from . import disp 12 | # from . import examples 13 | # from . import matrix_tools 14 | # from . import metrics 15 | # from . import pattern_choice 16 | # from . import sampling 17 | # from . import statistics 18 | # from . import walsh_hadamard 19 | -------------------------------------------------------------------------------- /spyrit/misc/color.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Dec 2 20:53:59 2024 4 | 5 | @author: ducros 6 | """ 7 | import numpy as np 8 | import warnings 9 | from typing import Tuple 10 | 11 | import warnings 12 | from typing import Tuple 13 | 14 | import numpy as np 15 | 16 | 17 | # %% 18 | def wavelength_to_rgb( 19 | wavelength: float, gamma: float = 0.8 20 | ) -> Tuple[float, float, float]: 21 | """Converts wavelength to RGB. 22 | 23 | Based on https://gist.github.com/friendly/67a7df339aa999e2bcfcfec88311abfc. 24 | Itself based on code by Dan Bruton: 25 | http://www.physics.sfasu.edu/astro/color/spectra.html 26 | 27 | Args: 28 | wavelength (float): 29 | Single wavelength to be converted to RGB. 30 | gamma (float, optional): 31 | Gamma correction. Defaults to 0.8. 32 | 33 | Returns: 34 | Tuple[float, float, float]: 35 | RGB value. 36 | """ 37 | 38 | if np.min(wavelength) < 380 or np.max(wavelength) > 750: 39 | warnings.warn("Some wavelengths are not in the visible range [380-750] nm") 40 | 41 | if wavelength >= 380 and wavelength <= 440: 42 | attenuation = 0.3 + 0.7 * (wavelength - 380) / (440 - 380) 43 | R = ((-(wavelength - 440) / (440 - 380)) * attenuation) ** gamma 44 | G = 0.0 45 | B = (1.0 * attenuation) ** gamma 46 | 47 | elif wavelength >= 440 and wavelength <= 490: 48 | R = 0.0 49 | G = ((wavelength - 440) / (490 - 440)) ** gamma 50 | B = 1.0 51 | 52 | elif wavelength >= 490 and wavelength <= 510: 53 | R = 0.0 54 | G = 1.0 55 | B = (-(wavelength - 510) / (510 - 490)) ** gamma 56 | 57 | elif wavelength >= 510 and wavelength <= 580: 58 | R = ((wavelength - 510) / (580 - 510)) ** gamma 59 | G = 1.0 60 | B = 0.0 61 | 62 | elif wavelength >= 580 and wavelength <= 645: 63 | R = 1.0 64 | G = (-(wavelength - 645) / (645 - 580)) ** gamma 65 | B = 0.0 66 | 67 | elif wavelength >= 645 and wavelength <= 750: 68 | attenuation = 0.3 + 0.7 * (750 - wavelength) / (750 - 645) 69 | R = (1.0 * attenuation) ** gamma 70 | G = 0.0 71 | B = 0.0 72 | 73 | else: 74 | R = 0.0 75 | G = 0.0 76 | B = 0.0 77 | 78 | return R, G, B 79 | 80 | 81 | def wavelength_to_rgb_mat(wav_range, gamma=1): 82 | 83 | rgb_mat = np.zeros((len(wav_range), 3)) 84 | 85 | for i, wav in enumerate(wav_range): 86 | rgb_mat[i, :] = wavelength_to_rgb(wav, gamma) 87 | 88 | return rgb_mat 89 | 90 | 91 | def spectral_colorization(M_gray, wav, axis=None): 92 | """ 93 | Colorize the last dimension of an array 94 | 95 | Args: 96 | M_gray (np.ndarray): Grayscale array where the last dimension is the 97 | spectral dimension. This is an A-by-C array, where A can indicate multiple 98 | dimensions (e.g., 4-by-3-by-7) and C is the number of spectral channels. 99 | 100 | wav (np.ndarray): Wavelenth. This is a 1D array of size C. 101 | 102 | axis (None or int or tuple of ints, optional): Axis or axes along which 103 | the grayscale input is normalized. By default, global normalization 104 | across all axes is considered. 105 | 106 | Returns: 107 | M_color (np.ndarray): Color array with an extra dimension. This is an A-by-C-by-3 array. 108 | 109 | """ 110 | 111 | # Normalize to adjust contrast 112 | M_gray_min = M_gray.min(keepdims=True, axis=axis) 113 | M_gray_max = M_gray.max(keepdims=True, axis=axis) 114 | M_gray = (M_gray - M_gray_min) / (M_gray_max - M_gray_min) 115 | 116 | # 117 | rgb_mat = wavelength_to_rgb_mat(wav, gamma=1) 118 | M_red = M_gray @ np.diag(rgb_mat[:, 0]) 119 | M_green = M_gray @ np.diag(rgb_mat[:, 1]) 120 | M_blue = M_gray @ np.diag(rgb_mat[:, 2]) 121 | 122 | M_color = np.stack((M_red, M_green, M_blue), axis=-1) 123 | 124 | return M_color 125 | 126 | 127 | def colorize(im, color, clip_percentile=0.1): 128 | """ 129 | Helper function to create an RGB image from a single-channel image using a 130 | specific color. 131 | """ 132 | # Check that we just have a 2D image 133 | if im.ndim > 2 and im.shape[2] != 1: 134 | raise ValueError("This function expects a single-channel image!") 135 | 136 | # Rescale the image according to how we want to display it 137 | im_scaled = im.astype(np.float32) - np.percentile(im, clip_percentile) 138 | im_scaled = im_scaled / np.percentile(im_scaled, 100 - clip_percentile) 139 | print( 140 | f"Norm: min={np.percentile(im, clip_percentile)}, max={np.percentile(im_scaled, 100 - clip_percentile)}" 141 | ) 142 | print(f"New: min={im_scaled.min()}, max={im_scaled.max()}") 143 | im_scaled = np.clip(im_scaled, 0, 1) 144 | 145 | # Need to make sure we have a channels dimension for the multiplication to work 146 | im_scaled = np.atleast_3d(im_scaled) 147 | 148 | # Reshape the color (here, we assume channels last) 149 | color = np.asarray(color).reshape((1, 1, -1)) 150 | return im_scaled * color 151 | -------------------------------------------------------------------------------- /spyrit/misc/data_visualisation.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | #!/usr/bin/env python3 8 | # -*- coding: utf-8 -*- 9 | """ 10 | Created on Thu Jan 16 08:56:13 2020 11 | 12 | @author: crombez 13 | """ 14 | 15 | from astropy.io import fits 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | # Show basic information of a fits image acquired 20 | # with the andor zyla and plot the image 21 | def show_image_and_infos(path, file): 22 | hdul = fits.open(path + file) 23 | show_images_infos(path, file) 24 | plt.figure() 25 | plt.imshow(hdul[0].data[0]) 26 | plt.show() 27 | 28 | 29 | def show_images_infos(path, file): # Show basic information of a fits image acquired 30 | hdul = fits.open(path + file) 31 | print("***** Name file : " + file + " *****") 32 | print("Type de données : " + hdul[0].header["DATATYPE"]) 33 | print("Mode d'acquisition : " + hdul[0].header["ACQMODE"]) 34 | print("Temps d'exposition : " + str(hdul[0].header["EXPOSURE"])) 35 | print("Temps de lecture : " + str(hdul[0].header["READTIME"])) 36 | print("Longeur d'onde de Rayleigh : " + str(hdul[0].header["RAYWAVE"])) 37 | print("Longeur d'onde détectée : " + str(hdul[0].header["DTNWLGTH"])) 38 | print("***********************************" + "\n") 39 | 40 | 41 | # Plot the resulting fuction of to set of 1D data with the same dimension 42 | def simple_plot_2D( 43 | Lx, Ly, fig=None, title=None, xlabel=None, ylabel=None, style_color="b" 44 | ): 45 | plt.figure(fig) 46 | plt.clf() 47 | plt.title(title) 48 | plt.xlabel(xlabel) 49 | plt.ylabel(ylabel) 50 | plt.plot(Lx, Ly, style_color) 51 | plt.show() 52 | 53 | 54 | # Plot a 2D matrix 55 | def plot_im2D(Im, fig=None, title=None, xlabel=None, ylabel=None, cmap="viridis"): 56 | plt.figure(fig) 57 | plt.clf() 58 | plt.title(title) 59 | plt.xlabel(xlabel) 60 | plt.ylabel(ylabel) 61 | plt.imshow(Im, cmap=cmap) 62 | plt.colorbar() 63 | plt.show() 64 | -------------------------------------------------------------------------------- /spyrit/misc/disp.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | from PIL import Image 10 | import numpy as np 11 | from numpy import linalg as LA 12 | import time 13 | from scipy import signal 14 | from scipy import misc 15 | from scipy import sparse 16 | import torch 17 | import math 18 | 19 | 20 | def display_vid(video, fps, title="", colormap=plt.cm.gray): 21 | """ 22 | video is a numpy array of shape [nb_frames, 1, nx, ny] 23 | """ 24 | plt.ion() 25 | (nb_frames, channels, nx, ny) = video.shape 26 | fig = plt.figure() 27 | ax = fig.add_subplot(1, 1, 1) 28 | for i in range(nb_frames): 29 | current_frame = video[i, 0, :, :] 30 | plt.imshow(current_frame, cmap=colormap) 31 | plt.title(title) 32 | divider = make_axes_locatable(ax) 33 | cax = plt.axes([0.85, 0.1, 0.075, 0.8]) 34 | plt.colorbar(cax=cax) 35 | plt.show() 36 | plt.pause(fps) 37 | plt.ioff() 38 | 39 | 40 | def display_rgb_vid(video, fps, title=""): 41 | """ 42 | video is a numpy array of shape [nb_frames, 3, nx, ny] 43 | """ 44 | plt.ion() 45 | (nb_frames, channels, nx, ny) = video.shape 46 | fig = plt.figure() 47 | ax = fig.add_subplot(1, 1, 1) 48 | for i in range(nb_frames): 49 | current_frame = video[i, :, :, :] 50 | current_frame = np.moveaxis(current_frame, 0, -1) 51 | plt.imshow(current_frame) 52 | plt.title(title) 53 | plt.show() 54 | plt.pause(fps) 55 | plt.ioff() 56 | 57 | 58 | def fitPlots(N, aspect=(16, 9)): 59 | width = aspect[0] 60 | height = aspect[1] 61 | area = width * height * 1.0 62 | factor = (N / area) ** (1 / 2.0) 63 | cols = math.floor(width * factor) 64 | rows = math.floor(height * factor) 65 | rowFirst = width < height 66 | while rows * cols < N: 67 | if rowFirst: 68 | rows += 1 69 | else: 70 | cols += 1 71 | rowFirst = not (rowFirst) 72 | return rows, cols 73 | 74 | 75 | def Multi_plots( 76 | img_list, 77 | title_list, 78 | shape, 79 | suptitle="", 80 | colormap=plt.cm.gray, 81 | axis_off=True, 82 | aspect=(16, 9), 83 | savefig="", 84 | fontsize=14, 85 | ): 86 | [rows, cols] = shape 87 | plt.figure() 88 | plt.suptitle(suptitle, fontsize=16) 89 | if (len(img_list) < rows * cols) or (len(title_list) < rows * cols): 90 | for k in range(max(rows * cols - len(img_list), rows * cols - len(title_list))): 91 | img_list.append(np.zeros((64, 64))) 92 | title_list.append("") 93 | 94 | for k in range(rows * cols): 95 | ax = plt.subplot(rows, cols, k + 1) 96 | ax.imshow(img_list[k], cmap=colormap) 97 | ax.set_title(title_list[k], fontsize=fontsize) 98 | if axis_off: 99 | plt.axis("off") 100 | if savefig: 101 | plt.savefig(savefig, bbox_inches="tight") 102 | plt.show() 103 | 104 | 105 | def compare_video_frames( 106 | vid_list, 107 | nb_disp_frames, 108 | title_list, 109 | suptitle="", 110 | colormap=plt.cm.gray, 111 | aspect=(16, 9), 112 | savefig="", 113 | fontsize=14, 114 | ): 115 | rows = len(vid_list) 116 | cols = nb_disp_frames 117 | plt.figure(figsize=aspect) 118 | plt.suptitle(suptitle, fontsize=16) 119 | for i in range(rows): 120 | for j in range(cols): 121 | k = (j + 1) + (i) * (cols) 122 | i 123 | # print(k) 124 | ax = plt.subplot(rows, cols, k) 125 | # print("i = {}, j = {}".format(i,j)) 126 | ax.imshow(vid_list[i][0, j, 0, :, :], cmap=colormap) 127 | ax.set_title(title_list[i][j], fontsize=fontsize) 128 | plt.axis("off") 129 | if savefig: 130 | plt.savefig(savefig, bbox_inches="tight") 131 | plt.show() 132 | 133 | 134 | def torch2numpy(torch_tensor): 135 | return torch_tensor.cpu().detach().numpy() 136 | 137 | 138 | def uint8(dsp): 139 | x = (dsp - np.amin(dsp)) / (np.amax(dsp) - np.amin(dsp)) * 255 140 | x = x.astype("uint8") 141 | return x 142 | 143 | 144 | def imagesc( 145 | Img, 146 | title="", 147 | colormap=plt.cm.gray, 148 | show=True, 149 | figsize=None, 150 | cbar_pos=None, 151 | title_fontsize=16, 152 | ): 153 | """ 154 | imagesc(IMG) Display image Img with scaled colors with greyscale 155 | colormap and colorbar 156 | imagesc(IMG, title=ttl) Display image Img with scaled colors with 157 | greyscale colormap and colorbar, with the title ttl 158 | imagesc(IMG, title=ttl, colormap=cmap) Display image Img with scaled colors 159 | with colormap and colorbar specified by cmap (choose between 'plasma', 160 | 'jet', and 'grey'), with the title ttl 161 | """ 162 | fig = plt.figure(figsize=figsize) 163 | ax = fig.add_subplot(1, 1, 1) 164 | plt.imshow(Img, cmap=colormap) 165 | plt.title(title, fontsize=title_fontsize) 166 | divider = make_axes_locatable(ax) 167 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes 168 | 169 | if cbar_pos == "bottom": 170 | cax = inset_axes( 171 | ax, width="100%", height="5%", loc="lower center", borderpad=-5 172 | ) 173 | plt.colorbar(cax=cax, orientation="horizontal") 174 | else: 175 | cax = plt.axes([0.85, 0.1, 0.075, 0.8]) 176 | plt.colorbar(cax=cax, orientation="vertical") 177 | 178 | # fig.tight_layout() # it raises warnings in some cases 179 | if show is True: 180 | plt.show() 181 | 182 | 183 | def imagecomp( 184 | Img1, 185 | Img2, 186 | suptitle="", 187 | title1="", 188 | title2="", 189 | colormap1=plt.cm.gray, 190 | colormap2=plt.cm.gray, 191 | ): 192 | f, (ax1, ax2) = plt.subplots(1, 2) 193 | im1 = ax1.imshow(Img1, cmap=colormap1) 194 | ax1.set_title(title1) 195 | cax = plt.axes([0.43, 0.3, 0.025, 0.4]) 196 | plt.colorbar(im1, cax=cax) 197 | plt.suptitle(suptitle, fontsize=16) 198 | # 199 | im2 = ax2.imshow(Img2, cmap=colormap2) 200 | ax2.set_title(title2) 201 | cax = plt.axes([0.915, 0.3, 0.025, 0.4]) 202 | plt.colorbar(im2, cax=cax) 203 | plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) 204 | plt.show() 205 | 206 | 207 | def imagepanel( 208 | Img1, 209 | Img2, 210 | Img3, 211 | Img4, 212 | suptitle="", 213 | title1="", 214 | title2="", 215 | title3="", 216 | title4="", 217 | colormap1=plt.cm.gray, 218 | colormap2=plt.cm.gray, 219 | colormap3=plt.cm.gray, 220 | colormap4=plt.cm.gray, 221 | ): 222 | fig, axarr = plt.subplots(2, 2, figsize=(20, 10)) 223 | plt.suptitle(suptitle, fontsize=16) 224 | 225 | im1 = axarr[0, 0].imshow(Img1, cmap=colormap1) 226 | axarr[0, 0].set_title(title1) 227 | cax = plt.axes([0.4, 0.54, 0.025, 0.35]) 228 | plt.colorbar(im1, cax=cax) 229 | 230 | im2 = axarr[0, 1].imshow(Img2, cmap=colormap2) 231 | axarr[0, 1].set_title(title2) 232 | cax = plt.axes([0.90, 0.54, 0.025, 0.35]) 233 | plt.colorbar(im2, cax=cax) 234 | 235 | im3 = axarr[1, 0].imshow(Img3, cmap=colormap3) 236 | axarr[1, 0].set_title(title3) 237 | cax = plt.axes([0.4, 0.12, 0.025, 0.35]) 238 | plt.colorbar(im3, cax=cax) 239 | 240 | im4 = axarr[1, 1].imshow(Img4, cmap=colormap4) 241 | axarr[1, 1].set_title(title4) 242 | cax = plt.axes([0.9, 0.12, 0.025, 0.35]) 243 | plt.colorbar(im4, cax=cax) 244 | 245 | plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) 246 | plt.show() 247 | 248 | 249 | def plot(x, y, title="", xlabel="", ylabel="", color="black"): 250 | fig = plt.figure() 251 | ax = fig.add_subplot(1, 1, 1) 252 | plt.plot(x, y, color=color) 253 | plt.title(title) 254 | plt.xlabel(xlabel) 255 | plt.ylabel(ylabel) 256 | plt.show() 257 | 258 | 259 | def add_colorbar(mappable, position="right", size="5%"): 260 | """ 261 | Example: 262 | f, axs = plt.subplots(1, 2) 263 | im = axs[0].imshow(img1, cmap='gray') 264 | add_colorbar(im) 265 | im = axs[0].imshow(img2, cmap='gray') 266 | add_colorbar(im) 267 | """ 268 | if position == "bottom": 269 | orientation = "horizontal" 270 | else: 271 | orientation = "vertical" 272 | 273 | last_axes = plt.gca() 274 | ax = mappable.axes 275 | fig = ax.figure 276 | divider = make_axes_locatable(ax) 277 | cax = divider.append_axes(position, size="5%", pad=0.05) 278 | cbar = fig.colorbar(mappable, cax=cax, orientation=orientation) 279 | plt.sca(last_axes) 280 | return cbar 281 | 282 | 283 | def noaxis(axs): 284 | if type(axs) is np.ndarray: 285 | for ax in axs: 286 | ax.get_xaxis().set_visible(False) 287 | ax.get_yaxis().set_visible(False) 288 | else: 289 | axs.get_xaxis().set_visible(False) 290 | axs.get_yaxis().set_visible(False) 291 | 292 | 293 | def string_mean_std(x, prec=3): 294 | return "{:.{p}f} +/- {:.{p}f}".format(np.mean(x), np.std(x), p=prec) 295 | 296 | 297 | def print_mean_std(x, tag="", prec=3): 298 | print("{} = {:.{p}f} +/- {:.{p}f}".format(tag, np.mean(x), np.std(x), p=prec)) 299 | 300 | 301 | def histogram(s): 302 | count, bins, ignored = plt.hist(s, 30, density=True) 303 | plt.show() 304 | 305 | 306 | def vid2batch(root, img_dim, start_frame, end_frame): 307 | from imutils.video import FPS 308 | import imutils 309 | import cv2 310 | 311 | stream = cv2.VideoCapture(root) 312 | fps = FPS().start() 313 | frame_nb = 0 314 | output_batch = torch.zeros(1, end_frame - start_frame, 1, img_dim, img_dim) 315 | while True: 316 | (grabbed, frame) = stream.read() 317 | if not grabbed: 318 | break 319 | 320 | frame_nb += 1 321 | if (frame_nb >= start_frame) & (frame_nb < end_frame): 322 | frame = cv2.resize(frame, (img_dim, img_dim)) 323 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB) 324 | output_batch[0, frame_nb - start_frame, 0, :, :] = torch.Tensor( 325 | frame[:, :, 1] 326 | ) 327 | 328 | return output_batch 329 | 330 | 331 | def pre_process_video(video, crop_patch, kernel_size): 332 | import cv2 333 | 334 | batch_size, seq_length, c, h, w = video.shape 335 | batched_frames = video.reshape(batch_size * seq_length * c, h, w) 336 | output_batch = torch.zeros(batched_frames.shape) 337 | 338 | for i in range(batch_size * seq_length * c): 339 | img = torch2numpy(batched_frames[i, :, :]) 340 | img[crop_patch] = 0 341 | median_frame = cv2.medianBlur(img, kernel_size) 342 | output_batch[i, :, :] = torch.Tensor(median_frame) 343 | output_batch = output_batch.reshape(batch_size, seq_length, c, h, w) 344 | return output_batch 345 | -------------------------------------------------------------------------------- /spyrit/misc/examples.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | import numpy as np 8 | 9 | 10 | def translation_matrix(img_size, nb_pixels): 11 | init_ind = np.reshape(np.arange(img_size**2), (img_size, img_size)) 12 | final_ind = np.zeros((img_size, img_size)) 13 | final_ind[:, : (img_size - nb_pixels)] = init_ind[:, nb_pixels:] 14 | final_ind[:, (img_size - nb_pixels) :] = init_ind[:, :nb_pixels] 15 | 16 | final_ind = np.reshape(final_ind, (img_size**2, 1)) 17 | init_ind = np.reshape(init_ind, (img_size**2, 1)) 18 | F = permutation_matrix(final_ind, init_ind) 19 | return F 20 | 21 | 22 | def permutation_matrix(A, B): 23 | N = A.shape[0] 24 | I = np.eye(N) 25 | P = np.zeros((N, N)) 26 | 27 | for i in range(N): 28 | pat = np.matlib.repmat(A[i, :], N, 1) 29 | ind = np.where(np.sum((pat == B), axis=1)) 30 | P[ind, :] = I[i, :] 31 | 32 | return P 33 | 34 | 35 | def circle(img_size, R, x_max): 36 | x = np.linspace(-x_max, x_max, img_size) 37 | X, Y = np.meshgrid(x, x) 38 | return 1.0 * (X**2 + Y**2 < R) 39 | -------------------------------------------------------------------------------- /spyrit/misc/load_data.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | #!/usr/bin/env python3 8 | # -*- coding: utf-8 -*- 9 | """ 10 | Created on Wed Jan 15 17:06:19 2020 11 | 12 | @author: crombez 13 | """ 14 | 15 | import os 16 | import sys 17 | import glob 18 | import numpy as np 19 | import PIL 20 | from typing import Union 21 | 22 | 23 | def Files_names(Path, name_type): 24 | files = glob.glob(Path + name_type) 25 | print 26 | files.sort(key=os.path.getmtime) 27 | return [os.path.basename(x) for x in files] 28 | 29 | 30 | def load_data_recon_3D(Path_files, list_files, Nl, Nc, Nh): 31 | Data = np.zeros((Nl, Nc, Nh)) 32 | 33 | for i in range(0, 2 * Nh, 2): 34 | Data[:, :, i // 2] = np.rot90( 35 | np.array(PIL.Image.open(Path_files + list_files[i])) 36 | ) - np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1]))) 37 | 38 | return Data 39 | 40 | 41 | # Load the data of the hSPIM and compresse the spectrale dimensions to do the reconstruction for every lambda 42 | # odl convention the set of data has to be arranged in such way that the positive part of the hadamard motifs comes first 43 | def load_data_Comp_1D_old(Path_files, list_files, Nh, Nl, Nc): 44 | Data = np.zeros((Nl, Nh)) 45 | 46 | for i in range(0, 2 * Nh, 2): 47 | Data[:, i // 2] = Sum_coll( 48 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc 49 | ) - Sum_coll( 50 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3), 51 | Nl, 52 | Nc, 53 | ) 54 | 55 | return Data 56 | 57 | 58 | # Load the data of the hSPIM and compresse the spectrale dimensions to do the reconstruction for every lambda 59 | # new convention the set of data has to be arranged in such way that the negative part of the hadamard motifs comes first 60 | def load_data_Comp_1D_new(Path_files, list_files, Nh, Nl, Nc): 61 | Data = np.zeros((Nl, Nh)) 62 | 63 | for i in range(0, 2 * Nh, 2): 64 | Data[:, i // 2] = Sum_coll( 65 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3), 66 | Nl, 67 | Nc, 68 | ) - Sum_coll( 69 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc 70 | ) 71 | 72 | return Data 73 | 74 | 75 | def download_girder( 76 | server_url: str, 77 | hex_ids: Union[str, list[str]], 78 | local_folder: str, 79 | file_names: Union[str, list[str]] = None, 80 | ): 81 | """ 82 | Downloads data from a Girder server and saves it locally. 83 | 84 | This function first creates the local folder if it does not exist. Then, it 85 | connects to the Girder server and gets the file names for the files 86 | whose name are not provided. For each file, it checks if it already exists 87 | by checking if the file name is already in the local folder. If not, it 88 | downloads the file. 89 | 90 | Args: 91 | server_url (str): The URL of the Girder server. 92 | 93 | hex_id (str or list[str]): The hexadecimal id of the file(s) to download. 94 | If a list is provided, the files are downloaded in the same order and 95 | are saved in the same folder. 96 | 97 | local_folder (str): The path to the local folder where the files will 98 | be saved. If it does not exist, it will be created. 99 | 100 | file_name (str or list[str], optional): The name of the file(s) to save. 101 | If a list is provided, it must have the same length as hex_id. Each 102 | element equal to `None` will be replaced by the name of the file on the 103 | server. If None, all the names will be obtained from the server. 104 | Default is None. All names include the extension. 105 | 106 | Raises: 107 | ValueError: If the number of file names provided does not match the 108 | number of files to download. 109 | 110 | Returns: 111 | list[str]: The absolute paths to the downloaded files. 112 | """ 113 | # leave import in function, so that the module can be used without 114 | # girder_client 115 | import girder_client 116 | 117 | # check the local folder exists 118 | if not os.path.exists(local_folder): 119 | print("Local folder not found, creating it... ", end="") 120 | os.makedirs(local_folder) 121 | print("done.") 122 | 123 | # connect to the server 124 | gc = girder_client.GirderClient(apiUrl=server_url) 125 | 126 | # create lists if strings are provided 127 | if type(hex_ids) is str: 128 | hex_ids = [hex_ids] 129 | if file_names is None: 130 | file_names = [None] * len(hex_ids) 131 | elif type(file_names) is str: 132 | file_names = [file_names] 133 | 134 | if len(file_names) != len(hex_ids): 135 | raise ValueError("There must be as many file names as hex ids.") 136 | 137 | abs_paths = [] 138 | 139 | # for each file, check if it exists and download if necessary 140 | for id, name in zip(hex_ids, file_names): 141 | 142 | if name is None: 143 | # get the file name 144 | name = gc.getFile(id)["name"] 145 | 146 | # check the file exists 147 | if not os.path.exists(os.path.join(local_folder, name)): 148 | # connect to the server to download the file 149 | print(f"Downloading {name}... ", end="\r") 150 | gc.downloadFile(id, os.path.join(local_folder, name)) 151 | print(f"Downloading {name}... done.") 152 | 153 | else: 154 | print("File already exists at", os.path.join(local_folder, name)) 155 | 156 | abs_paths.append(os.path.abspath(os.path.join(local_folder, name))) 157 | 158 | return abs_paths[0] if len(abs_paths) == 1 else abs_paths 159 | -------------------------------------------------------------------------------- /spyrit/misc/matrix_tools.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | #!/usr/bin/env python3 8 | # -*- coding: utf-8 -*- 9 | """ 10 | Created on Wed Jan 15 16:37:27 2020 11 | 12 | @author: crombez 13 | """ 14 | import warnings 15 | 16 | warnings.simplefilter("always", DeprecationWarning) 17 | 18 | import numpy as np 19 | 20 | import spyrit.misc.sampling as samp 21 | 22 | 23 | def Permutation_Matrix(mat): 24 | r""" 25 | Returns permutation matrix from sampling matrix 26 | 27 | Args: 28 | Mat (np.ndarray): 29 | N-by-N sampling matrix, where high values indicate high significance. 30 | 31 | Returns: 32 | P (np.ndarray): N^2-by-N^2 permutation matrix (boolean) 33 | 34 | .. warning:: 35 | This function is a duplicate of 36 | :func:`spyrit.misc.sampling.Permutation_Matrix` and will be removed 37 | in a future release. 38 | 39 | .. note:: 40 | Consider using :func:`sort_by_significance` for increased 41 | computational performance if using :func:`Permutation_Matrix` to 42 | reorder a matrix as follows: 43 | ``y = Permutation_Matrix(Ord) @ Mat`` 44 | """ 45 | warnings.warn( 46 | "\nspyrit.misc.matrix_tools.Permutation_Matrix is deprecated and will" 47 | + " be removed in a future release. Use\n" 48 | + "spyrit.misc.sampling.Permutation_Matrix instead.", 49 | DeprecationWarning, 50 | ) 51 | return samp.Permutation_Matrix(mat) 52 | 53 | 54 | def expend_vect(Vect, N1, N2): # Expened a vectors of siez N1 to N2 55 | V_out = np.zeros(N2) 56 | S = int(N2 / N1) 57 | j = 0 58 | ad = 0 59 | for i in range(N1): 60 | for j in range(0, S): 61 | V_out[i + j + ad] = Vect[i] 62 | ad += S - 1 63 | return V_out 64 | 65 | 66 | def data_conv_hadamard(H, Data, N): 67 | for i in range(N): 68 | H[:, :, i] = H[:, :, i] * Data 69 | return H 70 | 71 | 72 | def Sum_coll(Mat, N_lin, N_coll): # Return the sum of all the raw of the N1xN2 matrix 73 | Mturn = np.zeros(N_lin) 74 | 75 | for i in range(N_coll): 76 | Mturn += Mat[:, i] 77 | 78 | return Mturn 79 | 80 | 81 | def compression_1D( 82 | H, Nl, Nc, Nh 83 | ): # Compress a Matrix of N1xN2xN3 into a matrix of N1xN3 by summing the raw 84 | H_1D = np.zeros((Nl, Nh)) 85 | for i in range(Nh): 86 | H_1D[:, i] = Sum_coll(H[:, :, i], Nl, Nc) 87 | 88 | return H_1D 89 | 90 | 91 | def normalize_mat_2D(Mat): # Normalise a N1xN2 matrix by is maximum value 92 | Max = np.amax(Mat) 93 | return Mat * (1 / Max) 94 | 95 | 96 | def normalize_by_median_mat_2D(Mat): # Normalise a N1xN2 matrix by is median value 97 | Median = np.median(Mat) 98 | return Mat * (1 / Median) 99 | 100 | 101 | def remove_offset_mat_2D(Mat): # Substract the mean value of the matrix 102 | Mean = np.mean(Mat) 103 | return Mat - Mean 104 | 105 | 106 | def resize(Mat, Nl, Nc, Nh): # Re-size a matrix of N1xN2 into N1xN3 107 | Mres = np.zeros((Nl, Nc)) 108 | for i in range(Nl): 109 | Mres[i, :] = expend_vect(Mat[i, :], Nh, Nc) 110 | return Mres 111 | 112 | 113 | def stack_depth_matrice( 114 | Mat, Nl, Nc, Nd 115 | ): # Stack a 3 by 3 matrix along its third dimensions 116 | M_out = np.zeros((Nl, Nc)) 117 | for i in range(Nd): 118 | M_out += Mat[:, :, i] 119 | return M_out 120 | 121 | 122 | # fuction that need to be better difended 123 | 124 | 125 | def smooth(y, box_pts): # Smooth a vectors 126 | box = np.ones(box_pts) / box_pts 127 | y_smooth = np.convolve(y, box, mode="same") 128 | return y_smooth 129 | 130 | 131 | def reject_outliers(data, m=2): # Remove 132 | return np.where(abs(data - np.mean(data)) < m * np.std(data), data, 0) 133 | 134 | 135 | def clean_out(Data, Nl, Nc, Nh, m=2): 136 | Mout = np.zeros((Nl, Nc, Nh)) 137 | for i in range(Nh): 138 | Mout[:, :, i] = reject_outliers(Data[:, :, i], m) 139 | return Data 140 | -------------------------------------------------------------------------------- /spyrit/misc/metrics.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # This software is distributed under the terms 3 | # of the GNU Lesser General Public Licence (LGPL) 4 | # See LICENSE.md for further details 5 | # ----------------------------------------------------------------------------- 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim import lr_scheduler 11 | import numpy as np 12 | import torchvision 13 | from torchvision import datasets, models, transforms 14 | import torch.nn.functional as F 15 | import imageio 16 | import matplotlib.pyplot as plt 17 | 18 | # import skimage.metrics as skm 19 | 20 | 21 | def batch_psnr(torch_batch, output_batch): 22 | list_psnr = [] 23 | for i in range(torch_batch.shape[0]): 24 | img = torch_batch[i, 0, :, :] 25 | img_out = output_batch[i, 0, :, :] 26 | img = img.cpu().detach().numpy() 27 | img_out = img_out.cpu().detach().numpy() 28 | list_psnr.append(psnr(img, img_out)) 29 | return list_psnr 30 | 31 | 32 | def batch_psnr_(torch_batch, output_batch, r=2): 33 | list_psnr = [] 34 | for i in range(torch_batch.shape[0]): 35 | img = torch_batch[i, 0, :, :] 36 | img_out = output_batch[i, 0, :, :] 37 | img = img.cpu().detach().numpy() 38 | img_out = img_out.cpu().detach().numpy() 39 | list_psnr.append(psnr_(img, img_out, r=r)) 40 | return list_psnr 41 | 42 | 43 | def batch_ssim(torch_batch, output_batch): 44 | list_ssim = [] 45 | for i in range(torch_batch.shape[0]): 46 | img = torch_batch[i, 0, :, :] 47 | img_out = output_batch[i, 0, :, :] 48 | img = img.cpu().detach().numpy() 49 | img_out = img_out.cpu().detach().numpy() 50 | list_ssim.append(ssim(img, img_out)) 51 | return list_ssim 52 | 53 | 54 | def dataset_meas(dataloader, model, device): 55 | meas = [] 56 | for inputs, labels in dataloader: 57 | inputs = inputs.to(device) 58 | # with torch.no_grad(): 59 | b, c, h, w = inputs.shape 60 | net_output = model.acquire(inputs, b, c, h, w) 61 | raw = net_output[:, 0, :] 62 | raw = raw.cpu().detach().numpy() 63 | meas.extend(raw) 64 | return meas 65 | 66 | 67 | # 68 | # def dataset_psnr_different_measures(dataloader, model, model_2, device): 69 | # psnr = []; 70 | # #psnr_fc = []; 71 | # for inputs, labels in dataloader: 72 | # inputs = inputs.to(device) 73 | # m = model_2.normalized measure(inputs); 74 | # net_output = model.forward_reconstruct(inputs); 75 | # #net_output2 = model.evaluate_fcl(inputs); 76 | # 77 | # psnr += batch_psnr(inputs, net_output); 78 | # #psnr_fc += batch_psnr(inputs, net_output2); 79 | # psnr = np.array(psnr); 80 | # #psnr_fc = np.array(psnr_fc); 81 | # return psnr; 82 | # 83 | 84 | 85 | def dataset_psnr(dataloader, model, device): 86 | psnr = [] 87 | psnr_fc = [] 88 | for inputs, labels in dataloader: 89 | inputs = inputs.to(device) 90 | # with torch.no_grad(): 91 | # b,c,h,w = inputs.shape; 92 | 93 | net_output = model.evaluate(inputs) 94 | net_output2 = model.evaluate_fcl(inputs) 95 | 96 | psnr += batch_psnr(inputs, net_output) 97 | psnr_fc += batch_psnr(inputs, net_output2) 98 | psnr = np.array(psnr) 99 | psnr_fc = np.array(psnr_fc) 100 | return psnr, psnr_fc 101 | 102 | 103 | def dataset_ssim(dataloader, model, device): 104 | ssim = [] 105 | ssim_fc = [] 106 | for inputs, labels in dataloader: 107 | inputs = inputs.to(device) 108 | # evaluate full model and fully connected layer 109 | net_output = model.evaluate(inputs) 110 | net_output2 = model.evaluate_fcl(inputs) 111 | # compute SSIM and concatenate 112 | ssim += batch_ssim(inputs, net_output) 113 | ssim_fc += batch_ssim(inputs, net_output2) 114 | ssim = np.array(ssim) 115 | ssim_fc = np.array(ssim_fc) 116 | return ssim, ssim_fc 117 | 118 | 119 | def dataset_psnr_ssim(dataloader, model, device): 120 | # init lists 121 | psnr = [] 122 | ssim = [] 123 | # loop over batches 124 | for inputs, labels in dataloader: 125 | inputs = inputs.to(device) 126 | # evaluate full model 127 | net_output = model.evaluate(inputs) 128 | # compute PSNRs and concatenate 129 | psnr += batch_psnr(inputs, net_output) 130 | # compute SSIMs and concatenate 131 | ssim += batch_ssim(inputs, net_output) 132 | # convert 133 | psnr = np.array(psnr) 134 | ssim = np.array(ssim) 135 | return psnr, ssim 136 | 137 | 138 | def dataset_psnr_ssim_fcl(dataloader, model, device): 139 | # init lists 140 | psnr = [] 141 | ssim = [] 142 | # loop over batches 143 | for inputs, labels in dataloader: 144 | inputs = inputs.to(device) 145 | # evaluate fully connected layer 146 | net_output = model.evaluate_fcl(inputs) 147 | # compute PSNRs and concatenate 148 | psnr += batch_psnr(inputs, net_output) 149 | # compute SSIMs and concatenate 150 | ssim += batch_ssim(inputs, net_output) 151 | # convert 152 | psnr = np.array(psnr) 153 | ssim = np.array(ssim) 154 | return psnr, ssim 155 | 156 | 157 | def psnr(I1, I2): 158 | """ 159 | Computes the psnr between two images I1 and I2 160 | """ 161 | d = np.amax(I1) - np.amin(I1) 162 | diff = np.square(I2 - I1) 163 | MSE = diff.sum() / I1.size 164 | Psnr = 10 * np.log(d**2 / MSE) / np.log(10) 165 | return Psnr 166 | 167 | 168 | def psnr_(img1, img2, r=2): 169 | """ 170 | Computes the psnr between two image with values expected in a given range 171 | 172 | Args: 173 | img1, img2 (np.ndarray): images 174 | r (float): image range 175 | 176 | Returns: 177 | Psnr (float): Peak signal-to-noise ratio 178 | 179 | """ 180 | MSE = np.mean((img1 - img2) ** 2) 181 | Psnr = 10 * np.log(r**2 / MSE) / np.log(10) 182 | return Psnr 183 | 184 | 185 | def psnr_torch(img_gt, img_rec, mask=None, dim=(-2, -1), img_dyn=None): 186 | r""" 187 | Computes the Peak Signal-to-Noise Ratio (PSNR) between two images. 188 | 189 | .. math:: 190 | 191 | \text{PSNR} = 20 \, \log_{10} \left( \frac{\text{d}}{\sqrt{\text{MSE}}} \right), \\ 192 | \text{MSE} = \frac{1}{L}\sum_{\ell=1}^L \|I_\ell - \tilde{I}_\ell\|^2_2, 193 | 194 | where :math:`d` is the image dynamic and :math:`\{I_\ell\}` (resp. :math:`\{\tilde{I}_\ell\}`) is the set of ground truth (resp. reconstructed) images. 195 | 196 | Args: 197 | :attr:`img_gt`: Tensor containing the *ground-truth* image. 198 | 199 | :attr:`img_rec`: Tensor containing the reconstructed image. 200 | 201 | :attr:`mask`: Mask where the squared error is computed. Defaults :attr:`None`, i.e., no mask is considered. 202 | 203 | :attr:`dim`: Dimensions where the squared error is computed. If mask is :attr:`None`, defaults to :attr:`-1` (i.e., the last dimension). Othewise defaults to :attr:`(-2,-1)` (i.e., the last two dimensions). 204 | 205 | :attr:`img_dyn`: Image dynamic range (e.g., 1.0 for normalized images, 255 for 8-bit images). When :attr:`img_dyn` is :attr:`None`, the dynamic range is computed from the ground-truth image. 206 | 207 | Returns: 208 | PSNR value. 209 | 210 | .. note:: 211 | :attr:`psnr_torch(img_gt, img_rec)` is different from :attr:`psnr_torch(img_rec, img_gt)`. The first expression assumes :attr:`img_gt` is the ground truth while the second assumes that this is :attr:`img_rec`. This leads to different dynamic ranges. 212 | 213 | Example 1: 10 images of size 64x64 with values in [0,1) corrupted with 5% noise 214 | >>> x = torch.rand(10,1,64,64) 215 | >>> n = x + 0.05*torch.randn(x.shape) 216 | >>> out = psnr_torch(x,n) 217 | >>> print(out.shape) 218 | torch.Size([10, 1]) 219 | 220 | Example 2: 10 images of size 64x64 with values in [0,1) corrupted with 5% noise 221 | >>> psnr_torch(n,x) 222 | tensor(...) 223 | >>> psnr_torch(x,n) 224 | tensor(...) 225 | >>> psnr_torch(n,x,img_dyn=1.0) 226 | tensor(...) 227 | 228 | """ 229 | if mask is not None: 230 | dim = -1 231 | img_gt = img_gt[mask > 0] 232 | img_rec = img_rec[mask > 0] 233 | print("mask") 234 | 235 | mse = (img_gt - img_rec) ** 2 236 | mse = torch.mean(mse, dim=dim) 237 | 238 | if img_dyn is None: 239 | img_dyn = torch.amax(img_gt, dim=dim) - torch.amin(img_gt, dim=dim) 240 | 241 | return 10 * torch.log10(img_dyn**2 / mse) 242 | 243 | 244 | def ssim(I1, I2): 245 | """ 246 | Computes the ssim between two images I1 and I2 247 | """ 248 | L = np.amax(I1) - np.amin(I1) 249 | mu1 = np.mean(I1) 250 | mu2 = np.mean(I2) 251 | s1 = np.std(I1) 252 | s2 = np.std(I2) 253 | s12 = np.mean(np.multiply((I1 - mu1), (I2 - mu2))) 254 | c1 = (0.01 * L) ** 2 255 | c2 = (0.03 * L) ** 2 256 | result = ((2 * mu1 * mu2 + c1) * (2 * s12 + c2)) / ( 257 | (mu1**2 + mu2**2 + c1) * (s1**2 + s2**2 + c2) 258 | ) 259 | return result 260 | 261 | 262 | # def ssim_sk(x_gt, x, img_dyn=None): 263 | # """ 264 | # SSIM from skimage 265 | 266 | # Args: 267 | # torch tensors 268 | 269 | # Returns: 270 | # torch tensor 271 | # """ 272 | # if not isinstance(x, np.ndarray): 273 | # x = x.cpu().detach().numpy().squeeze() 274 | # x_gt = x_gt.cpu().detach().numpy().squeeze() 275 | # ssim_val = np.zeros(x.shape[0]) 276 | # for i in range(x.shape[0]): 277 | # ssim_val[i] = skm.structural_similarity(x_gt[i], x[i], data_range=img_dyn) 278 | # return torch.tensor(ssim_val) 279 | 280 | 281 | def batch_psnr_vid(input_batch, output_batch): 282 | list_psnr = [] 283 | batch_size, seq_length, c, h, w = input_batch.shape 284 | input_batch = input_batch.reshape(batch_size * seq_length * c, 1, h, w) 285 | output_batch = output_batch.reshape(batch_size * seq_length * c, 1, h, w) 286 | for i in range(input_batch.shape[0]): 287 | img = input_batch[i, 0, :, :] 288 | img_out = output_batch[i, 0, :, :] 289 | img = img.cpu().detach().numpy() 290 | img_out = img_out.cpu().detach().numpy() 291 | list_psnr.append(psnr(img, img_out)) 292 | return list_psnr 293 | 294 | 295 | def batch_ssim_vid(input_batch, output_batch): 296 | list_ssim = [] 297 | batch_size, seq_length, c, h, w = input_batch.shape 298 | input_batch = input_batch.reshape(batch_size * seq_length * c, 1, h, w) 299 | output_batch = output_batch.reshape(batch_size * seq_length * c, 1, h, w) 300 | for i in range(input_batch.shape[0]): 301 | img = input_batch[i, 0, :, :] 302 | img_out = output_batch[i, 0, :, :] 303 | img = img.cpu().detach().numpy() 304 | img_out = img_out.cpu().detach().numpy() 305 | list_ssim.append(ssim(img, img_out)) 306 | return list_ssim 307 | 308 | 309 | def compare_video_nets_supervised(net_list, testloader, device): 310 | psnr = [[] for i in range(len(net_list))] 311 | ssim = [[] for i in range(len(net_list))] 312 | for batch, (inputs, labels) in enumerate(testloader): 313 | [batch_size, seq_length, c, h, w] = inputs.shape 314 | print("Batch :{}/{}".format(batch + 1, len(testloader))) 315 | inputs = inputs.to(device) 316 | labels = labels.to(device) 317 | with torch.no_grad(): 318 | for i in range(len(net_list)): 319 | outputs = net_list[i].evaluate(inputs) 320 | psnr[i] += batch_psnr_vid(labels, outputs) 321 | ssim[i] += batch_ssim_vid(labels, outputs) 322 | return psnr, ssim 323 | 324 | 325 | def compare_nets_unsupervised(net_list, testloader, device): 326 | psnr = [[] for i in range(len(net_list))] 327 | ssim = [[] for i in range(len(net_list))] 328 | for batch, (inputs, labels) in enumerate(testloader): 329 | [batch_size, seq_length, c, h, w] = inputs.shape 330 | print("Batch :{}/{}".format(batch + 1, len(testloader))) 331 | inputs = inputs.to(device) 332 | labels = labels.to(device) 333 | with torch.no_grad(): 334 | for i in range(len(net_list)): 335 | outputs = net_list[i].evaluate(inputs) 336 | psnr[i] += batch_psnr_vid(outputs, labels) 337 | ssim[i] += batch_ssim_vid(outputs, labels) 338 | return psnr, ssim 339 | 340 | 341 | def print_mean_std(x, tag=""): 342 | print("{}psnr = {} +/- {}".format(tag, np.mean(x), np.std(x))) 343 | -------------------------------------------------------------------------------- /tutorial/README.txt: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | This series of tutorials should guide you through the use of the SPyRiT pipeline. 5 | 6 | .. figure:: ../fig/direct_net.png 7 | :width: 600 8 | :align: center 9 | :alt: SPyRiT pipeline 10 | 11 | | 12 | 13 | Each tutorial focuses on a specific submodule of the full pipeline. 14 | 15 | * :ref:`Tutorial 1 `.a introduces the basics of measurement operators. 16 | 17 | * :ref:`Tutorial 1 `.b introduces the splitting of measurement operators. 18 | 19 | * :ref:`Tutorial 1 `.c introduces the 2d Hadamard transform with subsampling. 20 | 21 | * :ref:`Tutorial 2 ` introduces the noise operators. 22 | 23 | * :ref:`Tutorial 3 ` demonstrates pseudo-inverse reconstructions from Hadamard measurements. 24 | 25 | 26 | .. note:: 27 | 28 | The Python script (*.py*) or Jupyter notebook (*.ipynb*) corresponding to each tutorial can be downloaded at the bottom of the page. The images used in these files can be found on `GitHub`_. 29 | 30 | The tutorials below will gradually be updated to be compatible with SPyRiT 3 (work in progress, in the meantime see SPyRiT `2.4.0`_). 31 | 32 | * :ref:`Tutorial 3 ` uses a CNN to denoise the image if necessary 33 | 34 | * :ref:`Tutorial 4 ` is used to train the CNN introduced in Tutorial 3 35 | 36 | * :ref:`Tutorial 5 ` introduces a new type of measurement operator ('split') that simulates positive and negative measurements 37 | 38 | * :ref:`Tutorial 6 ` uses a Denoised Completion Network with a trainable image denoiser to improve the results obtained in Tutorial 5 39 | 40 | * :ref:`Tutorial 7 ` shows how to perform image reconstruction using a pretrained plug-and-play denoising network. 41 | 42 | * :ref:`Tutorial 8 ` shows how to perform image reconstruction using a learnt proximal gradient descent. 43 | 44 | * :ref:`Tutorial 9 ` explains motion simulation from an image, dynamic measurements and reconstruction. 45 | 46 | * Explore :ref:`Bonus Tutorial ` if you want to go deeper into Spyrit's capabilities 47 | 48 | 49 | .. _GitHub: https://github.com/openspyrit/spyrit/tree/3895b5e61fb6d522cff5e8b32a36da89b807b081/tutorial/images/test 50 | 51 | .. _2.4.0: https://spyrit.readthedocs.io/en/2.4.0/gallery/index.html 52 | -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000001.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000001.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000002.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000002.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000003.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000003.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000004.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000004.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000005.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000005.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000006.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000006.jpeg -------------------------------------------------------------------------------- /tutorial/images/test/ILSVRC2012_test_00000007.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000007.jpeg -------------------------------------------------------------------------------- /tutorial/tuto_01_a_acquisition_operators.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 01.a. Acquisition operators (basic) 3 | ==================================================== 4 | .. _tuto_acquisition_operators: 5 | 6 | This tutorial shows how to simulate measurements using the :mod:`spyrit.core.meas` submodule. 7 | 8 | 9 | .. image:: ../fig/tuto1.png 10 | :width: 600 11 | :align: center 12 | :alt: Reconstruction architecture sketch 13 | 14 | | 15 | 16 | All simulations are based on :class:`spyrit.core.meas.Linear` base class that simulates linear measurements 17 | 18 | .. math:: 19 | m = Hx, 20 | 21 | where :math:`H\in\mathbb{R}^{M\times N}` is the acquisition matrix, :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`M` is the number of measurements, and :math:`N` is the dimension of the signal. 22 | 23 | .. important:: 24 | The vector :math:`x \in \mathbb{R}^N` represents a multi-dimensional array (e.g, an image :math:`X \in \mathbb{R}^{N_1 \times N_2}` with :math:`N = N_1 \times N_2`). Both variables are related through vectorization , i.e., :math:`x = \texttt{vec}(X)`. 25 | 26 | """ 27 | 28 | # %% 29 | # 1D Measurements 30 | # ----------------------------------------------------------------------------- 31 | 32 | ############################################################################### 33 | # We instantiate a measurement operator from a matrix of shape (10, 15). 34 | import torch 35 | from spyrit.core.meas import Linear 36 | 37 | H = torch.randn(10, 15) 38 | meas_op = Linear(H) 39 | 40 | ############################################################################### 41 | # We consider 3 signals of length 15 42 | x = torch.randn(3, 15) 43 | 44 | ############################################################################### 45 | # We apply the operator to the batch of images, which produces 3 measurements 46 | # of length 10 47 | m = meas_op(x) 48 | print(m.shape) 49 | 50 | ############################################################################### 51 | # We now plot the matrix-vector products 52 | 53 | from spyrit.misc.disp import add_colorbar, noaxis 54 | import matplotlib.pyplot as plt 55 | 56 | f, axs = plt.subplots(1, 3, figsize=(10, 5)) 57 | axs[0].set_title("Forward matrix H") 58 | im = axs[0].imshow(H, cmap="gray") 59 | add_colorbar(im, "bottom") 60 | 61 | axs[1].set_title("Signals x") 62 | im = axs[1].imshow(x.T, cmap="gray") 63 | add_colorbar(im, "bottom") 64 | 65 | axs[2].set_title("Measurements m") 66 | im = axs[2].imshow(m.T, cmap="gray") 67 | add_colorbar(im, "bottom") 68 | 69 | noaxis(axs) 70 | # sphinx_gallery_thumbnail_number = 1 71 | 72 | # %% 73 | # 2D Measurements 74 | # ----------------------------------------------------------------------------- 75 | 76 | ############################################################################### 77 | # We load a batch of images from the :attr:`/images/` folder. Using the 78 | # :func:`transform_gray_norm` function with the :attr:`normalize=False` 79 | # argument returns images with values in (0,1). 80 | import os 81 | import torchvision 82 | from spyrit.misc.statistics import transform_gray_norm 83 | 84 | spyritPath = os.getcwd() 85 | imgs_path = os.path.join(spyritPath, "images/") 86 | 87 | # Grayscale images of size (32, 32), no normalization to keep values in (0,1) 88 | transform = transform_gray_norm(img_size=32, normalize=False) 89 | 90 | # Create dataset and loader (expects class folder :attr:'images/test/') 91 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 92 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 93 | 94 | x, _ = next(iter(dataloader)) 95 | 96 | ############################################################################### 97 | # We crop the batch to get image of shape (9, 25). 98 | x = x[:, :, :9, :25] 99 | print(f"Shape of input images: {x.shape}") 100 | 101 | ############################################################################### 102 | # We plot the second image. 103 | from spyrit.misc.disp import imagesc 104 | 105 | imagesc(x[1, 0, :, :], "Image X") 106 | 107 | 108 | ############################################################################### 109 | # We instantiate a measurement operator from a random matrix with shape (10, 9*25). To indicate that the operator works in 2D, we use the :attr:`meas_shape` argument. 110 | H = torch.randn(10, 9 * 25) 111 | meas_op = Linear(H, meas_shape=(9, 25)) 112 | 113 | ############################################################################### 114 | # We apply the operator to the batch of images, which produces a batch of measurement vectors of length 10. 115 | m = meas_op(x) 116 | print(m.shape) 117 | 118 | 119 | ############################################################################### 120 | # We now plot the matrix-vector products corresponding to the second image in the batch. 121 | 122 | ############################################################################### 123 | # We first select the second image and the second measurement vector in the batch. 124 | x_plot = x[1, 0, :, :] 125 | m_plot = m[1] 126 | 127 | ############################################################################### 128 | # Then we vectorize the image to get a 1D array of length 9*25. 129 | x_plot = x_plot.reshape(1, -1) 130 | 131 | print(f"Vectorised image with shape: {x_plot.shape}") 132 | 133 | ############################################################################### 134 | # We finally plot the matrix-vector products :math:`m = H x = H \texttt{vec}(X)`. 135 | 136 | from spyrit.misc.disp import add_colorbar, noaxis 137 | import matplotlib.pyplot as plt 138 | 139 | f, axs = plt.subplots(1, 3) 140 | axs[0].set_title("Forward matrix H") 141 | im = axs[0].imshow(H, cmap="gray") 142 | # add_colorbar(im, "bottom") 143 | 144 | axs[1].set_title("x = vec(X)") 145 | im = axs[1].imshow(x_plot.mT, cmap="gray") 146 | # add_colorbar(im, "bottom") 147 | 148 | axs[2].set_title("Measurements m") 149 | im = axs[2].imshow(m_plot.mT, cmap="gray") 150 | # add_colorbar(im, "bottom") 151 | 152 | noaxis(axs) 153 | -------------------------------------------------------------------------------- /tutorial/tuto_01_b_splitting.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 01.b. Acquisition operators (splitting) 3 | ==================================================== 4 | .. _tuto_acquisition_operators_splitting: 5 | 6 | This tutorial shows how to simulate linear measurements by splitting an acquisition matrix :math:`H\in \mathbb{R}^{M\times N}` that contains negative values. It based on the :class:`spyrit.core.meas.LinearSplit` class of the :mod:`spyrit.core.meas` submodule. 7 | 8 | 9 | .. image:: ../fig/tuto1.png 10 | :width: 600 11 | :align: center 12 | :alt: Reconstruction architecture sketch 13 | 14 | | 15 | 16 | In practice, only positive values can be implemented using a digital micromirror device (DMD). Therefore, we acquire 17 | 18 | .. math:: 19 | y =Ax, 20 | 21 | where :math:`A \colon\, \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains positive DMD patterns, :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`2M` is the number of DMD patterns, and :math:`N` is the dimension of the signal. 22 | 23 | .. important:: 24 | The vector :math:`x \in \mathbb{R}^N` represents a multi-dimensional array (e.g, an image :math:`X \in \mathbb{R}^{N_1 \times N_2}` with :math:`N = N_1 \times N_2`). Both variables are related through vectorization , i.e., :math:`x = \texttt{vec}(X)`. 25 | 26 | Given a matrix :math:`H` with negative values, we define the positive DMD patterns :math:`A` from the positive and negative components :math:`H`. In practice, the even rows of :math:`A` contain the positive components of :math:`H`, while odd rows of :math:`A` contain the negative components of :math:`H`. 27 | 28 | .. math:: 29 | \begin{cases} 30 | A[0::2, :] = H_{+}, \text{ with } H_{+} = \max(0,H),\\ 31 | A[1::2, :] = H_{-}, \text{ with } H_{-} = \max(0,-H). 32 | \end{cases} 33 | 34 | """ 35 | 36 | # %% 37 | # Splitting in 1D 38 | # ----------------------------------------------------------------------------- 39 | 40 | ############################################################################### 41 | # We instantiate a measurement operator from a matrix of shape (10, 15). 42 | import torch 43 | from spyrit.core.meas import LinearSplit 44 | 45 | H = torch.randn(10, 15) 46 | meas_op = LinearSplit(H) 47 | 48 | ############################################################################### 49 | # We consider 3 signals of length 15. 50 | x = torch.randn(3, 15) 51 | 52 | ############################################################################### 53 | # We apply the operator to the batch of images, which produces 3 measurements 54 | # of length 10*2 = 20. 55 | y = meas_op(x) 56 | print(y.shape) 57 | 58 | ############################################################################### 59 | # .. note:: 60 | # The number of measurements is twice the number of rows of the matrix H that contains negative values. 61 | 62 | # %% 63 | # Illustration 64 | # ----------------------------------------------------------------------------- 65 | 66 | ############################################################################### 67 | # We plot the positive and negative components of H that are concatenated in the matrix A. 68 | 69 | A = meas_op.A 70 | H_pos = meas_op.A[::2, :] # Even rows 71 | H_neg = meas_op.A[1::2, :] # Odd rows 72 | 73 | from spyrit.misc.disp import add_colorbar, noaxis 74 | import matplotlib.pyplot as plt 75 | 76 | fig = plt.figure(figsize=(10, 5)) 77 | gs = fig.add_gridspec(2, 2) 78 | 79 | ax1 = fig.add_subplot(gs[:, 0]) 80 | ax2 = fig.add_subplot(gs[0, 1]) 81 | ax3 = fig.add_subplot(gs[1, 1]) 82 | 83 | ax1.set_title("Forward matrix A") 84 | im = ax1.imshow(A, cmap="gray") 85 | add_colorbar(im) 86 | 87 | ax2.set_title("Forward matrix H_pos") 88 | im = ax2.imshow(H_pos, cmap="gray") 89 | add_colorbar(im) 90 | 91 | ax3.set_title("Measurements H_neg") 92 | im = ax3.imshow(H_neg, cmap="gray") 93 | add_colorbar(im) 94 | 95 | noaxis(ax1) 96 | noaxis(ax2) 97 | noaxis(ax3) 98 | # sphinx_gallery_thumbnail_number = 1 99 | 100 | ############################################################################### 101 | # We can verify numerically that H = H_pos - H_neg 102 | 103 | H = meas_op.H 104 | diff = torch.linalg.norm(H - (H_pos - H_neg)) 105 | 106 | print(f"|| H - (H_pos - H_neg) || = {diff}") 107 | 108 | ############################################################################### 109 | # We now plot the matrix-vector products between A and x. 110 | 111 | f, axs = plt.subplots(1, 3, figsize=(10, 5)) 112 | axs[0].set_title("Forward matrix A") 113 | im = axs[0].imshow(A, cmap="gray") 114 | add_colorbar(im, "bottom") 115 | 116 | axs[1].set_title("Signals x") 117 | im = axs[1].imshow(x.T, cmap="gray") 118 | add_colorbar(im, "bottom") 119 | 120 | axs[2].set_title("Split measurements y") 121 | im = axs[2].imshow(y.T, cmap="gray") 122 | add_colorbar(im, "bottom") 123 | 124 | noaxis(axs) 125 | 126 | # %% 127 | # Simulations with noise and using the matrix H 128 | # -------------------------------------------------------------------- 129 | 130 | ###################################################################### 131 | # The operators in the :mod:`spyrit.core.meas` submodule allow for simulating noisy measurements 132 | # 133 | # .. math:: 134 | # y =\mathcal{N}\left(Ax\right), 135 | # 136 | # where :math:`\mathcal{N} \colon\, \mathbb{R}^{2M} \to \mathbb{R}^{2M}` represents a noise operator (e.g., Gaussian). By default, no noise is applied to the measurement, i.e., :math:`\mathcal{N}` is the identity. We can consider noise by setting the :attr:`noise_model` attribute of the :class:`spyrit.core.meas.LinearSplit` class. 137 | 138 | ##################################################################### 139 | # For instance, we can consider additive Gaussian noise with standard deviation 2. 140 | 141 | from spyrit.core.noise import Gaussian 142 | 143 | meas_op.noise_model = Gaussian(2) 144 | 145 | ##################################################################### 146 | # .. note:: 147 | # To learn more about noise models, please refer to :ref:`tutorial 2 `. 148 | 149 | ##################################################################### 150 | # We simulate the noisy measurement vectors 151 | y_noise = meas_op(x) 152 | 153 | ##################################################################### 154 | # Noiseless measurements can be simulated using the :meth:`spyrit.core.LinearSplit.measure` method. 155 | y_nonoise = meas_op.measure(x) 156 | 157 | ##################################################################### 158 | # The :meth:`spyrit.core.LinearSplit.measure_H` method simulates noiseless measurements using the matrix H, i.e., :math:`m = Hx`. 159 | m_nonoise = meas_op.measure_H(x) 160 | 161 | ##################################################################### 162 | # We now plot the noisy and noiseless measurements 163 | f, axs = plt.subplots(1, 3, figsize=(8, 5)) 164 | axs[0].set_title("Split measurements y \n with noise") 165 | im = axs[0].imshow(y_noise.mT, cmap="gray") 166 | add_colorbar(im) 167 | 168 | axs[1].set_title("Split measurements y \n without noise") 169 | im = axs[1].imshow(y_nonoise.mT, cmap="gray") 170 | add_colorbar(im) 171 | 172 | axs[2].set_title("Measurements m \n without noise") 173 | im = axs[2].imshow(m_nonoise.mT, cmap="gray") 174 | add_colorbar(im) 175 | 176 | noaxis(axs) 177 | -------------------------------------------------------------------------------- /tutorial/tuto_01_c_HadamSplit2d.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 01.c. Acquisition operators (HadamSplit2d) 3 | ==================================================== 4 | .. _tuto_acquisition_operators_HadamSplit2d: 5 | 6 | This tutorial shows how to simulate measurements that correspond to the 2D Hadamard transform of an image. It based on the :class:`spyrit.core.meas.HadamSplit2d` class of the :mod:`spyrit.core.meas` submodule. 7 | 8 | 9 | .. image:: ../fig/tuto1.png 10 | :width: 600 11 | :align: center 12 | :alt: Reconstruction architecture sketch 13 | 14 | | 15 | 16 | In practice, only positive values can be implemented using a digital micromirror device (DMD). Therefore, we acquire 17 | 18 | .. math:: 19 | y = \texttt{vec}\left(AXA^T\right), 20 | 21 | where :math:`A \in \mathbb{R}_+^{2h\times h}` is the acquisition matrix that contains the positive and negative components of a Hadamard matrix and :math:`X \in \mathbb{R}^{h\times h}` is the (2D) image. 22 | 23 | We define the positive DMD patterns :math:`A` from the positive and negative components a Hadamard matrix :math:`H`. In practice, the even rows of :math:`A` contain the positive components of :math:`H`, while odd rows of :math:`A` contain the negative components of :math:`H`. 24 | 25 | .. math:: 26 | \begin{cases} 27 | A[0::2, :] = H_{+}, \text{ with } H_{+} = \max(0,H),\\ 28 | A[1::2, :] = H_{-}, \text{ with } H_{-} = \max(0,-H). 29 | \end{cases} 30 | 31 | """ 32 | 33 | # %% 34 | # Loads images 35 | # ----------------------------------------------------------------------------- 36 | 37 | ############################################################################### 38 | # We load a batch of images from the :attr:`/images/` folder with values in (0,1). 39 | import os 40 | import torchvision 41 | import torch.nn 42 | 43 | import matplotlib.pyplot as plt 44 | 45 | from spyrit.misc.disp import imagesc 46 | from spyrit.misc.statistics import transform_gray_norm 47 | 48 | spyritPath = os.getcwd() 49 | imgs_path = os.path.join(spyritPath, "images/") 50 | 51 | # Grayscale images of size 64 x 64, values in (-1,1) 52 | transform = transform_gray_norm(img_size=64) 53 | 54 | # Create dataset and loader (expects class folder 'images/test/') 55 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 56 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 57 | 58 | x, _ = next(iter(dataloader)) 59 | print(f"Ground-truth images: {x.shape}") 60 | 61 | ############################################################################### 62 | # We select the second image in the batch and plot it. 63 | 64 | i_plot = 1 65 | imagesc(x[i_plot, 0, :, :], r"$64\times 64$ image $X$") 66 | 67 | # %% 68 | # Basic example 69 | # ----------------------------------------------------------------------------- 70 | 71 | ###################################################################### 72 | # We instantiate an HadamSplit2d object and simulate the 2D hadamard transform of the input images. As measurements are split, this produces vectors of size :math:`64 \times 64 \times 2 = 8192`. 73 | from spyrit.core.meas import HadamSplit2d 74 | 75 | meas_op = HadamSplit2d(64) 76 | y = meas_op(x) 77 | 78 | print(y.shape) 79 | 80 | ###################################################################### 81 | # As with :class:`spyrit.core.meas.LinearSplit`, the :meth:`spyrit.core.HadamSplit2d.measure_H` method simulates measurements using the matrix :math:`H`, i.e., it computes :math:`m = \texttt{vec}\left(HXH^\top\right)`. This produces vectors of size :math:`64 \times 64 = 4096`. 82 | meas_op = HadamSplit2d(64) 83 | m = meas_op.measure_H(x) 84 | 85 | print(m.shape) 86 | 87 | ###################################################################### 88 | # We plot the components of the positive and negative Hadamard transform that are concatenated in the measurement vector :math:`y` as well as the measurement vector :math:`m`. 89 | 90 | from spyrit.misc.disp import add_colorbar, noaxis 91 | 92 | y_pos = y[:, :, 0::2] 93 | y_neg = y[:, :, 1::2] 94 | 95 | f, axs = plt.subplots(1, 3, figsize=(10, 5)) 96 | axs[0].set_title(r"$H_+XH_+^\top$") 97 | im = axs[0].imshow(y_pos[1, 0].reshape(64, 64), cmap="gray") 98 | add_colorbar(im, "bottom") 99 | 100 | axs[1].set_title(r"$H_-XH_-^\top$") 101 | im = axs[1].imshow( 102 | y_neg[ 103 | 1, 104 | 0, 105 | ].reshape(64, 64), 106 | cmap="gray", 107 | ) 108 | add_colorbar(im, "bottom") 109 | 110 | axs[2].set_title(r"$HXH^\top$") 111 | im = axs[2].imshow(m[1, 0].reshape(64, 64), cmap="gray") 112 | add_colorbar(im, "bottom") 113 | 114 | noaxis(axs) 115 | # sphinx_gallery_thumbnail_number = 2 116 | 117 | 118 | # %% 119 | # Subsampling 120 | # ---------------------------------------------------------------------- 121 | 122 | ###################################################################### 123 | # To reduce the acquisition time, only a few of the measurement can be acquired. In thise case, we simulate: 124 | # 125 | # .. math:: 126 | # y = \mathcal{S}\left(AXA^T\right), 127 | # 128 | # where :math:`\mathcal{S} \colon\, \mathbb{R}^{2h\times 2h} \to \mathbb{R}^{2M}` is a subsampling operator and :math:`2M < 2h` represents the number of DMD patterns that are displayed on the DMD. 129 | 130 | ###################################################################### 131 | # The subsampling operator :math:`\mathcal{S}` is defined by an order matrix :math:`O\in\mathbb{R}^{h\times h}` that ranks the measurements by decreasing significance, before retaining only the first :math:`M`. 132 | 133 | ###################################################################### 134 | # .. note:: 135 | # This process applies to both :math:`H_{+}XH_{+}^T` and :math:`H_{-}XH_{-}^T` the same way, independently. 136 | # 137 | # We consider two subsampling strategies: 138 | # 139 | # * The "naive" subsampling, which uses the linear (row-major) indexing order. This is the default subsampling strategy. 140 | # 141 | # * The variance subsampling, which sorts the Hadamard coefficient by decreasing variance. The motivation is that low variance coefficients are less informative than the others. This can be supported by principal component analysis, which states that preserving the components with largest variance leads to the best linear predictor. 142 | 143 | # %% 144 | # Naive subsampling 145 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 146 | 147 | ############################################################################### 148 | # The order matrix corresponding to the "naive" subsampling is given by linear values. 149 | Ord_naive = torch.arange(64 * 64, 0, step=-1).reshape(64, 64) 150 | print(Ord_naive) 151 | 152 | 153 | # %% 154 | # Variance subsampling 155 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 156 | 157 | ###################################################################### 158 | # The order matrix corresponding is obtained by computing the variance of the Hadamard coefficients of the images belonging to the `ImageNet 2012 dataset `_. 159 | 160 | ###################################################################### 161 | # First, we download the *covariance* matrix from our warehouse. The covariance was computed from the ImageNet 2012 dataset and has a size of (64*64, 64*64). 162 | 163 | from spyrit.misc.load_data import download_girder 164 | 165 | # url of the warehouse 166 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 167 | dataId = "672207cbf03a54733161e95d" # for reconstruction (imageNet, 64) 168 | data_folder = "./stat/" 169 | cov_name = "Cov_64x64.pt" 170 | 171 | # download the covariance matrix and get the file path 172 | file_abs_path = download_girder(url, dataId, data_folder, cov_name) 173 | 174 | try: 175 | # Load covariance matrix for "variance subsampling" 176 | Cov = torch.load(file_abs_path, weights_only=True) 177 | print(f"Cov matrix {cov_name} loaded") 178 | except: 179 | # Set to the identity if not found for "naive subsampling" 180 | Cov = torch.eye(64 * 64) 181 | print(f"Cov matrix {cov_name} not found! Set to the identity") 182 | 183 | ###################################################################### 184 | # Then, we extract the variance from the covariance matrix. The variance matrix has a size 185 | # of (64, 64). 186 | from spyrit.core.torch import Cov2Var 187 | 188 | Ord_variance = Cov2Var(Cov) 189 | 190 | ###################################################################### 191 | # .. note:: 192 | # In this tutorial, the covariance matrix is used to define the subsampling strategy. As explained in another tutorial, the covariance matrix can also be used to reconstruct the image from the measurements. 193 | 194 | # %% 195 | # Comparison of the two subsampling strategies 196 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 197 | 198 | ###################################################################### 199 | # We plot the masks corresponding to the two order matrices for a subsampling factor of 4, which corresponds to :math:`M = 64 \times 64 / 4 = 1024` measurements. 200 | 201 | # sphinx_gallery_thumbnail_number = 2 202 | 203 | ###################################################################### 204 | # We build the masks using the function :func:`spyrit.core.torch.sort_by_significance` and reshape them to the image size. 205 | from spyrit.core.torch import sort_by_significance 206 | 207 | M = 64 * 64 // 4 208 | mask_basis = torch.zeros(64 * 64) 209 | mask_basis[:M] = 1 210 | 211 | # Mask for the naive subsampling 212 | mask_nai = sort_by_significance(mask_basis, Ord_naive, axis="cols") 213 | mask_nai = mask_nai.reshape(64, 64) 214 | 215 | # Mask for the variance subsampling 216 | mask_var = sort_by_significance(mask_basis, Ord_variance, axis="cols") 217 | mask_var = mask_var.reshape(64, 64) 218 | 219 | ###################################################################### 220 | # We finally plot the masks. 221 | f, ax = plt.subplots(1, 2, figsize=(10, 5)) 222 | im = ax[0].imshow(mask_nai, vmin=0, vmax=1) 223 | ax[0].set_title("Mask \n'naive subsampling'", fontsize=20) 224 | add_colorbar(im, "bottom", size="20%") 225 | 226 | im = ax[1].imshow(mask_var, vmin=0, vmax=1) 227 | ax[1].set_title("Mask \n'variance subsampling'", fontsize=20) 228 | add_colorbar(im, "bottom", size="20%") 229 | 230 | noaxis(ax) 231 | 232 | # %% 233 | # Measurements for accelerated acquisitions 234 | # -------------------------------------------------------------------- 235 | 236 | ###################################################################### 237 | # We instantiate two HadamSplit2d objects corresponding to the two subsampling strategies. By default, the HadamSplit2d object uses the "naive" subsampling strategy. 238 | meas_nai = HadamSplit2d(64, M=M) 239 | 240 | ###################################################################### 241 | # For the variance subsampling strategy, we specify the order matrix using the :attr:`order` attribute. 242 | meas_var = HadamSplit2d(64, M=M, order=Ord_variance) 243 | 244 | ############################################################################### 245 | # We now simulate the measurements from both subsampling strategies. Here, we simulate measurements using the matrix :math:`H`, i.e., we compute :math:`m = HXH^\top`. This produces vectors of size :math:`M = 64 \times 64 / 4 = 1024`. 246 | 247 | m_nai = meas_nai.measure_H(x) 248 | m_var = meas_var.measure_H(x) 249 | 250 | print(f"Shape of measurement vectors: {m_nai.shape}") 251 | 252 | ############################################################################### 253 | # We transform the two measurement vectors as images in the Hadamard domain thanks to the function :meth:`spyrit.core.torch.meas2img`. 254 | 255 | from spyrit.core.torch import meas2img 256 | 257 | m_nai_plot = meas2img(m_nai, Ord_naive) 258 | m_var_plot = meas2img(m_var, Ord_variance) 259 | 260 | print(f"Shape of measurements: {m_nai_plot.shape}") 261 | 262 | ############################################################################### 263 | # We finally plot the measurements corresponding to one image in the batch. 264 | f, ax = plt.subplots(1, 2, figsize=(10, 5)) 265 | im = ax[0].imshow(m_nai_plot[i_plot, 0, :, :], cmap="gray") 266 | ax[0].set_title("Measurements \n 'Naive' subsampling", fontsize=20) 267 | add_colorbar(im, "bottom") 268 | 269 | im = ax[1].imshow(m_var_plot[i_plot, 0, :, :], cmap="gray") 270 | ax[1].set_title("Measurements \n Variance subsampling", fontsize=20) 271 | add_colorbar(im, "bottom") 272 | 273 | noaxis(ax) 274 | 275 | ############################################################################### 276 | # We can also simulate the split measurements, i.e., the measurement obtained from the positive and negative components of the Hadamard transform. This produces vectors of size :math:`2 M = 2 \times 64 \times 64 / 4 = 2048`. 277 | 278 | y_var = meas_var(x) 279 | print(f"Shape of split measurements: {y_var.shape}") 280 | 281 | 282 | ############################################################################### 283 | # We separate the positive and negative components of the split measurements. 284 | y_var_pos = y_var[..., ::2] # Even rows 285 | y_var_neg = y_var[..., 1::2] # Odd rows 286 | 287 | print(f"Shape of the positive component: {y_var_pos.shape}") 288 | print(f"Shape of the negative component: {y_var_neg.shape}") 289 | 290 | ############################################################################### 291 | # We now send the measurement vectors to Hadamard domain to plot them as images. 292 | m_plot_1 = meas2img(y_var_pos, Ord_variance) 293 | m_plot_2 = meas2img(y_var_neg, Ord_variance) 294 | 295 | print(f"Shape of the positive component: {m_plot_1.shape}") 296 | print(f"Shape of the negative component: {m_plot_2.shape}") 297 | 298 | ############################################################################### 299 | # We finally plot the measurements corresponding to one image in the batch 300 | f, ax = plt.subplots(1, 2, figsize=(10, 5)) 301 | im = ax[0].imshow(m_plot_1[i_plot, 0, :, :], cmap="gray") 302 | ax[0].set_title(r"$\mathcal{S}\left(H_+XH_+^\top\right)$", fontsize=20) 303 | add_colorbar(im, "bottom") 304 | 305 | im = ax[1].imshow(m_plot_2[i_plot, 0, :, :], cmap="gray") 306 | ax[1].set_title(r"$\mathcal{S}\left(H_-XH_-^\top\right)$", fontsize=20) 307 | add_colorbar(im, "bottom") 308 | 309 | noaxis(ax) 310 | -------------------------------------------------------------------------------- /tutorial/tuto_02_noise.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 02. Noise operators 3 | =================================================== 4 | .. _tuto_noise: 5 | 6 | This tutorial shows how to use noise operators using the :mod:`spyrit.core.noise` submodule. 7 | 8 | .. image:: ../fig/tuto2.png 9 | :width: 600 10 | :align: center 11 | :alt: Reconstruction architecture sketch 12 | 13 | | 14 | """ 15 | 16 | # %% 17 | # Load a batch of images 18 | # ----------------------------------------------------------------------------- 19 | 20 | ############################################################################### 21 | # We load a batch of images from the `/images/` folder. Using the 22 | # :func:`transform_gray_norm` function with the :attr:`normalize=False` 23 | # argument returns images with values in (0,1). 24 | import os 25 | 26 | import torch 27 | import torchvision 28 | import matplotlib.pyplot as plt 29 | 30 | from spyrit.misc.disp import imagesc 31 | from spyrit.misc.statistics import transform_gray_norm 32 | 33 | spyritPath = os.getcwd() 34 | imgs_path = os.path.join(spyritPath, "images/") 35 | 36 | # Grayscale images of size 64 x 64, no normalization to keep values in (0,1) 37 | transform = transform_gray_norm(img_size=64, normalize=False) 38 | 39 | # Create dataset and loader (expects class folder 'images/test/') 40 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 41 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 42 | 43 | x, _ = next(iter(dataloader)) 44 | print(f"Shape of input images: {x.shape}") 45 | 46 | 47 | ############################################################################### 48 | # We select the first image in the batch and plot it. 49 | 50 | i_plot = 1 51 | imagesc(x[i_plot, 0, :, :], r"$x$ in (0, 1)") 52 | 53 | 54 | # %% 55 | # Gaussian noise 56 | # ----------------------------------------------------------------------------- 57 | 58 | ############################################################################### 59 | # We consider additive Gaussiane noise, 60 | # 61 | # .. math:: 62 | # y \sim z + \mathcal{N}(0,\sigma^2), 63 | # 64 | # where :math:`\mathcal{N}(\mu, \sigma^2)` is a Gaussian distribution with mean :math:`\mu` and variance :math:`\sigma^2`, and :math:`z` is the noiseless image. The larger :math:`\sigma`, the lower the signal-to-noise ratio. 65 | 66 | ############################################################################### 67 | # To add 10% Gaussian noise, we instantiate a :class:`spyrit.core.noise` 68 | # operator with :attr:`sigma=0.1`. 69 | 70 | from spyrit.core.noise import Gaussian 71 | 72 | noise_op = Gaussian(sigma=0.1) 73 | x_noisy = noise_op(x) 74 | 75 | imagesc(x_noisy[1, 0, :, :], r"10% Gaussian noise") 76 | # sphinx_gallery_thumbnail_number = 2 77 | 78 | ############################################################################### 79 | # To add 2% Gaussian noise, we update the class attribute :attr:`sigma`. 80 | 81 | noise_op.sigma = 0.02 82 | x_noisy = noise_op(x) 83 | 84 | imagesc(x_noisy[1, 0, :, :], r"2% Gaussian noise") 85 | 86 | # %% 87 | # Poisson noise 88 | # ----------------------------------------------------------------------------- 89 | 90 | ############################################################################### 91 | # We now consider Poisson noise, 92 | # 93 | # .. math:: 94 | # y \sim \mathcal{P}(\alpha z), \quad z \ge 0, 95 | # 96 | # where :math:`\alpha \ge 0` is a scalar value that represents the maximum 97 | # image intensity (in photons). The larger :math:`\alpha`, the higher the signal-to-noise ratio. 98 | 99 | ############################################################################### 100 | # We consider the :class:`spyrit.core.noise.Poisson` class and set :math:`\alpha` 101 | # to 100 photons (which corresponds to the Poisson parameter). 102 | 103 | from spyrit.core.noise import Poisson 104 | from spyrit.misc.disp import add_colorbar, noaxis 105 | 106 | alpha = 100 # number of photons 107 | noise_op = Poisson(alpha) 108 | 109 | ############################################################################### 110 | # We simulate two noisy versions of the same images 111 | 112 | y1 = noise_op(x) # first sample 113 | y2 = noise_op(x) # another sample 114 | 115 | ############################################################################### 116 | # We now consider the case :math:`\alpha = 1000` photons. 117 | 118 | noise_op.alpha = 1000 119 | y3 = noise_op(x) # noisy measurement vector 120 | 121 | ############################################################################### 122 | # We finally plot the noisy images 123 | 124 | # plot 125 | f, axs = plt.subplots(1, 3, figsize=(10, 5)) 126 | axs[0].set_title("100 photons") 127 | im = axs[0].imshow(y1[1, 0].reshape(64, 64), cmap="gray") 128 | add_colorbar(im, "bottom") 129 | 130 | axs[1].set_title("100 photons") 131 | im = axs[1].imshow(y2[1, 0].reshape(64, 64), cmap="gray") 132 | add_colorbar(im, "bottom") 133 | 134 | axs[2].set_title("1000 photons") 135 | im = axs[2].imshow( 136 | y3[ 137 | 1, 138 | 0, 139 | ].reshape(64, 64), 140 | cmap="gray", 141 | ) 142 | add_colorbar(im, "bottom") 143 | 144 | noaxis(axs) 145 | 146 | ############################################################################### 147 | # As expected the signal-to-noise ratio of the measurement vector is higher for 148 | # 1,000 photons than for 100 photons 149 | # 150 | # .. note:: 151 | # Not only the signal-to-noise, but also the scale of the measurements 152 | # depends on :math:`\alpha`, which motivates the introduction of the 153 | # preprocessing operator. 154 | -------------------------------------------------------------------------------- /tutorial/tuto_03_pseudoinverse_linear.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 03. Pseudoinverse solution from linear measurements 3 | =================================================== 4 | .. _tuto_pseudoinverse_linear: 5 | 6 | This tutorial shows how to simulate measurements and perform image reconstruction using the :class:`spyrit.core.inverse.PseudoInverse` class of the :mod:`spyrit.core.inverse` submodule. 7 | 8 | .. image:: ../fig/tuto3_pinv.png 9 | :width: 600 10 | :align: center 11 | :alt: Reconstruction architecture sketch 12 | 13 | | 14 | """ 15 | 16 | # %% 17 | # Loads images 18 | # ----------------------------------------------------------------------------- 19 | 20 | ############################################################################### 21 | # We load a batch of images from the :attr:`/images/` folder. Using the 22 | # :func:`spyrit.misc.statistics.transform_gray_norm` function with the :attr:`normalize=False` 23 | # argument returns images with values in (0,1). 24 | import os 25 | import torchvision 26 | import torch.nn 27 | from spyrit.misc.statistics import transform_gray_norm 28 | 29 | spyritPath = os.getcwd() 30 | imgs_path = os.path.join(spyritPath, "images/") 31 | 32 | # Grayscale images of size 32 x 32, no normalization to keep values in (0,1) 33 | transform = transform_gray_norm(img_size=64, normalize=False) 34 | 35 | # Create dataset and loader (expects class folder 'images/test/') 36 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 37 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 38 | 39 | x, _ = next(iter(dataloader)) 40 | print(f"Ground-truth images: {x.shape}") 41 | 42 | 43 | # %% 44 | # Linear measurements without noise 45 | # ----------------------------------------------------------------------------- 46 | 47 | ############################################################################### 48 | # We consider a Hadamard matrix in "2D". The matrix has a shape of (64*64, 64*64)and values in {-1, 1}. 49 | from spyrit.core.torch import walsh_matrix_2d 50 | 51 | H = walsh_matrix_2d(64) 52 | 53 | print(f"Acquisition matrix: {H.shape}", end=" ") 54 | print(rf"with values in {{{H.min()}, {H.max()}}}") 55 | 56 | ############################################################################### 57 | # We instantiate a :class:`spyrit.core.meas.Linear` operator. To indicate that the operator works in 2D, on images with shape (64, 64), we use the :attr:`meas_shape` argument. 58 | from spyrit.core.meas import Linear 59 | 60 | meas_op = Linear(H, (64, 64)) 61 | 62 | ############################################################################### 63 | # We simulate the measurement vectors, which have a shape of (7, 1, 4096). 64 | y = meas_op(x) 65 | 66 | print(f"Measurement vectors: {y.shape}") 67 | 68 | ############################################################################### 69 | # We now compute the pseudo inverse solutions, which have a shape of (7, 1, 64, 64). 70 | from spyrit.core.inverse import PseudoInverse 71 | 72 | pinv = PseudoInverse(meas_op) 73 | x_rec = pinv(y) 74 | 75 | print(f"Reconstructed images: {x_rec.shape}") 76 | 77 | ############################################################################### 78 | # We plot the reconstruction of the second image in the batch 79 | from spyrit.misc.disp import imagesc, add_colorbar 80 | 81 | imagesc(x_rec[1, 0]) 82 | # sphinx_gallery_thumbnail_number = 1 83 | 84 | ############################################################################### 85 | # .. note:: 86 | # The measurement operator is chosen as a Hadamard matrix with positive but this matrix can be replaced by any other matrix. 87 | 88 | # %% 89 | # LinearSplit measurements with Gaussian noise 90 | # ----------------------------------------------------------------------------- 91 | 92 | ############################################################################### 93 | # We consider a linear operator where the positive and negative components are split, i.e. acquired separately. To do so, we instantiate a :class:`spyrit.core.meas.LinearSplit` operator. 94 | from spyrit.core.meas import LinearSplit 95 | 96 | meas_op = LinearSplit(H, (64, 64)) 97 | 98 | ############################################################################### 99 | # We consider additive Gaussian noise with standard deviation 2. 100 | from spyrit.core.noise import Gaussian 101 | 102 | meas_op.noise_model = Gaussian(2) 103 | 104 | ############################################################################### 105 | # We simulate the measurement vectors, which have shape (7, 1, 8192). 106 | y = meas_op(x) 107 | 108 | print(f"Measurement vectors: {y.shape}") 109 | 110 | ############################################################################### 111 | # We preprocess measurement vectors by computing the difference of the positive and negative components of the measurement vectors. To do so, we use the :class:`spyrit.core.prep.Unsplit` class. The preprocess measurements have a shape of (7, 1, 4096). 112 | 113 | from spyrit.core.prep import Unsplit 114 | 115 | prep = Unsplit() 116 | m = prep(y) 117 | 118 | print(f"Preprocessed measurement vectors: {m.shape}") 119 | 120 | ############################################################################### 121 | # We now compute the pseudo inverse solutions, which have a shape of (7, 1, 64, 64). 122 | from spyrit.core.inverse import PseudoInverse 123 | 124 | pinv = PseudoInverse(meas_op) 125 | x_rec = pinv(m) 126 | 127 | print(f"Reconstructed images: {x_rec.shape}") 128 | 129 | ############################################################################### 130 | # We plot the reconstruction 131 | from spyrit.misc.disp import imagesc, add_colorbar 132 | 133 | imagesc(x_rec[1, 0]) 134 | 135 | # %% 136 | # HadamSplit2d with x4 subsampling with Poisson noise 137 | # ----------------------------------------------------------------------------- 138 | 139 | ############################################################################### 140 | # We consider the acquisition of the 2D Hadamard transform of an image, where the positive and negative components of acquisition matrix are acquired separately. To do so, we use the dedicated :class:`spyrit.core.meas.HadamSplit2d` operator. It also allows for subsampling the rows the Hadamard matrix, using a sampling map. 141 | 142 | from spyrit.core.meas import HadamSplit2d 143 | 144 | # Sampling map with ones in the top left corner and zeros elsewhere (low-frequency subsampling) 145 | sampling_map = torch.ones((64, 64)) 146 | sampling_map[:, 64 // 2 :] = 0 147 | sampling_map[64 // 2 :, :] = 0 148 | 149 | # Linear operator with HadamSplit2d 150 | meas_op = HadamSplit2d(64, 64**2 // 4, order=sampling_map, reshape_output=True) 151 | 152 | ############################################################################### 153 | # We consider additive Poisson noise with an intensity of 100 photons. 154 | from spyrit.core.noise import Poisson 155 | 156 | meas_op.noise_model = Poisson(100) 157 | 158 | 159 | ############################################################################### 160 | # We simulate the measurement vectors, which have a shape of (7, 1, 2048) 161 | 162 | ############################################################################### 163 | # .. note:: 164 | # The :class:`spyrit.core.noise.Poisson` class noise assumes that the images are in the range [0, 1] 165 | y = meas_op(x) 166 | 167 | print(rf"Reference images with values in {{{x.min()}, {x.max()}}}") 168 | print(f"Measurement vectors: {y.shape}") 169 | 170 | ############################################################################### 171 | # We preprocess measurement vectors by i) computing the difference of the positive and negative components, and ii) normalizing the intensity. To do so, we use the :class:`spyrit.core.prep.UnsplitRescale` class. The preprocessed measurements have a shape of (7, 1, 1024). 172 | 173 | from spyrit.core.prep import UnsplitRescale 174 | 175 | prep = UnsplitRescale(100) 176 | 177 | m = prep(y) # (y+ - y-)/alpha 178 | print(f"Preprocessed measurement vectors: {m.shape}") 179 | 180 | ############################################################################### 181 | # We compute the pseudo inverse solution, which has a shape of (7, 1, 64, 64). 182 | 183 | x_rec = meas_op.fast_pinv(m) 184 | 185 | print(f"Reconstructed images: {x_rec.shape}") 186 | 187 | ############################################################################### 188 | # .. note:: 189 | # There is no need to use the :class:`spyrit.core.inverse.PseudoInverse` class here, as the :class:`spyrit.core.meas.HadamSplit2d` class includes a method that returns the pseudo inverse solution. 190 | 191 | ############################################################################### 192 | # We plot the reconstruction 193 | from spyrit.misc.disp import imagesc, add_colorbar 194 | 195 | imagesc(x_rec[1, 0]) 196 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_03_pseudoinverse_cnn_linear.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 03. Pseudoinverse solution + CNN denoising 3 | ========================================== 4 | .. _tuto_pseudoinverse_cnn_linear: 5 | 6 | This tutorial shows how to simulate measurements and perform image reconstruction 7 | using PinvNet (pseudoinverse linear network) with CNN denoising as a last layer. 8 | This tutorial is a continuation of the :ref:`Pseudoinverse solution tutorial ` 9 | but uses a CNN denoiser instead of the identity operator in order to remove artefacts. 10 | 11 | The measurement operator is chosen as a Hadamard matrix with positive coefficients, 12 | which can be replaced by any matrix. 13 | 14 | .. image:: ../fig/tuto3.png 15 | :width: 600 16 | :align: center 17 | :alt: Reconstruction and neural network denoising architecture sketch 18 | 19 | These tutorials load image samples from `/images/`. 20 | """ 21 | 22 | # %% 23 | # Load a batch of images 24 | # ----------------------------------------------------------------------------- 25 | 26 | ############################################################################### 27 | # Images :math:`x` for training expect values in [-1,1]. The images are normalized 28 | # using the :func:`transform_gray_norm` function. 29 | if False: 30 | 31 | import os 32 | 33 | import torch 34 | import torchvision 35 | import matplotlib.pyplot as plt 36 | 37 | import spyrit.core.torch as spytorch 38 | from spyrit.misc.disp import imagesc 39 | from spyrit.misc.statistics import transform_gray_norm 40 | 41 | # sphinx_gallery_thumbnail_path = 'fig/tuto3.png' 42 | 43 | h = 64 # image size hxh 44 | i = 1 # Image index (modify to change the image) 45 | spyritPath = os.getcwd() 46 | imgs_path = os.path.join(spyritPath, "images/") 47 | 48 | # Create a transform for natural images to normalized grayscale image tensors 49 | transform = transform_gray_norm(img_size=h) 50 | 51 | # Create dataset and loader (expects class folder 'images/test/') 52 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 53 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 54 | 55 | x, _ = next(iter(dataloader)) 56 | print(f"Shape of input images: {x.shape}") 57 | 58 | # Select image 59 | x = x[i : i + 1, :, :, :] 60 | x = x.detach().clone() 61 | print(f"Shape of selected image: {x.shape}") 62 | b, c, h, w = x.shape 63 | 64 | # plot 65 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") 66 | 67 | # %% 68 | # Define a measurement operator 69 | # ----------------------------------------------------------------------------- 70 | 71 | ############################################################################### 72 | # We consider the case where the measurement matrix is the positive 73 | # component of a Hadamard matrix and the sampling operator preserves only 74 | # the first :attr:`M` low-frequency coefficients 75 | # (see :ref:`Positive Hadamard matrix ` for full explantion). 76 | 77 | import math 78 | 79 | F = spytorch.walsh_matrix_2d(h) 80 | F = torch.max(F, torch.zeros_like(F)) 81 | 82 | und = 4 # undersampling factor 83 | M = h**2 // und # number of measurements (undersampling factor = 4) 84 | 85 | Sampling_map = torch.zeros(h, h) 86 | M_xy = math.ceil(M**0.5) 87 | Sampling_map[:M_xy, :M_xy] = 1 88 | 89 | imagesc(Sampling_map, "low-frequency sampling map") 90 | 91 | ############################################################################### 92 | # After permutation of the full Hadamard matrix, we keep only its first 93 | # :attr:`M` rows 94 | 95 | F = spytorch.sort_by_significance(F, Sampling_map, "rows", False) 96 | H = F[:M, :] 97 | 98 | print(f"Shape of the measurement matrix: {H.shape}") 99 | 100 | ############################################################################### 101 | # Then, we instantiate a :class:`spyrit.core.meas.Linear` measurement operator 102 | 103 | from spyrit.core.meas import Linear 104 | 105 | meas_op = Linear(H, pinv=True) 106 | 107 | # %% 108 | # Noiseless case 109 | # ----------------------------------------------------------------------------- 110 | 111 | ############################################################################### 112 | # In the noiseless case, we consider the :class:`spyrit.core.noise.NoNoise` noise 113 | # operator 114 | 115 | from spyrit.core.noise import NoNoise 116 | 117 | N0 = 1.0 # Noise level (noiseless) 118 | noise = NoNoise(meas_op) 119 | 120 | # Simulate measurements 121 | y = noise(x) 122 | print(f"Shape of raw measurements: {y.shape}") 123 | 124 | ############################################################################### 125 | # We now compute and plot the preprocessed measurements corresponding to an 126 | # image in [-1,1] 127 | 128 | from spyrit.core.prep import DirectPoisson 129 | 130 | prep = DirectPoisson(N0, meas_op) # "Undo" the NoNoise operator 131 | 132 | m = prep(y) 133 | print(f"Shape of the preprocessed measurements: {m.shape}") 134 | 135 | ############################################################################### 136 | # To display the subsampled measurement vector as an image in the transformed 137 | # domain, we use the :func:`spyrit.core.torch.meas2img` function 138 | 139 | # plot 140 | m_plot = spytorch.meas2img(m, Sampling_map) 141 | print(f"Shape of the preprocessed measurement image: {m_plot.shape}") 142 | 143 | imagesc(m_plot[0, 0, :, :], "Preprocessed measurements (no noise)") 144 | 145 | # %% 146 | # PinvNet Network 147 | # ----------------------------------------------------------------------------- 148 | 149 | ############################################################################### 150 | # We consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs an 151 | # image by computing the pseudoinverse solution, which is fed to a neural 152 | # network denoiser. To compute the pseudoinverse solution only, the denoiser 153 | # can be set to the identity operator 154 | 155 | from spyrit.core.recon import PinvNet 156 | 157 | pinv_net = PinvNet(noise, prep, denoi=torch.nn.Identity()) 158 | 159 | ############################################################################### 160 | # or equivalently 161 | pinv_net = PinvNet(noise, prep) 162 | 163 | ############################################################################### 164 | # Then, we reconstruct the image from the measurement vector :attr:`y` using the 165 | # :func:`~spyrit.core.recon.PinvNet.reconstruct` method 166 | 167 | x_rec = pinv_net.reconstruct(y) 168 | 169 | # %% 170 | # Removing artefacts with a CNN 171 | # ----------------------------------------------------------------------------- 172 | 173 | ############################################################################### 174 | # Artefacts can be removed by selecting a neural network denoiser 175 | # (last layer of PinvNet). We select a simple CNN using the 176 | # :class:`spyrit.core.nnet.ConvNet` class, but this can be replaced by any 177 | # neural network (eg. UNet from :class:`spyrit.core.nnet.Unet`). 178 | 179 | ############################################################################### 180 | # .. image:: ../fig/pinvnet_cnn.png 181 | # :width: 400 182 | # :align: center 183 | # :alt: Sketch of the PinvNet with CNN architecture 184 | 185 | from spyrit.core.nnet import ConvNet 186 | from spyrit.core.train import load_net 187 | 188 | # Define PInvNet with ConvNet denoising layer 189 | denoi = ConvNet() 190 | pinv_net_cnn = PinvNet(noise, prep, denoi) 191 | 192 | # Send to GPU if available 193 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 194 | print("Using device:", device) 195 | pinv_net_cnn = pinv_net_cnn.to(device) 196 | 197 | ############################################################################### 198 | # As an example, we use a simple ConvNet that has been pretrained using STL-10 dataset. 199 | # We download the pretrained weights and load them into the network. 200 | 201 | from spyrit.misc.load_data import download_girder 202 | 203 | # Load pretrained model 204 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 205 | dataID = "67221889f03a54733161e963" # unique ID of the file 206 | local_folder = "./model/" 207 | data_name = "tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth" 208 | # download the model and save it in the local folder 209 | model_cnn_path = download_girder(url, dataID, local_folder, data_name) 210 | 211 | # Load model weights 212 | load_net(model_cnn_path, pinv_net_cnn, device, False) 213 | 214 | ############################################################################### 215 | # We now reconstruct the image using PinvNet with pretrained CNN denoising 216 | # and plot results side by side with the PinvNet without denoising 217 | 218 | from spyrit.misc.disp import add_colorbar, noaxis 219 | 220 | with torch.no_grad(): 221 | x_rec_cnn = pinv_net_cnn.reconstruct(y.to(device)) 222 | x_rec_cnn = pinv_net_cnn(x.to(device)) 223 | 224 | # plot 225 | f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) 226 | 227 | im1 = ax1.imshow(x[0, 0, :, :], cmap="gray") 228 | ax1.set_title("Ground-truth image", fontsize=20) 229 | noaxis(ax1) 230 | add_colorbar(im1, "bottom", size="20%") 231 | 232 | im2 = ax2.imshow(x_rec[0, 0, :, :], cmap="gray") 233 | ax2.set_title("Pinv reconstruction", fontsize=20) 234 | noaxis(ax2) 235 | add_colorbar(im2, "bottom", size="20%") 236 | 237 | im3 = ax3.imshow(x_rec_cnn.cpu()[0, 0, :, :], cmap="gray") 238 | ax3.set_title(f"Pinv + CNN (trained 30 epochs", fontsize=20) 239 | noaxis(ax3) 240 | add_colorbar(im3, "bottom", size="20%") 241 | 242 | plt.show() 243 | 244 | ############################################################################### 245 | # We show the best result again (tutorial thumbnail purpose) 246 | 247 | # Plot 248 | imagesc( 249 | x_rec_cnn.cpu()[0, 0, :, :], f"Pinv + CNN (trained 30 epochs", title_fontsize=20 250 | ) 251 | 252 | ############################################################################### 253 | # In the next tutorial, we will show how to train PinvNet + CNN denoiser. 254 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_04_train_pseudoinverse_cnn_linear.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 04. Train pseudoinverse solution + CNN denoising 3 | ================================================ 4 | .. _tuto_train_pseudoinverse_cnn_linear: 5 | 6 | This tutorial shows how to train PinvNet with a CNN denoiser for 7 | reconstruction of linear measurements (results shown in the 8 | :ref:`previous tutorial `). 9 | As an example, we use a small CNN, which can be replaced by any other network, 10 | for example Unet. Training is performed on the STL-10 dataset. 11 | 12 | You can use Tensorboard for Pytorch for experiment tracking and 13 | for visualizing the training process: losses, network weights, 14 | and intermediate results (reconstructed images at different epochs). 15 | 16 | The linear measurement operator is chosen as the positive part of a Hadamard matrix, 17 | but this matrix can be replaced by any desired matrix. 18 | 19 | These tutorials load image samples from `/images/`. 20 | """ 21 | 22 | # %% 23 | # Load a batch of images 24 | # ----------------------------------------------------------------------------- 25 | 26 | ############################################################################### 27 | # First, we load an image :math:`x` and normalized it to [-1,1], as in previous examples. 28 | if False: 29 | 30 | import os 31 | 32 | import torch 33 | import torchvision 34 | import matplotlib.pyplot as plt 35 | 36 | import spyrit.core.torch as spytorch 37 | from spyrit.misc.disp import imagesc 38 | from spyrit.misc.statistics import transform_gray_norm 39 | 40 | h = 64 # image size hxh 41 | i = 1 # Image index (modify to change the image) 42 | spyritPath = os.getcwd() 43 | imgs_path = os.path.join(spyritPath, "images/") 44 | 45 | # Create a transform for natural images to normalized grayscale image tensors 46 | transform = transform_gray_norm(img_size=h) 47 | 48 | # Create dataset and loader (expects class folder 'images/test/') 49 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 50 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 51 | 52 | x, _ = next(iter(dataloader)) 53 | print(f"Shape of input images: {x.shape}") 54 | 55 | # Select image 56 | x = x[i : i + 1, :, :, :] 57 | x = x.detach().clone() 58 | print(f"Shape of selected image: {x.shape}") 59 | b, c, h, w = x.shape 60 | 61 | # plot 62 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") 63 | 64 | # %% 65 | # Define a dataloader 66 | # ----------------------------------------------------------------------------- 67 | # We define a dataloader for STL-10 dataset using :func:`spyrit.misc.statistics.data_loaders_stl10`. 68 | # This will download the dataset to the provided path if it is not already downloaded. 69 | # It is based on pytorch pre-loaded dataset :class:`torchvision.datasets.STL10` and 70 | # :class:`torch.utils.data.DataLoader`, which creates a generator that iterates 71 | # through the dataset, returning a batch of images and labels at each iteration. 72 | # 73 | # Set :attr:`mode_run` to True in the script below to download the dataset and for training; 74 | # otherwise, pretrained weights and results will be download for display. 75 | 76 | from spyrit.misc.statistics import data_loaders_stl10 77 | from pathlib import Path 78 | 79 | # Parameters 80 | h = 64 # image size hxh 81 | data_root = Path("./data") # path to data folder (where the dataset is stored) 82 | batch_size = 512 83 | 84 | # Dataloader for STL-10 dataset 85 | mode_run = False 86 | if mode_run: 87 | dataloaders = data_loaders_stl10( 88 | data_root, 89 | img_size=h, 90 | batch_size=batch_size, 91 | seed=7, 92 | shuffle=True, 93 | download=True, 94 | ) 95 | 96 | # %% 97 | # Define a measurement operator 98 | # ----------------------------------------------------------------------------- 99 | 100 | ############################################################################### 101 | # We consider the case where the measurement matrix is the positive 102 | # component of a Hadamard matrix, which is often used in single-pixel imaging 103 | # (see :ref:`Hadamard matrix `). 104 | # Then, we simulate an accelerated acquisition by keeping only the first 105 | # :attr:`M` low-frequency coefficients (see :ref:`low frequency sampling `). 106 | 107 | import math 108 | 109 | und = 4 # undersampling factor 110 | M = h**2 // und # number of measurements (undersampling factor = 4) 111 | 112 | F = spytorch.walsh_matrix_2d(h) 113 | F = torch.max(F, torch.zeros_like(F)) 114 | 115 | Sampling_map = torch.zeros(h, h) 116 | M_xy = math.ceil(M**0.5) 117 | Sampling_map[:M_xy, :M_xy] = 1 118 | 119 | # imagesc(Sampling_map, 'low-frequency sampling map') 120 | 121 | F = spytorch.sort_by_significance(F, Sampling_map, "rows", False) 122 | H = F[:M, :] 123 | 124 | print(f"Shape of the measurement matrix: {H.shape}") 125 | 126 | ############################################################################### 127 | # Then, we instantiate a :class:`spyrit.core.meas.Linear` measurement operator, 128 | # a :class:`spyrit.core.noise.NoNoise` noise operator for noiseless case, 129 | # and a preprocessing measurements operator :class:`spyrit.core.prep.DirectPoisson`. 130 | 131 | from spyrit.core.meas import Linear 132 | from spyrit.core.noise import NoNoise 133 | from spyrit.core.prep import DirectPoisson 134 | 135 | meas_op = Linear(H, pinv=True) 136 | noise = NoNoise(meas_op) 137 | N0 = 1.0 # Mean maximum total number of photons 138 | prep = DirectPoisson(N0, meas_op) # "Undo" the NoNoise operator 139 | 140 | # %% 141 | # PinvNet Network 142 | # ----------------------------------------------------------------------------- 143 | 144 | ############################################################################### 145 | # We consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs an 146 | # image by computing the pseudoinverse solution and applies a nonlinear 147 | # network denoiser. First, we must define the denoiser. As an example, 148 | # we choose a small CNN using the :class:`spyrit.core.nnet.ConvNet` class. 149 | # Then, we define the PinvNet network by passing the noise and preprocessing operators 150 | # and the denoiser. 151 | 152 | from spyrit.core.nnet import ConvNet 153 | from spyrit.core.recon import PinvNet 154 | 155 | denoiser = ConvNet() 156 | model = PinvNet(noise, prep, denoi=denoiser) 157 | 158 | # Send to GPU if available 159 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 160 | 161 | # Use multiple GPUs if available 162 | if torch.cuda.device_count() > 1: 163 | print("Let's use", torch.cuda.device_count(), "GPUs!") 164 | model = nn.DataParallel(model) 165 | 166 | print("Using device:", device) 167 | model = model.to(device) 168 | 169 | ############################################################################### 170 | # .. note:: 171 | # 172 | # In the example provided, we choose a small CNN using the :class:`spyrit.core.nnet.ConvNet` class. 173 | # This can be replaced by any denoiser, for example the :class:`spyrit.core.nnet.Unet` class 174 | # or a custom denoiser. 175 | 176 | # %% 177 | # Define a Loss function optimizer and scheduler 178 | # ----------------------------------------------------------------------------- 179 | 180 | ############################################################################### 181 | # In order to train the network, we need to define a loss function, an optimizer 182 | # and a scheduler. We use the Mean Square Error (MSE) loss function, weigh decay 183 | # loss and the Adam optimizer. The scheduler decreases the learning rate 184 | # by a factor of :attr:`gamma` every :attr:`step_size` epochs. 185 | 186 | import torch.nn as nn 187 | import torch.optim as optim 188 | from torch.optim import lr_scheduler 189 | from spyrit.core.train import save_net, Weight_Decay_Loss 190 | 191 | # Parameters 192 | lr = 1e-3 193 | step_size = 10 194 | gamma = 0.5 195 | 196 | loss = nn.MSELoss() 197 | criterion = Weight_Decay_Loss(loss) 198 | optimizer = optim.Adam(model.parameters(), lr=lr) 199 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 200 | 201 | # %% 202 | # Train the network 203 | # ----------------------------------------------------------------------------- 204 | 205 | ############################################################################### 206 | # To train the network, we use the :func:`~spyrit.core.train.train_model` function, 207 | # which handles the training process. It iterates through the dataloader, feeds the inputs to the 208 | # network and optimizes the solution (by computing the loss and its gradients and 209 | # updating the network weights at each iteration). In addition, it computes 210 | # the loss and desired metrics on the training and validation sets at each iteration. 211 | # The training process can be monitored using Tensorboard. 212 | 213 | ############################################################################### 214 | # .. note:: 215 | # 216 | # To launch Tensorboard type in a new console: 217 | # 218 | # tensorboard --logdir runs 219 | # 220 | # and open the provided link in a browser. The training process can be monitored 221 | # in real time in the "Scalars" tab. The "Images" tab allows to visualize the 222 | # reconstructed images at different iterations :attr:`tb_freq`. 223 | 224 | ############################################################################### 225 | # In order to train, you must set :attr:`mode_run` to True for training. It is set to False 226 | # by default to download the pretrained weights and results for display, 227 | # as training takes around 40 min for 30 epochs. 228 | 229 | # We train for one epoch only to check that everything works fine. 230 | 231 | from spyrit.core.train import train_model 232 | from datetime import datetime 233 | 234 | # Parameters 235 | model_root = Path("./model") # path to model saving files 236 | num_epochs = 5 # number of training epochs (num_epochs = 30) 237 | checkpoint_interval = 2 # interval between saving model checkpoints 238 | tb_freq = 50 # interval between logging to Tensorboard (iterations through the dataloader) 239 | 240 | # Path for Tensorboard experiment tracking logs 241 | name_run = "stdl10_hadampos" 242 | now = datetime.now().strftime("%Y-%m-%d_%H-%M") 243 | tb_path = f"runs/runs_{name_run}_n{int(N0)}_m{M}/{now}" 244 | 245 | # Train the network 246 | if mode_run: 247 | model, train_info = train_model( 248 | model, 249 | criterion, 250 | optimizer, 251 | scheduler, 252 | dataloaders, 253 | device, 254 | model_root, 255 | num_epochs=num_epochs, 256 | disp=True, 257 | do_checkpoint=checkpoint_interval, 258 | tb_path=tb_path, 259 | tb_freq=tb_freq, 260 | ) 261 | else: 262 | train_info = {} 263 | 264 | # %% 265 | # Save the network and training history 266 | # ----------------------------------------------------------------------------- 267 | 268 | ############################################################################### 269 | # We save the model so that it can later be utilized. We save the network's 270 | # architecture, the training parameters and the training history. 271 | 272 | from spyrit.core.train import save_net 273 | 274 | # Training parameters 275 | train_type = "N0_{:g}".format(N0) 276 | arch = "pinv-net" 277 | denoi = "cnn" 278 | data = "stl10" 279 | reg = 1e-7 # Default value 280 | suffix = "N_{}_M_{}_epo_{}_lr_{}_sss_{}_sdr_{}_bs_{}".format( 281 | h, M, num_epochs, lr, step_size, gamma, batch_size 282 | ) 283 | title = model_root / f"{arch}_{denoi}_{data}_{train_type}_{suffix}" 284 | print(title) 285 | 286 | Path(model_root).mkdir(parents=True, exist_ok=True) 287 | 288 | if checkpoint_interval: 289 | Path(title).mkdir(parents=True, exist_ok=True) 290 | 291 | save_net(str(title) + ".pth", model) 292 | 293 | # Save training history 294 | import pickle 295 | 296 | if mode_run: 297 | from spyrit.core.train import Train_par 298 | 299 | params = Train_par(batch_size, lr, h, reg=reg) 300 | params.set_loss(train_info) 301 | 302 | train_path = ( 303 | model_root / f"TRAIN_{arch}_{denoi}_{data}_{train_type}_{suffix}.pkl" 304 | ) 305 | 306 | with open(train_path, "wb") as param_file: 307 | pickle.dump(params, param_file) 308 | torch.cuda.empty_cache() 309 | 310 | else: 311 | from spyrit.misc.load_data import download_girder 312 | 313 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 314 | dataID = "667ebfe4baa5a90007058964" # unique ID of the file 315 | data_name = "tuto4_TRAIN_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pkl" 316 | train_path = os.path.join(model_root, data_name) 317 | # download girder file 318 | download_girder(url, dataID, model_root, data_name) 319 | 320 | with open(train_path, "rb") as param_file: 321 | params = pickle.load(param_file) 322 | train_info["train"] = params.train_loss 323 | train_info["val"] = params.val_loss 324 | 325 | ############################################################################### 326 | # We plot the training loss and validation loss 327 | 328 | # Plot 329 | # sphinx_gallery_thumbnail_number = 2 330 | 331 | fig = plt.figure() 332 | plt.plot(train_info["train"], label="train") 333 | plt.plot(train_info["val"], label="val") 334 | plt.xlabel("Epochs", fontsize=20) 335 | plt.ylabel("Loss", fontsize=20) 336 | plt.legend(fontsize=20) 337 | plt.show() 338 | 339 | ############################################################################### 340 | # .. note:: 341 | # 342 | # See the googlecolab notebook `spyrit-examples/tutorial/tuto_train_lin_meas_colab.ipynb `_ 343 | # for training a reconstruction network on GPU. It shows how to train 344 | # using different architectures, denoisers and other hyperparameters from 345 | # :func:`~spyrit.core.train.train_model` function. 346 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_05_recon_hadamSplit.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | if False: 4 | 5 | # %% 6 | # Split measurement operator and no noise 7 | # ----------------------------------------------------------------------------- 8 | # .. _split_measurements: 9 | 10 | ############################################################################### 11 | # .. math:: 12 | # y = P\tilde{x}= \begin{bmatrix} H_{+} \\ H_{-} \end{bmatrix} \tilde{x}. 13 | 14 | ############################################################################### 15 | # Hadamard split measurement operator is defined in the :class:`spyrit.core.meas.HadamSplit` class. 16 | # It computes linear measurements from incoming images, where :math:`P` is a 17 | # linear operator (matrix) with positive entries and :math:`\tilde{x}` is an image. 18 | # The class relies on a matrix :math:`H` with 19 | # shape :math:`(M,N)` where :math:`N` represents the number of pixels in the 20 | # image and :math:`M \le N` the number of measurements. The matrix :math:`P` 21 | # is obtained by splitting the matrix :math:`H` as :math:`H = H_{+}-H_{-}` where 22 | # :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. 23 | 24 | # %% 25 | # Measurement and noise operators 26 | # ----------------------------------------------------------------------------- 27 | 28 | ############################################################################### 29 | # We compute the measurement and noise operators and then 30 | # simulate the measurement vector :math:`y`. 31 | 32 | ############################################################################### 33 | # We consider Poisson noise, i.e., a noisy measurement vector given by 34 | # 35 | # .. math:: 36 | # y \sim \mathcal{P}(\alpha P \tilde{x}), 37 | # 38 | # where :math:`\alpha` is a scalar value that represents the maximum image intensity 39 | # (in photons). The larger :math:`\alpha`, the higher the signal-to-noise ratio. 40 | 41 | ############################################################################### 42 | # We use the :class:`spyrit.core.noise.Poisson` class, set :math:`\alpha` 43 | # to 100 photons, and simulate a noisy measurement vector for the two sampling 44 | # strategies. Subsampling is handled internally by the :class:`~spyrit.core.meas.HadamSplit` class. 45 | 46 | from spyrit.core.noise import Poisson 47 | from spyrit.core.meas import HadamSplit 48 | 49 | alpha = 100.0 # number of photons 50 | 51 | # "Naive subsampling" 52 | # Measurement and noise operators 53 | meas_nai_op = HadamSplit(M, h, Ord_naive) 54 | noise_nai_op = Poisson(meas_nai_op, alpha) 55 | 56 | # Measurement operator 57 | y_nai = noise_nai_op(x) # a noisy measurement vector 58 | 59 | # "Variance subsampling" 60 | meas_var_op = HadamSplit(M, h, Ord_variance) 61 | noise_var_op = Poisson(meas_var_op, alpha) 62 | y_var = noise_var_op(x) # a noisy measurement vector 63 | 64 | print(f"Shape of image: {x.shape}") 65 | print(f"Shape of simulated measurements y: {y_var.shape}") 66 | 67 | # %% 68 | # The preprocessing operator measurements for split measurements 69 | # ----------------------------------------------------------------------------- 70 | 71 | ############################################################################### 72 | # We compute the preprocessing operators for the three cases considered above, 73 | # using the :mod:`spyrit.core.prep` module. As previously introduced, 74 | # a preprocessing operator applies to the noisy measurements in order to 75 | # compensate for the scaling factors that appear in the measurement or noise operators: 76 | # 77 | # .. math:: 78 | # m = \texttt{Prep}(y), 79 | 80 | ############################################################################### 81 | # We consider the :class:`spyrit.core.prep.SplitPoisson` class that intends 82 | # to "undo" the :class:`spyrit.core.noise.Poisson` class, for split measurements, by compensating for 83 | # 84 | # * the scaling that appears when computing Poisson-corrupted measurements 85 | # 86 | # * the affine transformation to get images in [0,1] from images in [-1,1] 87 | # 88 | # For this, it computes 89 | # 90 | # .. math:: 91 | # m = \frac{2(y_+-y_-)}{\alpha} - P\mathbb{1}, 92 | # 93 | # where :math:`y_+=H_+\tilde{x}` and :math:`y_-=H_-\tilde{x}`. 94 | # This is handled internally by the :class:`spyrit.core.prep.SplitPoisson` class. 95 | 96 | ############################################################################### 97 | # We compute the preprocessing operator and the measurements vectors for 98 | # the two sampling strategies. 99 | 100 | from spyrit.core.prep import SplitPoisson 101 | 102 | # "Naive subsampling" 103 | # 104 | # Preprocessing operator 105 | prep_nai_op = SplitPoisson(alpha, meas_nai_op) 106 | 107 | # Preprocessed measurements 108 | m_nai = prep_nai_op(y_nai) 109 | 110 | # "Variance subsampling" 111 | prep_var_op = SplitPoisson(alpha, meas_var_op) 112 | m_var = prep_var_op(y_var) 113 | 114 | # %% 115 | # Noiseless measurements 116 | # ----------------------------------------------------------------------------- 117 | 118 | ############################################################################### 119 | # We consider now noiseless measurements for the "naive subsampling" strategy. 120 | # We compute the required operators and the noiseless measurement vector. 121 | # For this we use the :class:`spyrit.core.noise.NoNoise` class, which normalizes 122 | # the input image to get an image in [0,1], as explained in 123 | # :ref:`acquisition operators tutorial `. 124 | # For the preprocessing operator, we assign the number of photons equal to one. 125 | 126 | from spyrit.core.noise import NoNoise 127 | 128 | nonoise_nai_op = NoNoise(meas_nai_op) 129 | y_nai_nonoise = nonoise_nai_op(x) # a noisy measurement vector 130 | 131 | prep_nonoise_op = SplitPoisson(1.0, meas_nai_op) 132 | m_nai_nonoise = prep_nonoise_op(y_nai_nonoise) 133 | 134 | ############################################################################### 135 | # We can now plot the three measurement vectors 136 | 137 | # Plot the three measurement vectors 138 | m_plot = meas2img(m_nai_nonoise, Ord_naive) 139 | m_plot2 = meas2img(m_nai, Ord_naive) 140 | m_plot3 = spytorch.meas2img(m_var, Ord_variance) 141 | 142 | m_plot_max = m_plot[0, 0, :, :].max() 143 | m_plot_min = m_plot[0, 0, :, :].min() 144 | 145 | f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 7)) 146 | im1 = ax1.imshow(m_plot[0, 0, :, :], cmap="gray") 147 | ax1.set_title("Noiseless measurements $m$ \n 'Naive' subsampling", fontsize=20) 148 | noaxis(ax1) 149 | add_colorbar(im1, "bottom", size="20%") 150 | 151 | im2 = ax2.imshow(m_plot2[0, 0, :, :], cmap="gray", vmin=m_plot_min, vmax=m_plot_max) 152 | ax2.set_title("Measurements $m$ \n 'Naive' subsampling", fontsize=20) 153 | noaxis(ax2) 154 | add_colorbar(im2, "bottom", size="20%") 155 | 156 | im3 = ax3.imshow(m_plot3[0, 0, :, :], cmap="gray", vmin=m_plot_min, vmax=m_plot_max) 157 | ax3.set_title("Measurements $m$ \n 'Variance' subsampling", fontsize=20) 158 | noaxis(ax3) 159 | add_colorbar(im3, "bottom", size="20%") 160 | 161 | plt.show() 162 | 163 | # %% 164 | # PinvNet network 165 | # ----------------------------------------------------------------------------- 166 | 167 | ############################################################################### 168 | # We use the :class:`spyrit.core.recon.PinvNet` class where 169 | # the pseudo inverse reconstruction is performed by a neural network 170 | 171 | from spyrit.core.recon import PinvNet 172 | 173 | pinvnet_nai_nonoise = PinvNet(nonoise_nai_op, prep_nonoise_op) 174 | pinvnet_nai = PinvNet(noise_nai_op, prep_nai_op) 175 | pinvnet_var = PinvNet(noise_var_op, prep_var_op) 176 | 177 | # Reconstruction 178 | z_nai_nonoise = pinvnet_nai_nonoise.reconstruct(y_nai_nonoise) 179 | z_nai = pinvnet_nai.reconstruct(y_nai) 180 | z_var = pinvnet_var.reconstruct(y_var) 181 | 182 | ############################################################################### 183 | # We can now plot the three reconstructed images 184 | from spyrit.misc.disp import add_colorbar, noaxis 185 | 186 | # Plot 187 | f, axs = plt.subplots(2, 2, figsize=(10, 10)) 188 | im1 = axs[0, 0].imshow(x[0, 0, :, :], cmap="gray") 189 | axs[0, 0].set_title("Ground-truth image") 190 | noaxis(axs[0, 0]) 191 | add_colorbar(im1, "bottom") 192 | 193 | im2 = axs[0, 1].imshow(z_nai_nonoise[0, 0, :, :], cmap="gray") 194 | axs[0, 1].set_title("Reconstruction noiseless") 195 | noaxis(axs[0, 1]) 196 | add_colorbar(im2, "bottom") 197 | 198 | im3 = axs[1, 0].imshow(z_nai[0, 0, :, :], cmap="gray") 199 | axs[1, 0].set_title("Reconstruction \n 'Naive' subsampling") 200 | noaxis(axs[1, 0]) 201 | add_colorbar(im3, "bottom") 202 | 203 | im4 = axs[1, 1].imshow(z_var[0, 0, :, :], cmap="gray") 204 | axs[1, 1].set_title("Reconstruction \n 'Variance' subsampling") 205 | noaxis(axs[1, 1]) 206 | add_colorbar(im4, "bottom") 207 | 208 | plt.show() 209 | 210 | ############################################################################### 211 | # .. note:: 212 | # 213 | # Note that reconstructed images are pixelized when using the "naive subsampling", 214 | # while they are smoother and more similar to the ground-truth image when using the 215 | # "variance subsampling". 216 | # 217 | # Another way to further improve results is to include a nonlinear post-processing step, 218 | # which we will consider in a future tutorial. 219 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_06_dcnet_split_measurements.py: -------------------------------------------------------------------------------- 1 | r""" 2 | ========================================= 3 | 06. Denoised Completion Network (DCNet) 4 | ========================================= 5 | .. _tuto_dcnet_split_measurements: 6 | This tutorial shows how to perform image reconstruction using the denoised 7 | completion network (DCNet) with a trainable image denoiser. In the next 8 | tutorial, we will plug a denoiser into a DCNet, which requires no training. 9 | 10 | .. figure:: ../fig/tuto3.png 11 | :width: 600 12 | :align: center 13 | :alt: Reconstruction and neural network denoising architecture sketch using split measurements 14 | """ 15 | 16 | ###################################################################### 17 | # .. note:: 18 | # As in the previous tutorials, we consider a split Hadamard operator and 19 | # measurements corrupted by Poisson noise (see :ref:`Tutorial 5 `). 20 | 21 | # %% 22 | # Load a batch of images 23 | # ========================================= 24 | 25 | ###################################################################### 26 | # Update search path 27 | 28 | # sphinx_gallery_thumbnail_path = 'fig/tuto6.png' 29 | if False: 30 | 31 | import os 32 | 33 | import torch 34 | import torchvision 35 | import matplotlib.pyplot as plt 36 | 37 | import spyrit.core.torch as spytorch 38 | from spyrit.misc.disp import imagesc 39 | from spyrit.misc.statistics import transform_gray_norm 40 | 41 | spyritPath = os.getcwd() 42 | imgs_path = os.path.join(spyritPath, "images/") 43 | 44 | ###################################################################### 45 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function. 46 | 47 | h = 64 # image is resized to h x h 48 | transform = transform_gray_norm(img_size=h) 49 | 50 | ###################################################################### 51 | # Create a data loader from some dataset (images must be in the folder `images/test/`) 52 | 53 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 54 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 55 | 56 | x, _ = next(iter(dataloader)) 57 | print(f"Shape of input images: {x.shape}") 58 | 59 | ###################################################################### 60 | # Select the `i`-th image in the batch 61 | i = 1 # Image index (modify to change the image) 62 | x = x[i : i + 1, :, :, :] 63 | x = x.detach().clone() 64 | print(f"Shape of selected image: {x.shape}") 65 | b, c, h, w = x.shape 66 | 67 | ###################################################################### 68 | # Plot the selected image 69 | 70 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") 71 | 72 | # %% 73 | # Forward operators for split measurements 74 | # ========================================= 75 | 76 | ###################################################################### 77 | # We consider noisy measurements obtained from a split Hadamard operator, and a subsampling strategy that retaines the coefficients with the largest variance (for more details, refer to :ref:`Tutorial 5 `). 78 | 79 | ###################################################################### 80 | # First, we download the covariance matrix from our warehouse. 81 | 82 | import girder_client 83 | from spyrit.misc.load_data import download_girder 84 | 85 | # Get covariance matrix 86 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 87 | dataId = "672207cbf03a54733161e95d" 88 | data_folder = "./stat/" 89 | cov_name = "Cov_64x64.pt" 90 | # download 91 | file_abs_path = download_girder(url, dataId, data_folder, cov_name) 92 | 93 | try: 94 | Cov = torch.load(file_abs_path, weights_only=True) 95 | print(f"Cov matrix {cov_name} loaded") 96 | except: 97 | Cov = torch.eye(h * h) 98 | print(f"Cov matrix {cov_name} not found! Set to the identity") 99 | 100 | ###################################################################### 101 | # We define the measurement, noise and preprocessing operators and then simulate 102 | # a measurement vector corrupted by Poisson noise. As in the previous tutorials, 103 | # we simulate an accelerated acquisition by subsampling the measurement matrix 104 | # by retaining only the first rows of a Hadamard matrix that is permuted looking 105 | # at the diagonal of the covariance matrix. 106 | 107 | from spyrit.core.meas import HadamSplit 108 | from spyrit.core.noise import Poisson 109 | from spyrit.core.prep import SplitPoisson 110 | 111 | # Measurement parameters 112 | M = h**2 // 4 # Number of measurements (here, 1/4 of the pixels) 113 | alpha = 100.0 # number of photons 114 | 115 | # Measurement and noise operators 116 | Ord = spytorch.Cov2Var(Cov) 117 | meas_op = HadamSplit(M, h, Ord) 118 | noise_op = Poisson(meas_op, alpha) 119 | prep_op = SplitPoisson(alpha, meas_op) 120 | 121 | print(f"Shape of image: {x.shape}") 122 | 123 | # Measurements 124 | y = noise_op(x) # a noisy measurement vector 125 | m = prep_op(y) # preprocessed measurement vector 126 | 127 | m_plot = spytorch.meas2img(m, Ord) 128 | imagesc(m_plot[0, 0, :, :], r"Measurements $m$") 129 | 130 | # %% 131 | # Pseudo inverse solution 132 | # ========================================= 133 | 134 | ###################################################################### 135 | # We compute the pseudo inverse solution using :class:`spyrit.core.recon.PinvNet` class as in the previous tutorial. 136 | 137 | # Instantiate a PinvNet (with no denoising by default) 138 | from spyrit.core.recon import PinvNet 139 | 140 | pinvnet = PinvNet(noise_op, prep_op) 141 | 142 | # Use GPU, if available 143 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 144 | print("Using device: ", device) 145 | pinvnet = pinvnet.to(device) 146 | y = y.to(device) 147 | 148 | # Reconstruction 149 | with torch.no_grad(): 150 | z_invnet = pinvnet.reconstruct(y) 151 | 152 | # %% 153 | # Denoised completion network (DCNet) 154 | # ========================================= 155 | 156 | ###################################################################### 157 | # .. image:: ../fig/dcnet.png 158 | # :width: 400 159 | # :align: center 160 | # :alt: Sketch of the DCNet architecture 161 | 162 | ###################################################################### 163 | # The DCNet is based on four sequential steps: 164 | # 165 | # i) Denoising in the measurement domain. 166 | # 167 | # ii) Estimation of the missing measurements from the denoised ones. 168 | # 169 | # iii) Image-domain mapping. 170 | # 171 | # iv) (Learned) Denoising in the image domain. 172 | # 173 | # Typically, only the last step involves learnable parameters. 174 | 175 | # %% 176 | # Denoised completion 177 | # ========================================= 178 | 179 | ###################################################################### 180 | # The first three steps implement denoised completion, which corresponds to Tikhonov regularization. Considering linear measurements :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is the unknown image, it estimates :math:`x` from :math:`y` by minimizing 181 | # 182 | # .. math:: 183 | # \| y - Hx \|^2_{\Sigma^{-1}_\alpha} + \|x\|^2_{\Sigma^{-1}}, 184 | # 185 | # where :math:`\Sigma` is a covariance prior and :math:`\Sigma_\alpha` is the noise covariance. Denoised completation can be performed using the :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag` class (see documentation for more details). 186 | 187 | ###################################################################### 188 | # In practice, it is more convenient to use the :class:`spyrit.core.recon.DCNet` class, which relies on a forward operator, a preprocessing operator, and a covariance prior. 189 | 190 | from spyrit.core.recon import DCNet 191 | 192 | dcnet = DCNet(noise_op, prep_op, Cov) 193 | 194 | # Use GPU, if available 195 | dcnet = dcnet.to(device) 196 | y = y.to(device) 197 | 198 | with torch.no_grad(): 199 | z_dcnet = dcnet.reconstruct(y) 200 | 201 | ###################################################################### 202 | # .. note:: 203 | # In this tutorial, the covariance matrix used to define subsampling is also used as prior knowledge during reconstruction. 204 | 205 | # %% 206 | # (Learned) Denoising in the image domain 207 | # ========================================= 208 | 209 | ###################################################################### 210 | # To implement denoising in the image domain, we provide a :class:`spyrit.core.nnet.Unet` denoiser to a :class:`spyrit.core.recon.DCNet`. 211 | 212 | from spyrit.core.nnet import Unet 213 | 214 | denoi = Unet() 215 | dcnet_unet = DCNet(noise_op, prep_op, Cov, denoi) 216 | dcnet_unet = dcnet_unet.to(device) # Use GPU, if available 217 | 218 | ######################################################################## 219 | # We load pretrained weights for the UNet 220 | 221 | from spyrit.core.train import load_net 222 | 223 | local_folder = "./model/" 224 | # Create model folder 225 | if os.path.exists(local_folder): 226 | print(f"{local_folder} found") 227 | else: 228 | os.mkdir(local_folder) 229 | print(f"Created {local_folder}") 230 | 231 | # Load pretrained model 232 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 233 | dataID = "67221559f03a54733161e960" # unique ID of the file 234 | data_name = "tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth" 235 | model_unet_path = os.path.join(local_folder, data_name) 236 | 237 | if os.path.exists(model_unet_path): 238 | print(f"Model found : {data_name}") 239 | 240 | else: 241 | print(f"Model not found : {data_name}") 242 | print(f"Downloading model... ", end="") 243 | try: 244 | gc = girder_client.GirderClient(apiUrl=url) 245 | gc.downloadFile(dataID, model_unet_path) 246 | print("Done") 247 | except Exception as e: 248 | print("Failed with error: ", e) 249 | 250 | # Load pretrained model 251 | load_net(model_unet_path, dcnet_unet, device, False) 252 | 253 | ###################################################################### 254 | # We reconstruct the image 255 | with torch.no_grad(): 256 | z_dcnet_unet = dcnet_unet.reconstruct(y) 257 | 258 | # %% 259 | # Results 260 | # ========================================= 261 | 262 | from spyrit.misc.disp import add_colorbar, noaxis 263 | 264 | f, axs = plt.subplots(2, 2, figsize=(10, 10)) 265 | 266 | # Plot the ground-truth image 267 | im1 = axs[0, 0].imshow(x[0, 0, :, :], cmap="gray") 268 | axs[0, 0].set_title("Ground-truth image", fontsize=16) 269 | noaxis(axs[0, 0]) 270 | add_colorbar(im1, "bottom") 271 | 272 | # Plot the pseudo inverse solution 273 | im2 = axs[0, 1].imshow(z_invnet.cpu()[0, 0, :, :], cmap="gray") 274 | axs[0, 1].set_title("Pseudo inverse", fontsize=16) 275 | noaxis(axs[0, 1]) 276 | add_colorbar(im2, "bottom") 277 | 278 | # Plot the solution obtained from denoised completion 279 | im3 = axs[1, 0].imshow(z_dcnet.cpu()[0, 0, :, :], cmap="gray") 280 | axs[1, 0].set_title(f"Denoised completion", fontsize=16) 281 | noaxis(axs[1, 0]) 282 | add_colorbar(im3, "bottom") 283 | 284 | # Plot the solution obtained from denoised completion with UNet denoising 285 | im4 = axs[1, 1].imshow(z_dcnet_unet.cpu()[0, 0, :, :], cmap="gray") 286 | axs[1, 1].set_title(f"Denoised completion with UNet denoising", fontsize=16) 287 | noaxis(axs[1, 1]) 288 | add_colorbar(im4, "bottom") 289 | 290 | plt.show() 291 | 292 | ###################################################################### 293 | # .. note:: 294 | # While the pseudo inverse reconstrcution is pixelized, the solution obtained by denoised completion is smoother. DCNet with UNet denoising in the image domain provides the best reconstruction. 295 | 296 | ###################################################################### 297 | # .. note:: 298 | # We refer to `spyrit-examples tutorials `_ for a comparison of different solutions (pinvNet, DCNet and DRUNet) that can be run in colab. 299 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_08_lpgd_split_measurements.py: -------------------------------------------------------------------------------- 1 | r""" 2 | ====================================================================== 3 | 08. Learned proximal gradient descent (LPGD) for split measurements 4 | ====================================================================== 5 | .. _tuto_lpgd_split_measurements: 6 | 7 | This tutorial shows how to perform image reconstruction with unrolled Learned Proximal Gradient 8 | Descent (LPGD) for split measurements. 9 | 10 | Unfortunately, it has a large memory consumption so it cannot be run interactively. 11 | If you want to run it yourself, please remove all the "if False:" statements at 12 | the beginning of each code block. The figures displayed are the ones that would 13 | be generated if the code was run. 14 | 15 | .. figure:: ../fig/lpgd.png 16 | :width: 600 17 | :align: center 18 | :alt: Sketch of the unrolled Learned Proximal Gradient Descent 19 | 20 | """ 21 | 22 | ############################################################################### 23 | # LPGD is a unrolled method, which can be explained as a recurrent network where 24 | # each block corresponds to un unrolled iteration of the proximal gradient descent. 25 | # At each iteration, the network performs a gradient step and a denoising step. 26 | # 27 | # The updated rule for the LPGD network is given by: 28 | # 29 | # .. math:: 30 | # x^{(k+1)} = \mathcal{G}_{\theta}(x^{(k)} - \gamma H^T(H(x^{(k)}-m))). 31 | # 32 | # where :math:`x^{(k)}` is the image estimate at iteration :math:`k`, 33 | # :math:`H` is the forward operator, :math:`\gamma` is the step size, 34 | # and :math:`\mathcal{G}_{\theta}` is a denoising network with 35 | # learnable parameters :math:`\theta`. 36 | 37 | # %% 38 | # Load a batch of images 39 | # ----------------------------------------------------------------------------- 40 | # 41 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized 42 | # using the :func:`transform_gray_norm` function. 43 | 44 | # sphinx_gallery_thumbnail_path = 'fig/lpgd.png' 45 | 46 | if False: 47 | import os 48 | 49 | import torch 50 | import torchvision 51 | import matplotlib.pyplot as plt 52 | 53 | import spyrit.core.torch as spytorch 54 | from spyrit.misc.disp import imagesc 55 | from spyrit.misc.statistics import transform_gray_norm 56 | 57 | spyritPath = os.getcwd() 58 | imgs_path = os.path.join(spyritPath, "images/") 59 | 60 | ###################################################################### 61 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function. 62 | 63 | h = 128 # image is resized to h x h 64 | transform = transform_gray_norm(img_size=h) 65 | 66 | ###################################################################### 67 | # Create a data loader from some dataset (images must be in the folder `images/test/`) 68 | 69 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) 70 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) 71 | 72 | x, _ = next(iter(dataloader)) 73 | print(f"Shape of input images: {x.shape}") 74 | 75 | ###################################################################### 76 | # Select the `i`-th image in the batch 77 | i = 1 # Image index (modify to change the image) 78 | x = x[i : i + 1, :, :, :] 79 | x = x.detach().clone() 80 | print(f"Shape of selected image: {x.shape}") 81 | b, c, h, w = x.shape 82 | 83 | ###################################################################### 84 | # Plot the selected image 85 | 86 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") 87 | 88 | ############################################################################### 89 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679972abaa5a90007058950/download 90 | # :width: 600 91 | # :align: center 92 | # :alt: Ground-truth image x in [-1, 1] 93 | # %% 94 | # Forward operators for split measurements 95 | # ----------------------------------------------------------------------------- 96 | # 97 | # We consider noisy split measurements for a Hadamard operator and a simple 98 | # rectangular subsampling” strategy 99 | # (for more details, refer to :ref:`Acquisition - split measurements `). 100 | # 101 | # 102 | # We define the measurement, noise and preprocessing operators and then 103 | # simulate a measurement vector :math:`y` corrupted by Poisson noise. As in the previous tutorial, 104 | # we simulate an accelerated acquisition by subsampling the measurement matrix 105 | # by retaining only the first rows of a Hadamard matrix. 106 | 107 | if False: 108 | import math 109 | 110 | from spyrit.core.meas import HadamSplit 111 | from spyrit.core.noise import Poisson 112 | from spyrit.core.prep import SplitPoisson 113 | 114 | # Measurement parameters 115 | M = h**2 // 4 # Number of measurements (here, 1/4 of the pixels) 116 | alpha = 10.0 # number of photons 117 | 118 | # Sampling: rectangular matrix 119 | Ord_rec = torch.zeros(h, h) 120 | n_sub = math.ceil(M**0.5) 121 | Ord_rec[:n_sub, :n_sub] = 1 122 | 123 | # Measurement and noise operators 124 | meas_op = HadamSplit(M, h, Ord_rec) 125 | noise_op = Poisson(meas_op, alpha) 126 | prep_op = SplitPoisson(alpha, meas_op) 127 | 128 | print(f"Shape of image: {x.shape}") 129 | 130 | # Measurements 131 | y = noise_op(x) # a noisy measurement vector 132 | m = prep_op(y) # preprocessed measurement vector 133 | 134 | m_plot = spytorch.meas2img(m, Ord_rec) 135 | imagesc(m_plot[0, 0, :, :], r"Measurements $m$") 136 | 137 | ############################################################################### 138 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679972bbaa5a90007058953/download 139 | # :width: 600 140 | # :align: center 141 | # :alt: Measurements m 142 | # 143 | # We define the LearnedPGD network by providing the measurement, noise and preprocessing operators, 144 | # the denoiser and other optional parameters to the class :class:`spyrit.core.recon.LearnedPGD`. 145 | # The optional parameters include the number of unrolled iterations (`iter_stop`) 146 | # and the step size decay factor (`step_decay`). 147 | # We choose Unet as the denoiser, as in previous tutorials. 148 | # For the optional parameters, we use three iterations and a step size decay 149 | # factor of 0.9, which worked well on this data (this should match the parameters 150 | # used during training). 151 | # 152 | # .. image:: ../fig/lpgd.png 153 | # :width: 600 154 | # :align: center 155 | # :alt: Sketch of the network architecture for LearnedPGD 156 | 157 | if False: 158 | from spyrit.core.nnet import Unet 159 | from spyrit.core.recon import LearnedPGD 160 | 161 | # use GPU, if available 162 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 163 | print("Using device:", device) 164 | # Define UNet denoiser 165 | denoi = Unet() 166 | # Define the LearnedPGD model 167 | lpgd_net = LearnedPGD(noise_op, prep_op, denoi, iter_stop=3, step_decay=0.9) 168 | 169 | ############################################################################### 170 | # Now, we download the pretrained weights and load them into the LPGD network. 171 | # Unfortunately, the pretrained weights are too heavy (2GB) to be downloaded 172 | # here. The last figure is nonetheless displayed to show the results. 173 | 174 | if False: 175 | from spyrit.core.train import load_net 176 | from spyrit.misc.load_data import download_girder 177 | 178 | # Download parameters 179 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" 180 | dataID = "67221f60f03a54733161e96c" # unique ID of the file 181 | local_folder = "./model/" 182 | data_name = "tuto8_model_lpgd_light.pth" 183 | # Download from Girder 184 | model_abs_path = download_girder(url, dataID, local_folder, data_name) 185 | 186 | # Load pretrained weights to the model 187 | load_net(model_abs_path, lpgd_net, device, strict=False) 188 | 189 | lpgd_net.eval() 190 | lpgd_net.to(device) 191 | 192 | ############################################################################### 193 | # We reconstruct by calling the reconstruct method as in previous tutorials 194 | # and display the results. 195 | 196 | if False: 197 | from spyrit.misc.disp import add_colorbar, noaxis 198 | 199 | with torch.no_grad(): 200 | z_lpgd = lpgd_net.reconstruct(y.to(device)) 201 | 202 | # Plot results 203 | f, axs = plt.subplots(2, 1, figsize=(10, 10)) 204 | 205 | im1 = axs[0].imshow(x.cpu()[0, 0, :, :], cmap="gray") 206 | axs[0].set_title("Ground-truth image", fontsize=16) 207 | noaxis(axs[0]) 208 | add_colorbar(im1, "bottom") 209 | 210 | im2 = axs[1].imshow(z_lpgd.cpu()[0, 0, :, :], cmap="gray") 211 | axs[1].set_title("LPGD", fontsize=16) 212 | noaxis(axs[1]) 213 | add_colorbar(im2, "bottom") 214 | 215 | plt.show() 216 | 217 | ############################################################################### 218 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679853fbaa5a9000705894b/download 219 | # :width: 400 220 | # :align: center 221 | # :alt: Comparison of ground-truth image and LPGD reconstruction 222 | -------------------------------------------------------------------------------- /tutorial/wip/_tuto_bonus_advanced_methods_colab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | r""" 3 | Bonus. Advanced methods - Colab 4 | =============================== 5 | .. _tuto_advanced_methods_colab: 6 | 7 | We refer to `spyrit-examples/tutorial `_ 8 | for a list of tutorials that can be run directly in colab and present more advanced cases than the main spyrit tutorials, 9 | such as comparison of methods for split measurements, or comparison of different denoising networks. 10 | 11 | """ 12 | 13 | ############################################################################### 14 | # The spyrit-examples repository also includes research contributions based on the SPYRIT toolbox. 15 | --------------------------------------------------------------------------------