├── .github
└── workflows
│ └── main.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── AUTHORS
├── CHANGELOG.md
├── Codemeta.json
├── LICENSE.md
├── README.md
├── docs
├── Makefile
├── make.bat
└── source
│ ├── _static
│ └── css
│ │ └── sg_README.css
│ ├── _templates
│ ├── spyrit-class-template.rst
│ ├── spyrit-method-template.rst
│ └── spyrit-module-template.rst
│ ├── build-tuto-local.md
│ ├── conf.py
│ ├── fig
│ ├── dcnet.png
│ ├── direct_net.png
│ ├── drunet.png
│ ├── full.png
│ ├── lpgd.png
│ ├── pinvnet.png
│ ├── pinvnet_cnn.png
│ ├── principle.png
│ ├── spi_principle.png
│ ├── tuto1.png
│ ├── tuto2.png
│ ├── tuto3.png
│ ├── tuto3_pinv.png
│ ├── tuto6.png
│ └── tuto9.png
│ ├── index.rst
│ ├── organisation.rst
│ ├── sg_execution_times.rst
│ └── single_pixel.rst
├── pyproject.toml
├── requirements.txt
├── spyrit
├── __init__.py
├── core
│ ├── __init__.py
│ ├── inverse.py
│ ├── meas.py
│ ├── nnet.py
│ ├── noise.py
│ ├── prep.py
│ ├── recon.py
│ ├── torch.py
│ ├── train.py
│ └── warp.py
├── dev
│ ├── meas.py
│ ├── prep.py
│ └── recon.py
├── external
│ ├── __init__.py
│ └── drunet.py
├── hadamard_matrix
│ ├── __init__.py
│ ├── create_hadamard_matrix_with_sage.py
│ └── download_hadamard_matrix.py
└── misc
│ ├── __init__.py
│ ├── color.py
│ ├── data_visualisation.py
│ ├── disp.py
│ ├── examples.py
│ ├── load_data.py
│ ├── matrix_tools.py
│ ├── metrics.py
│ ├── pattern_choice.py
│ ├── sampling.py
│ ├── statistics.py
│ └── walsh_hadamard.py
└── tutorial
├── README.txt
├── images
└── test
│ ├── ILSVRC2012_test_00000001.jpeg
│ ├── ILSVRC2012_test_00000002.jpeg
│ ├── ILSVRC2012_test_00000003.jpeg
│ ├── ILSVRC2012_test_00000004.jpeg
│ ├── ILSVRC2012_test_00000005.jpeg
│ ├── ILSVRC2012_test_00000006.jpeg
│ └── ILSVRC2012_test_00000007.jpeg
├── tuto_01_a_acquisition_operators.py
├── tuto_01_b_splitting.py
├── tuto_01_c_HadamSplit2d.py
├── tuto_02_noise.py
├── tuto_03_pseudoinverse_linear.py
└── wip
├── _tuto_03_pseudoinverse_cnn_linear.py
├── _tuto_04_train_pseudoinverse_cnn_linear.py
├── _tuto_05_recon_hadamSplit.py
├── _tuto_06_dcnet_split_measurements.py
├── _tuto_07_drunet_split_measurements.py
├── _tuto_08_lpgd_split_measurements.py
├── _tuto_09_dynamic.py
└── _tuto_bonus_advanced_methods_colab.py
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 |
2 | name: CI
3 |
4 | on:
5 | push:
6 | branches: [ master ]
7 | tags:
8 | - '*'
9 | pull_request:
10 | branches:
11 | - '*'
12 | schedule:
13 | - cron: '0 0 * * 0'
14 | workflow_dispatch:
15 |
16 | jobs:
17 | build_wheel:
18 | runs-on: ${{ matrix.os }}
19 | strategy:
20 | fail-fast: false
21 | matrix:
22 | os: [ubuntu-latest]
23 | python-version: [3.9]
24 |
25 | steps:
26 | - name: Checkout github repo
27 | uses: actions/checkout@v4
28 | - name: Checkout submodules
29 | run: git submodule update --init --recursive
30 | - name: Set up Python ${{ matrix.python-version }}
31 | uses: actions/setup-python@v5
32 | with:
33 | python-version: ${{ matrix.python-version }}
34 | architecture: 'x64'
35 | - name: Create Wheel
36 | run: |
37 | pip install build
38 | python -m build
39 | mkdir wheelhouse
40 | cp dist/spyrit-* wheelhouse/
41 | ls wheelhouse
42 | rm -r dist
43 | mv wheelhouse dist
44 | - name: Upload wheels
45 | uses: actions/upload-artifact@v4
46 | with:
47 | name: dist
48 | path: dist/
49 |
50 | test_install:
51 | runs-on: ${{ matrix.os }}
52 | strategy:
53 | fail-fast: false
54 | matrix:
55 | os: [ubuntu-latest, windows-latest, macos-13, macos-14]
56 | python-version: [3.9, "3.10", "3.11", "3.12"]
57 | exclude:
58 | - os: macos-13
59 | python-version: '3.10'
60 | - os: macos-13
61 | python-version: '3.11'
62 | - os: macos-13
63 | python-version: '3.12'
64 | - os: macos-14
65 | python-version: 3.9
66 | - os: macos-14
67 | python-version: '3.10'
68 |
69 | steps:
70 | - name: Checkout github repo
71 | uses: actions/checkout@v4
72 | - name: Checkout submodules
73 | run: git submodule update --init --recursive
74 | - name: Set up Python ${{ matrix.python-version }}
75 | uses: actions/setup-python@v5
76 | with:
77 | python-version: ${{ matrix.python-version }}
78 | architecture: 'x64'
79 | - name: Run the tests on Mac and Linux
80 | if: matrix.os != 'windows-latest'
81 | run: |
82 | pip install pytest
83 | pip install -e .
84 | python -m pytest --doctest-modules --ignore=tutorial --ignore=docs --ignore=spyrit/dev --ignore=spyrit/hadamard_matrix || exit -1
85 | - name: Run the tests on Windows
86 | if: matrix.os == 'windows-latest'
87 | shell: cmd
88 | run: |
89 | pip install pytest
90 | pip install -e .
91 | python -m pytest --doctest-modules --ignore=tutorial --ignore=docs --ignore=spyrit\dev --ignore=spyrit\hadamard_matrix || exit /b -1
92 |
93 | test_wheel:
94 | runs-on: ${{ matrix.os }}
95 | needs: [build_wheel]
96 | strategy:
97 | fail-fast: false
98 | matrix:
99 | os: [ubuntu-latest, windows-latest, macos-13]
100 | python-version: [3.9, "3.10", "3.11", "3.12"]
101 |
102 | steps:
103 | - name: Checkout github repo
104 | uses: actions/checkout@v4
105 | - name: Checkout submodules
106 | run: git submodule update --init --recursive
107 | - name: Set up Python ${{ matrix.python-version }}
108 | uses: actions/setup-python@v5
109 | with:
110 | python-version: ${{ matrix.python-version }}
111 | architecture: 'x64'
112 | - uses: actions/download-artifact@v4
113 | with:
114 | pattern: dist*
115 | merge-multiple: true
116 | path: dist/
117 | - name: Run tests on Mac and Linux
118 | if: matrix.os != 'windows-latest'
119 | run: |
120 | cd dist
121 | pip install spyrit-*.whl
122 | - name: Run the tests on Windows
123 | if: matrix.os == 'windows-latest'
124 | run: |
125 | cd dist
126 | $package=dir -Path . -Filter spyrit*.whl | %{$_.FullName}
127 | echo $package
128 | pip install $package
129 |
130 | publish_wheel:
131 | runs-on: ubuntu-latest
132 | needs: [build_wheel, test_wheel, test_install]
133 | steps:
134 | - name: Checkout github repo
135 | uses: actions/checkout@v4
136 | - name: Checkout submodules
137 | run: git submodule update --init --recursive
138 | - uses: actions/download-artifact@v4
139 | with:
140 | pattern: dist*
141 | merge-multiple: true
142 | path: dist/
143 | - name: Publish to PyPI
144 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/')
145 | uses: pypa/gh-action-pypi-publish@release/v1
146 | with:
147 | user: __token__
148 | password: ${{ secrets.PYPI }}
149 | skip_existing: true
150 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.py[co]
3 | *.swp
4 | *.pdf
5 | *.png
6 | *.sh
7 | *.ipynb
8 | *.npy
9 | *.npz
10 |
11 | #folders
12 | data
13 | img
14 | models
15 | dist
16 | build
17 | spyrit.egg-info
18 | stats_walsh
19 | **/.ipynb_checkpoints/*
20 | docs/source/_build*
21 | model/
22 | runs/
23 | spyrit/drunet/
24 | !spyrit/images/tuto/*.png
25 | docs/source/html
26 | docs/source/_autosummary
27 | docs/source/_templates
28 | docs/source/api
29 | docs/source/gallery
30 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: end-of-file-fixer
6 | - id: trailing-whitespace
7 | - repo: https://github.com/psf/black
8 | rev: 25.1.0
9 | hooks:
10 | - id: black
11 | ci:
12 | autofix_commit_msg: |
13 | [pre-commit.ci] Automatic python formatting
14 | autofix_prs: true
15 | submodules: false
16 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Optionally build your docs in additional formats such as PDF
9 | formats: all # pdf, epub and htlmzip
10 |
11 | # Optionally set the version of Python and requirements required to build your docs
12 | build:
13 | os: ubuntu-22.04
14 | tools:
15 | python: "3.11"
16 | python:
17 | install:
18 | - requirements: requirements.txt
19 |
20 | # Build documentation in the docs/ directory with Sphinx
21 | sphinx:
22 | configuration: docs/source/conf.py
23 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | Nicolas Ducros
2 | Romain Phan
3 | Thomas Baudier
4 | Juan Abascal
5 | Fadoua Taia-Alaoui
6 | Claire Mouton
7 | Guilherme Beneti
8 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ### Notations
4 |
5 | \- removals
6 | / changes
7 | \+ additions
8 |
9 | ---
10 |
11 |
12 |
13 |
14 | ## Changes to come in a future version
15 |
16 |
17 |
18 |
19 |
20 |
21 | ---
22 |
23 |
24 |
25 | ## v2.3.4
26 |
27 |
28 | ### spyrit.core
29 | * #### General changes
30 | * The input and output shapes have been standardized across operators. All still images (i.e. not videos) have shape `(*, h, w)`, where `*` is any batch dimension (e.g. batch size and number of channels), and `h` and `w` are the height and width of the image. All measurements have shape `(*, M)`, where `*` is the same batch dimension than the images they come from. Videos have shape `(*, t, c, h, w)` where `t` is the time dimension, representing the number of frames in the video, `c` is the number of channels. Dynamic measurements from videos will thus have shape `(*, c, M)`.
31 | * The overall use of gpu has been improved. Every class of the `core` module now has a method `self.device` that allows to track the device on which its parameters are.
32 | * #### spyrit.core.meas
33 | * / The regularization value 'L1' has been changed to 'rcond'. The behavior is unchanged but the reconstruction did not correspond to L1 regularization.
34 | * / Fixed .pinv() output shape (it was transposed with some regularisation methods)
35 | * / Fixed some device errors when using cuda with .pinv()
36 | * / The measurement matrix H is now stored with the data type it is given to the constructor (it was previously converted to torch.float32 for memory reasons)
37 | * \+ added in the .pinv() method a diff parameter enabling differentiated reconstructions (subtracting negative patterns/measurements to the positive patterns/measurements), only available for dynamic operators.
38 | * / For HadamSplit, the pinv has been overwritten to use a fast Walsh-Hadamard transform, zero-padding the measurements if necessary (in the case of subsampling). The inverse() method has been deprecated and will be removed in a future release.
39 | * #### spyrit.core.recon
40 | * \- The class core.recon.Denoise_layer is deprecated and will be removed in a future version
41 | * / The class TikhonovMeasurementPriorDiag no longer uses Denoise_layer and uses instead an internal method to handle the denoising.
42 | * #### spyrit.core.train
43 | * / load_net() uses the weights_only=True parameter in the torch.load() function. Documentation updated.
44 | * #### spyrit.core.warp
45 | * / The warping operation (forward method) now has to be performed on (b,c,h,w) input tensors, and returns (b, time, c, h, w) output tensors.
46 | * / The AffineDeformationField does not store anymore the field as an attribute, but is rather generated on the fly. This allows for more efficient memory management.
47 | * / In AffineDeformationField the image size can be changed.
48 | * \+ It is now possible to use biquintic (5th-order) warping. This uses scikit-image's (skimage) warp function, which relies on numpy arrays.
49 |
50 | ### Tutorials
51 | * Tutorial 2 integrated the change from 'L1' to 'rcond'
52 | * All Tutorials have been updated to include the above mentioned changes.
53 |
54 |
55 |
56 | ---
57 |
58 |
59 |
60 | ## v2.3.3
61 |
62 |
63 | ### spyrit.core
64 | * #### spyrit.core.meas
65 | * / The regularization value 'L1' has been changed to 'rcond'. The behavior is unchanged but the reconstruction did not correspond to L1 regularization.
66 | * #### spyrit.core.recon
67 | * / The documentation for the class core.recon.Denoise_layer has been clarified.
68 |
69 | ### Tutorials
70 |
71 | * Tutorial 2 integrated the change from 'L1' to 'rcond'
72 |
73 |
74 |
75 | ---
76 |
77 |
78 |
79 | ## v2.3.2
80 |
81 |
82 | ### spyrit.core
83 | * #### spyrit.core.meas
84 | * / The method forward_H has been optimized for the HadamSplit class
85 | * #### spyrit.core.torch
86 | * \+ Added spyrit.core.torch.fwht that implements in Pytorch the fast Walsh-Hadamard tranform for natural and Walsh ordered tranforms.
87 | * \+ Added spyrit.core.torch.fwht_2d that implements in Pytorch the fast Walsh-Hadamard tranform in 2 dimensions for natural and Walsh ordered tranforms.
88 |
89 | ### spyrit.misc
90 |
91 | * #### spyrit.misc.statistics
92 | * / The function spyrit.misc.statistics.Cov2Var has been sped up and now supports an output shape for non-square images
93 | * #### spyrit.misc.walsh_hadamard
94 | * / The function spyrit.misc.walsh_hadamard.fwht has been significantly sped up, especially for sequency-ordered walsh-hadamard tranforms.
95 | * \- fwht_torch is now deprecated. Use spyrit.core.torch.fwht instead.
96 | * \- walsh_torch is now deprecated. Use spyrit.core.torch.fwht instead.
97 | * \- walsh2_torch is now deprecated. Use spyrit.core.torch.fwht_2d instead.
98 | * #### spyrit.misc.load_data
99 | * \+ New function download_girder that downloads files identified by their hexadecimal ID from a url server
100 |
101 | ### Tutorials
102 |
103 | * Tutorials 3, 4, 6, 7, 8 now download data from our own servers instead of using google drive and the gdown library. Dependency on gdown library will be fully removed in a future version.
104 |
105 |
106 |
107 | ---
108 |
109 |
110 |
111 | ## v2.3.1
112 |
113 |
114 | ### spyrit.core
115 |
116 | * #### spyrit.core.meas
117 | * \+ For static classes, self.set_H_pinv has been renamed to self.build_H_pinv to match with the dynamic classes.
118 | * \+ The dynamic classes now support bicubic dynamic reconstruction (spyrit.core.meas.DynamicLinear.build_h_dyn()). This uses cubic B-splines.
119 | * #### spyrit.core.train
120 | * load_net() must take the full path, **with** the extension name (xyz.pth).
121 |
122 | ### Tutorials
123 |
124 | * Tutorial 6 has been changed accordingly to the modification of spyrit.core.train.load_net().
125 | * Tutorial 8 is now available.
126 |
127 |
128 |
129 | ---
130 |
131 |
132 |
133 | ## v2.3.0
134 |
135 |
136 |
137 |
138 | ### spyrit.core
139 |
140 |
141 | * / no longer supports numpy.array as input, must use torch.tensor
142 | * #### spyrit.core.meas
143 | * \- class LinearRowSplit (use LinearSplit instead)
144 | * \+ 3 dynamic classes: DynamicLinear, DynamicLinearSplit, DynamicHadamSplit that allow measurements over time
145 | * spyrit.core.meas.Linear
146 | * \- self.get_H() deprecated (use self.H)
147 | * \- self.H_adjoint (you might want to use self.H.T)
148 | * / constructor argument 'reg' renamed to 'rtol'
149 | * / self.H no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable)
150 | * / self.H_pinv no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable)
151 | * \+ self.__init__() has 'Ord' and 'meas_shape' optional arguments
152 | * \+ self.pinv() now supports lstsq image reconstruction if self.H_pinv is not defined
153 | * \+ self.set_H_pinv(), self.reindex() inherited from spyrit.misc.torch
154 | * \+ self.meas_shape, self.indices, self.Ord, self.H_static
155 | * spyrit.core.meas.LinearSplit
156 | * / [includes changes from Linear]
157 | * / self.P no longer refers to a torch.nn.Linear, but to a torch.tensor (not callable)
158 | * spyrit.core.meas.HadamSplit
159 | * / [includes changes from LinearSplit]
160 | * \- self.__init__() does not need 'meas_shape' argument, it is taken as (h,h)
161 | * \- self.Perm (use self.reindex() instead)
162 | * #### spyrit.core.noise
163 | * spyrit.core.noise.NoNoise
164 | * \+ self.reindex() inherited from spyrit.core.meas.Linear.reindex()
165 | * #### spyrit.core.prep
166 | * \- class SplitRowPoisson (was used with LinearRowSplit)
167 | * #### spyrit.core.recon
168 | * spyrit.core.recon.PseudoInverse
169 | * / self.forward() now has **kwargs that are passed to meas_op.pinv(), useful for lstsq image reconstruction
170 | * #### \+ spyrit.core.torch
171 | contains torch-specific functions that are commonly used in spyrit.core. Mirrors some spyrit.misc functions that are numpy-specific
172 | * #### \+ spyrit.core.warp
173 | * \+ class AffineDeformationField
174 | warps an image using an affine transformation matrix
175 | * \+ class DeformationField
176 | warps an image using a deformation field
177 |
178 |
179 |
180 |
181 | ### spyrit.misc
182 |
183 |
184 | * #### spyrit.misc.matrix_tools
185 | * \- Permutation_Matrix() is deprecated (already defined in spyrit.misc.sampling.Permutation_Matrix())
186 | * #### spyrit.misc.sampling
187 | * \- meas2img2() is deprecated (use meas2img() instead)
188 | * / meas2img() can now handle batch of images
189 | * \+ sort_by_significance() & reindex() to speed up permutation mattrix multiplication
190 |
191 |
--------------------------------------------------------------------------------
/Codemeta.json:
--------------------------------------------------------------------------------
1 | {
2 | "@context": "https://doi.org/10.5063/schema/codemeta-2.0",
3 | "type": "SoftwareSourceCode",
4 | "applicationCategory": "Single-pixel imaging",
5 | "codeRepository": "https://github.com/openspyrit/spyrit",
6 | "dateCreated": "2020-12-10",
7 | "datePublished": "2021-03-11",
8 | "description": "SPyRiT is a PyTorch-based deep image reconstruction package primarily designed for single-pixel imaging.",
9 | "keywords": [
10 | "Single-pixel imaging",
11 | "pytorch"
12 | ],
13 | "license": "https://spdx.org/licenses/LGPL-3.0",
14 | "name": "SPyRiT",
15 | "operatingSystem": [
16 | "Linux",
17 | "Windows",
18 | "MacOS"
19 | ],
20 | "programmingLanguage": "Python 3",
21 | "contIntegration": "https://github.com/openspyrit/spyrit/actions",
22 | "codemeta:continuousIntegration": {
23 | "id": "https://github.com/openspyrit/spyrit/actions"
24 | },
25 | "issueTracker": "https://github.com/openspyrit/spyrit/issues"
26 | }
27 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | [](https://github.com/openspyrit/spyrit/blob/master/LICENSE.md)
3 | [](https://pypi.python.org/pypi/spyrit/)
4 | [](https://spyrit.readthedocs.io/en/master/)
5 |
6 | # SPyRiT
7 | SPyRiT is a [PyTorch]()-based deep image reconstruction package primarily designed for single-pixel imaging.
8 |
9 | # Installation
10 | The spyrit package is available for Linux, MacOs and Windows. We recommend to use a virtual environment.
11 | ## Linux and MacOs
12 | (user mode)
13 | ```
14 | pip install spyrit
15 | ```
16 | (developper mode)
17 | ```
18 | git clone https://github.com/openspyrit/spyrit.git
19 | cd spyrit
20 | pip install -e .
21 | ```
22 |
23 | ## Windows
24 | On Windows you may need to install PyTorch first. It may also be necessary to run the following commands using administrator rights (e.g., starting your Python environment with administrator rights).
25 |
26 | Adapt the two examples below to your configuration (see [here](https://pytorch.org/get-started/locally/) for the latest instructions)
27 |
28 | (CPU version using `pip`)
29 |
30 | ```
31 | pip3 install torch torchvision torchaudio
32 | ```
33 |
34 | (GPU version using `conda`)
35 |
36 | ``` shell
37 | conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
38 | ```
39 |
40 | Then, install SPyRiT using `pip`:
41 |
42 | (user mode)
43 | ```
44 | pip install spyrit
45 | ```
46 | (developper mode)
47 | ```
48 | git clone https://github.com/openspyrit/spyrit.git
49 | cd spyrit
50 | pip install -e .
51 | ```
52 |
53 |
54 | ## Test
55 | To check the installation, run in your python terminal:
56 | ```
57 | import spyrit
58 | ```
59 |
60 | ## Get started - Examples
61 | To start, check the [documentation tutorials](https://spyrit.readthedocs.io/en/master/gallery/index.html). These tutorials must be runned from `tutorial` folder (they load image samples from `spyrit/images/`):
62 | ```
63 | cd spyrit/tutorial/
64 | ```
65 |
66 | More advanced reconstruction examples can be found in [spyrit-examples/tutorial](https://github.com/openspyrit/spyrit-examples/tree/master/tutorial). Run advanced tutorial in colab: [](https://colab.research.google.com/github/openspyrit/spyrit-examples/blob/master/tutorial/tuto_core_2d_drunet.ipynb)
67 |
68 |
69 | # API Documentation
70 | https://spyrit.readthedocs.io/
71 |
72 | # Contributors (alphabetical order)
73 | * Juan Abascal - [Website](https://juanabascal78.wixsite.com/juan-abascal-webpage)
74 | * Thomas Baudier
75 | * Sebastien Crombez
76 | * Nicolas Ducros - [Website](https://www.creatis.insa-lyon.fr/~ducros/WebPage/index.html)
77 | * Antonio Tomas Lorente Mur - [Website]( https://sites.google.com/view/antonio-lorente-mur/)
78 | * Romain Phan
79 | * Fadoua Taia-Alaoui
80 |
81 | # How to cite?
82 | When using SPyRiT in scientific publications, please cite the following paper:
83 |
84 | * G. Beneti-Martin, L Mahieu-Williame, T Baudier, N Ducros, "OpenSpyrit: an Ecosystem for Reproducible Single-Pixel Hyperspectral Imaging," Optics Express, Vol. 31, No. 10, (2023). https://doi.org/10.1364/OE.483937.
85 |
86 | When using SPyRiT specifically for the denoised completion network, please cite the following paper:
87 |
88 | * A Lorente Mur, P Leclerc, F Peyrin, and N Ducros, "Single-pixel image reconstruction from experimental data using neural networks," Opt. Express 29, 17097-17110 (2021). https://doi.org/10.1364/OE.424228.
89 |
90 | # License
91 | This project is licensed under the LGPL-3.0 license - see the [LICENSE.md](LICENSE.md) file for details
92 |
93 | # Acknowledgments
94 | * [Jin LI](https://github.com/happyjin/ConvGRU-pytorch) for his implementation of Convolutional Gated Recurrent Units for PyTorch
95 | * [Erik Lindernoren](https://github.com/eriklindernoren/Action-Recognition) for his processing of the UCF-101 Dataset.
96 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/css/sg_README.css:
--------------------------------------------------------------------------------
1 | .sphx-glr-thumbnails {
2 | width: 100%;
3 | margin: 0px 0px 0px 0px;
4 |
5 | /* align thumbnails on a grid */
6 | justify-content: space-between;
7 | display: grid;
8 | /* each grid column should be at least 160px (this will determine
9 | the actual number of columns) and then take as much of the
10 | remaining width as possible */
11 | grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)) !important;
12 | gap: 20px;
13 | }
14 | .sphx-glr-thumbcontainer {
15 | width: 100% !important;
16 | min-height: 210px !important;
17 | margin: 0px !important;
18 | }
19 | .sphx-glr-thumbcontainer .figure {
20 | min-width: 100px !important;
21 | height: 100px !important;
22 | }
23 | .sphx-glr-thumbcontainer img {
24 | display: inline !important;
25 | object-fit: cover !important;
26 | max-height: 150px !important;
27 | min-width: 300px !important;
28 | }
29 |
--------------------------------------------------------------------------------
/docs/source/_templates/spyrit-class-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :show-inheritance:
7 |
8 | {% block methods %}
9 | {% if methods %}
10 | .. rubric:: {{ _('Methods') }}
11 |
12 | .. autosummary::
13 | :toctree:
14 | :template: spyrit-method-template.rst
15 | {% for item in methods %}
16 | {%- if item is in members %}
17 | ~{{ name }}.{{ item }}
18 | {%- endif %}
19 | {%- endfor %}
20 | {% endif %}
21 | {% endblock %}
22 |
--------------------------------------------------------------------------------
/docs/source/_templates/spyrit-method-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. automethod:: {{ objname }}
6 |
--------------------------------------------------------------------------------
/docs/source/_templates/spyrit-module-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. automodule:: {{ fullname }}
4 | :show-inheritance:
5 |
6 | {% block attributes %}
7 | {% if attributes %}
8 | .. rubric:: {{ _('Module Attributes') }}
9 |
10 | .. autosummary::
11 | :toctree:
12 | {% for item in attributes %}
13 | {{ item }}
14 | {%- endfor %}
15 | {% endif %}
16 | {% endblock %}
17 |
18 | {% block functions %}
19 | {% if functions %}
20 | .. rubric:: {{ _('Functions') }}
21 |
22 | .. autosummary::
23 | :toctree:
24 | {% for item in functions %}
25 | {{ item }}
26 | {%- endfor %}
27 | {% endif %}
28 | {% endblock %}
29 |
30 | {% block classes %}
31 | {% if classes %}
32 | .. rubric:: {{ _('Classes') }}
33 |
34 | .. autosummary::
35 | :toctree:
36 | :template: spyrit-class-template.rst
37 | {% for item in classes %}
38 | {{ item }}
39 | {%- endfor %}
40 | {% endif %}
41 | {% endblock %}
42 |
43 | {% block exceptions %}
44 | {% if exceptions %}
45 | .. rubric:: {{ _('Exceptions') }}
46 |
47 | .. autosummary::
48 | :toctree:
49 | {% for item in exceptions %}
50 | {{ item }}
51 | {%- endfor %}
52 | {% endif %}
53 | {% endblock %}
54 |
55 | {% block modules %}
56 | {% if modules %}
57 | .. rubric:: Modules
58 |
59 | .. autosummary::
60 | :toctree:
61 | :template: spyrit-module-template.rst
62 | :recursive:
63 | {% for item in modules %}
64 | {{ item }}
65 | {%- endfor %}
66 | {% endif %}
67 | {% endblock %}
68 |
--------------------------------------------------------------------------------
/docs/source/build-tuto-local.md:
--------------------------------------------------------------------------------
1 |
2 | ``` shell
3 | git clone --no-single-branch --depth 50 https://github.com/openspyrit/spyrit .
4 | git checkout --force origin/gallery
5 | git clean -d -f -f
6 | cat .readthedocs.yml
7 | ```
8 |
9 | # Linux
10 | ``` shell
11 | python3.7 -mvirtualenv $READTHEDOCS_VIRTUALENV_PATH
12 | python -m pip install --upgrade --no-cache-dir pip setuptools
13 | python -m pip install --upgrade --no-cache-dir pillow==5.4.1 mock==1.0.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.9.1 recommonmark==0.5.0 sphinx sphinx-rtd-theme readthedocs-sphinx-ext<2.3
14 | python -m pip install --exists-action=w --no-cache-dir -r requirements.txt
15 | cat docs/source/conf.py
16 | python -m sphinx -T -E -b html -d _build/doctrees -D language=en . $READTHEDOCS_OUTPUT/html
17 | ```
18 |
19 | # Windows using conda
20 | ``` powershell
21 | conda create --name readthedoc
22 | conda activate readthedoc
23 | conda install pip
24 | python.exe -m pip install --upgrade --no-cache-dir pip setuptools
25 | pip install --upgrade --no-cache-dir pillow==10.0.0 mock==1.0.1 alabaster==0.7.13 commonmark==0.9.1 recommonmark==0.5.0 sphinx sphinx-rtd-theme readthedocs-sphinx-ext==2.2.2
26 | cd .\myenv\spyrit\ # replace myenv by the environment in which spyrit is installed
27 | pip install --exists-action=w --no-cache-dir -r requirements.txt
28 | cd .\docs\source\
29 | python -m sphinx -T -E -b html -d _build/doctrees -D language=en . html
30 | ```
31 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 |
13 | import os
14 | import sys
15 | from sphinx_gallery.sorting import ExampleTitleSortKey
16 |
17 | # paths relative to this file
18 | sys.path.insert(0, os.path.abspath("../.."))
19 |
20 | # -- Project information -----------------------------------------------------
21 | project = "spyrit"
22 | copyright = "2021, Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan"
23 | author = "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan"
24 |
25 | # The full version, including alpha/beta/rc tags
26 | release = "2.4.0"
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | "sphinx.ext.intersphinx",
35 | "sphinx.ext.autodoc",
36 | "sphinx.ext.mathjax",
37 | "sphinx.ext.todo",
38 | "sphinx.ext.autosummary",
39 | "sphinx.ext.napoleon",
40 | "sphinx.ext.viewcode",
41 | "sphinx_gallery.gen_gallery",
42 | "sphinx.ext.coverage",
43 | ]
44 |
45 | # Napoleon settings
46 | napoleon_google_docstring = False
47 | napoleon_numpy_docstring = True
48 | napoleon_include_private_with_doc = False
49 | napoleon_include_special_with_doc = False
50 | napoleon_use_admonition_for_examples = False
51 | napoleon_use_admonition_for_notes = False
52 | napoleon_use_admonition_for_references = False
53 | napoleon_use_ivar = True
54 | napoleon_use_param = False
55 | napoleon_use_rtype = False
56 |
57 | autodoc_member_order = "bysource"
58 | autosummary_generate = True
59 | todo_include_todos = True
60 |
61 | # Add any paths that contain templates here, relative to this directory.
62 | templates_path = ["_templates"]
63 |
64 | # List of patterns, relative to source directory, that match files and
65 | # directories to ignore when looking for source files.
66 | # This pattern also affects html_static_path and html_extra_path.
67 | exclude_patterns = []
68 |
69 | sphinx_gallery_conf = {
70 | # path to your examples scripts
71 | "examples_dirs": [
72 | "../../tutorial",
73 | ],
74 | # path where to save gallery generated examples
75 | "gallery_dirs": ["gallery"],
76 | "filename_pattern": "/tuto_",
77 | "ignore_pattern": "/_",
78 | # resize the thumbnails, original size = 400x280
79 | "thumbnail_size": (400, 280),
80 | # Remove the "Download all examples" button from the top level gallery
81 | "download_all_examples": False,
82 | # Sort gallery example by file name instead of number of lines (default)
83 | "within_subsection_order": ExampleTitleSortKey,
84 | # directory where function granular galleries are stored
85 | "backreferences_dir": "api/generated/backreferences",
86 | # Modules for which function level galleries are created.
87 | "doc_module": "spyrit",
88 | # Insert links to documentation of objects in the examples
89 | "reference_url": {"spyrit": None},
90 | }
91 |
92 | # -- Options for HTML output -------------------------------------------------
93 |
94 | # The theme to use for HTML and HTML Help pages. See the documentation for
95 | # a list of builtin themes.
96 | html_theme = "sphinx_rtd_theme"
97 |
98 | # directory containing custom CSS file (used to produce bigger thumbnails)
99 |
100 | # on_rtd is whether we are on readthedocs.org
101 | on_rtd = os.environ.get("READTHEDOCS", None) == "True"
102 |
103 | # Add any paths that contain custom static files (such as style sheets) here,
104 | # relative to this directory. They are copied after the builtin static files,
105 | # so a file named "default.css" will overwrite the builtin "default.css".
106 | # By default, this is set to include the _static path.
107 | html_static_path = ["_static"]
108 | html_css_files = ["css/sg_README.css"]
109 |
110 | # The master toctree document.
111 | master_doc = "index"
112 |
113 | html_sidebars = {
114 | "**": ["globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html"]
115 | }
116 |
117 | # http://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_mock_imports
118 | # autodoc_mock_imports incompatible with autosummary somehow
119 | # autodoc_mock_imports = "numpy matplotlib mpl_toolkits scipy torch torchvision Pillow opencv-python imutils PyWavelets pywt wget imageio".split()
120 |
121 |
122 | # exclude all torch.nn.Module members (except forward method) from the docs:
123 | import torch
124 |
125 |
126 | def skip_member_handler(app, what, name, obj, skip, options):
127 | always_document = [ # complete this list if needed by adding methods
128 | "forward", # you *always* want to see documented
129 | ]
130 | if name in always_document:
131 | return None
132 | if name in dir(torch.nn.Module): # used for most of the classes in spyrit
133 | return True
134 | if name in dir(torch.nn.Sequential): # used for FullNet and child classes
135 | return True
136 | return None
137 |
138 |
139 | def setup(app):
140 | app.connect("autodoc-skip-member", skip_member_handler)
141 |
--------------------------------------------------------------------------------
/docs/source/fig/dcnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/dcnet.png
--------------------------------------------------------------------------------
/docs/source/fig/direct_net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/direct_net.png
--------------------------------------------------------------------------------
/docs/source/fig/drunet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/drunet.png
--------------------------------------------------------------------------------
/docs/source/fig/full.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/full.png
--------------------------------------------------------------------------------
/docs/source/fig/lpgd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/lpgd.png
--------------------------------------------------------------------------------
/docs/source/fig/pinvnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/pinvnet.png
--------------------------------------------------------------------------------
/docs/source/fig/pinvnet_cnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/pinvnet_cnn.png
--------------------------------------------------------------------------------
/docs/source/fig/principle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/principle.png
--------------------------------------------------------------------------------
/docs/source/fig/spi_principle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/spi_principle.png
--------------------------------------------------------------------------------
/docs/source/fig/tuto1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto1.png
--------------------------------------------------------------------------------
/docs/source/fig/tuto2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto2.png
--------------------------------------------------------------------------------
/docs/source/fig/tuto3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto3.png
--------------------------------------------------------------------------------
/docs/source/fig/tuto3_pinv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto3_pinv.png
--------------------------------------------------------------------------------
/docs/source/fig/tuto6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto6.png
--------------------------------------------------------------------------------
/docs/source/fig/tuto9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/docs/source/fig/tuto9.png
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. spyrit documentation master file, created by
2 | sphinx-quickstart on Fri Mar 12 11:04:59 2021.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | SPyRiT
7 | #####################################################################
8 |
9 | SPyRiT is a `PyTorch `_-based image reconstruction
10 | package designed for `single-pixel imaging `_. SPyRiT has a `modular organisation `_ and may be useful for other inverse problems.
11 |
12 | Github repository: `openspyrit/spyrit `_
13 |
14 |
15 | Installation
16 | ==================================
17 |
18 | SPyRiT is available for Linux, MacOs and Windows::
19 |
20 | pip install spyrit
21 |
22 | See `here `_ for advanced installation guidelines.
23 |
24 |
25 | Getting started
26 | ==================================
27 |
28 | Please check our `tutorials `_ as well as the `examples `_ on GitHub.
29 |
30 | Cite us
31 | ==================================
32 |
33 | When using SPyRiT in scientific publications, please cite [v3]_ for SPyRiT v3, [v2]_ for SPyRiT v2, and [v1]_ for DC-Net.
34 |
35 | .. [v3] JFJP Abascal, T Baudier, R Phan, A Repetti, N Ducros, "SPyRiT 3.0: an open source package for single-pixel imaging based on deep learning," Preprint (2024).
36 | .. [v2] G Beneti-Martin, L Mahieu-Williame, T Baudier, N Ducros, "OpenSpyrit: an Ecosystem for Reproducible Single-Pixel Hyperspectral Imaging," *Optics Express*, Vol. 31, Issue 10, (2023). `DOI `_.
37 | .. [v1] A Lorente Mur, P Leclerc, F Peyrin, and N Ducros, "Single-pixel image reconstruction from experimental data using neural networks," *Opt. Express*, Vol. 29, Issue 11, 17097-17110 (2021). `DOI `_.
38 |
39 |
40 | Join the project
41 | ==================================
42 |
43 | The list of contributors can be found `here `_. Feel free to contact us by `e-mail `_ for any question. Direct contributions via pull requests (PRs) are welcome.
44 |
45 | .. toctree::
46 | :maxdepth: 2
47 | :hidden:
48 |
49 | single_pixel
50 | organisation
51 |
52 |
53 | .. toctree::
54 | :maxdepth: 2
55 | :caption: Tutorials
56 | :hidden:
57 |
58 | gallery/index
59 |
60 |
61 | Contents
62 | ========
63 |
64 | .. autosummary::
65 | :toctree: _autosummary
66 | :template: spyrit-module-template.rst
67 | :recursive:
68 | :caption: Contents
69 |
70 | spyrit.core
71 | spyrit.misc
72 | spyrit.external
73 |
--------------------------------------------------------------------------------
/docs/source/organisation.rst:
--------------------------------------------------------------------------------
1 | Organisation of the package
2 | ==================================
3 |
4 | .. figure:: fig/direct_net.png
5 | :width: 600
6 | :align: center
7 |
8 |
9 | SPyRiT's typical pipeline.
10 |
11 | SPyRiT allows to simulate measurements and perform image reconstruction using
12 | a full network. A full network includes a measurement operator
13 | :math:`A`, a noise operator :math:`\mathcal{N}`, a preprocessing
14 | operator :math:`B`, a reconstruction operator :math:`\mathcal{R}`,
15 | and a learnable neural network :math:`\mathcal{G}_{\theta}`. All operators
16 | inherit from :class:`torch.nn.Module`.
17 |
18 |
19 | Submodules
20 | -----------------------------------
21 |
22 | SPyRiT has a modular structure with the core functionality organised in the 8 submodules of
23 | :mod:`spyrit.core`.
24 |
25 | 1. :mod:`spyrit.core.meas` provides measurement operators that compute linear measurements corresponding to :math:`A` in Eq. :eq:`eq_acquisition`. It also provides the adjoint and the pseudoinverse of :math:`A`, which are the basis of any reconstruction algorithm.
26 |
27 | 2. :mod:`spyrit.core.noise` provides noise operators corresponding to :math:`\mathcal{N}` in Eq. :eq:`eq_acquisition`.
28 |
29 | 3. :mod:`spyrit.core.prep` provides preprocessing operators for the operator :math:`B` introduced in Eq. :eq:`eq_prep`.
30 |
31 | 4. :mod:`spyrit.core.nnet` provides known neural networks corresponding to :math:`\mathcal{G}` in Eq. :eq:`eq_recon_direct` or Eq. :eq:`eq_pgd_no_Gamma`.
32 |
33 | 5. :mod:`spyrit.core.recon` returns the reconstruction operator corresponding to :math:`\mathcal{R}`.
34 |
35 | 6. :mod:`spyrit.core.train` provides the functionality to solve the minimisation problem of Eq. :eq:`eq_train`.
36 |
37 | 7. :mod:`spyrit.core.warp` contains the operators used for dynamic acquisitions.
38 |
39 | 8. :mod:`spyrit.core.torch` contains utility functions.
40 |
41 | In addition, :mod:`spyrit.misc` contains various utility functions for Numpy / PyTorch that can be used independently of the core functions.
42 |
43 | Finally, :mod:`spyrit.external` provides access to `DR-UNet `_.
44 |
--------------------------------------------------------------------------------
/docs/source/sg_execution_times.rst:
--------------------------------------------------------------------------------
1 |
2 | :orphan:
3 |
4 | .. _sphx_glr_sg_execution_times:
5 |
6 |
7 | Computation times
8 | =================
9 | **00:00.000** total execution time for 7 files **from all galleries**:
10 |
11 | .. container::
12 |
13 | .. raw:: html
14 |
15 |
19 |
20 |
21 |
22 |
27 |
28 | .. list-table::
29 | :header-rows: 1
30 | :class: table table-striped sg-datatable
31 |
32 | * - Example
33 | - Time
34 | - Mem (MB)
35 | * - :ref:`sphx_glr_gallery_tuto_01_acquisition_operators.py` (``..\..\tutorial\tuto_01_acquisition_operators.py``)
36 | - 00:00.000
37 | - 0.0
38 | * - :ref:`sphx_glr_gallery_tuto_02_pseudoinverse_linear.py` (``..\..\tutorial\tuto_02_pseudoinverse_linear.py``)
39 | - 00:00.000
40 | - 0.0
41 | * - :ref:`sphx_glr_gallery_tuto_03_pseudoinverse_cnn_linear.py` (``..\..\tutorial\tuto_03_pseudoinverse_cnn_linear.py``)
42 | - 00:00.000
43 | - 0.0
44 | * - :ref:`sphx_glr_gallery_tuto_04_train_pseudoinverse_cnn_linear.py` (``..\..\tutorial\tuto_04_train_pseudoinverse_cnn_linear.py``)
45 | - 00:00.000
46 | - 0.0
47 | * - :ref:`sphx_glr_gallery_tuto_05_acquisition_split_measurements.py` (``..\..\tutorial\tuto_05_acquisition_split_measurements.py``)
48 | - 00:00.000
49 | - 0.0
50 | * - :ref:`sphx_glr_gallery_tuto_06_dcnet_split_measurements.py` (``..\..\tutorial\tuto_06_dcnet_split_measurements.py``)
51 | - 00:00.000
52 | - 0.0
53 | * - :ref:`sphx_glr_gallery_tuto_bonus_advanced_methods_colab.py` (``..\..\tutorial\tuto_bonus_advanced_methods_colab.py``)
54 | - 00:00.000
55 | - 0.0
56 |
--------------------------------------------------------------------------------
/docs/source/single_pixel.rst:
--------------------------------------------------------------------------------
1 | Single-pixel imaging
2 | ==================================
3 | .. _principle:
4 | .. figure:: fig/spi_principle.png
5 | :width: 800
6 | :align: center
7 |
8 | Overview of the principle of single-pixel imaging.
9 |
10 |
11 | Simulation of the measurements
12 | -----------------------------------
13 | Single-pixel imaging aims to recover an unknown image :math:`x\in\mathbb{R}^N` from a few noisy observations
14 |
15 | .. math::
16 | m \approx Hx,
17 |
18 | where :math:`H\colon \mathbb{R}^{M\times N}` is a linear measurement operator, :math:`M` is the number of measurements and :math:`N` is the number of pixels in the image.
19 |
20 | In practice, measurements are obtained by uploading a set of light patterns onto a spatial light modulator (e.g., a digital micromirror device (DMD), see :ref:`principle`). Therefore, only positive patterns can be implemented. We model the actual acquisition process as
21 |
22 |
23 | .. math::
24 | :label: eq_acquisition
25 |
26 | y = \mathcal{N}(Ax)
27 |
28 | where :math:`\mathcal{N} \colon \mathbb{R}^J \to \mathbb{R}^J` represents a noise operator (e.g., Poisson or Poisson-Gaussian), :math:`A \in \mathbb{R}_+^{J\times N}` is the actual acquisition operator that models the (positive) DMD patterns, and :math:`J` is the number of DMD patterns.
29 |
30 | Handling non negativity with pre-processing
31 | ----------------------------------------------------------------------
32 | We may preprocess the measurements before reconstruction to transform the actual measurements into the target measurements
33 |
34 | .. math::
35 | :label: eq_prep
36 |
37 | m = By \approx Hx
38 |
39 |
40 | where :math:`B\colon\mathbb{R}^{J}\to \mathbb{R}^{M}` is the preprocessing operator chosen such that :math:`BA=H`. Note that the noise of the preprocessed measurements :math:`m=By` is not the same as that of the actual measurements :math:`y`.
41 |
42 | Data-driven image reconstruction
43 | -----------------------------------
44 | Data-driven methods based on deep learning aim to find an estimate :math:`x^*\in \mathbb{R}^N` of the unknown image :math:`x` from the preprocessed measurements :math:`By`, using a reconstruction operator :math:`\mathcal{R}_{\theta^*} \colon \mathbb{R}^M \to \mathbb{R}^N`
45 |
46 | .. math::
47 | \mathcal{R}_{\theta^*}(m) = x^* \approx x,
48 |
49 | where :math:`\theta^*` represents the parameters learned during a training procedure.
50 |
51 | Learning phase
52 | -----------------------------------
53 | In the case of supervised learning, it is assumed that a training dataset :math:`\{x^{(i)},y^{(i)}\}_{1 \le i \le I}` of :math:`I` pairs of ground truth images in :math:`\mathbb{R}^N` and measurements in :math:`\mathbb{R}^M` is available}. :math:`\theta^*` is then obtained by solving
54 |
55 | .. math::
56 | :label: eq_train
57 |
58 | \min_{\theta}\,{\sum_{i =1}^I \mathcal{L}\left(x^{(i)},\mathcal{R}_\theta(By^{(i)})\right)},
59 |
60 |
61 | where :math:`\mathcal{L}` is the training loss (e.g., squared error). In the case where only ground truth images :math:`\{x^{(i)}\}_{1 \le i \le I}` are available, the associated measurements are simulated as :math:`y^{(i)} = \mathcal{N}(Ax^{(i)})`, :math:`1 \le i \le I`.
62 |
63 |
64 | Reconstruction operator
65 | -----------------------------------
66 | A simple yet efficient method consists in correcting a traditional (e.g. linear) reconstruction by a data-driven nonlinear step
67 |
68 | .. math::
69 | :label: eq_recon_direct
70 |
71 | \mathcal{R}_\theta = \mathcal{G}_\theta \circ \mathcal{R},
72 |
73 | where :math:`\mathcal{R}\colon\mathbb{R}^{M}\to\mathbb{R}^N` is a traditional hand-crafted (e.g., regularized) reconstruction operator and :math:`\mathcal{G}_\theta\colon\mathbb{R}^{N}\to\mathbb{R}^N` is a nonlinear neural network that acts in the image domain.
74 |
75 | Algorithm unfolding consists in defining :math:`\mathcal{R}_\theta` from an iterative scheme
76 |
77 | .. math::
78 | :label: eq_pgd_no_Gamma
79 |
80 | \mathcal{R}_\theta = \mathcal{R}_{\theta_K} \circ ... \circ \mathcal{R}_{\theta_1},
81 |
82 | where :math:`\mathcal{R}_{\theta_k}` can be interpreted as the computation of the :math:`k`-th iteration of the iterative scheme and :math:`\theta = \bigcup_{k} \theta_k`.
83 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=67",
4 | "wheel",
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [tool.setuptools]
9 | include-package-data = true
10 | zip-safe = false
11 | script-files = [
12 | "tutorial/tuto_01_a_acquisition_operators.py",
13 | "tutorial/tuto_01_b_splitting.py",
14 | "tutorial/tuto_01_c_HadamSplit2d.py",
15 | "tutorial/tuto_02_noise.py",
16 | "tutorial/tuto_03_pseudoinverse_linear.py"
17 | ]
18 |
19 | [tool.setuptools.dynamic]
20 | readme = {file = ["README.md"]}
21 |
22 | [tool.setuptools.packages]
23 | find = {} # Scanning implicit namespaces is active by default
24 |
25 | [project]
26 | name = "spyrit"
27 | version = "3.0.1"
28 | dynamic = ["readme"]
29 | authors = [{name = "Nicolas Ducros", email = "Nicolas.Ducros@insa-lyon.fr"}]
30 | description = "Toolbox for deep image reconstruction"
31 | license = {file = "LICENSE.md"}
32 | classifiers = [
33 | "Programming Language :: Python",
34 | "Programming Language :: Python :: 3.9",
35 | "Programming Language :: Python :: 3.10",
36 | "Programming Language :: Python :: 3.11",
37 | "Programming Language :: Python :: Implementation :: PyPy",
38 | "Operating System :: OS Independent",
39 | ]
40 | dependencies = [
41 | "numpy",
42 | "matplotlib",
43 | "scipy",
44 | "torch",
45 | "torchvision",
46 | "Pillow",
47 | "PyWavelets",
48 | "wget",
49 | "sympy",
50 | "imageio",
51 | "astropy",
52 | "requests",
53 | "tqdm",
54 | "girder-client",
55 | ]
56 | requires-python = ">=3.9"
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | matplotlib
3 | scipy
4 | torch
5 | torchvision
6 | Pillow
7 | opencv-python
8 | imutils
9 | PyWavelets
10 | wget
11 | sympy
12 | imageio
13 | astropy
14 | tensorboard
15 | sphinx_gallery
16 | sphinx_rtd_theme
17 | girder-client
18 | gdown==v4.6.3
19 |
--------------------------------------------------------------------------------
/spyrit/__init__.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # This software is distributed under the terms
3 | # of the GNU Lesser General Public Licence (LGPL)
4 | # See LICENSE.md for further details
5 | # -----------------------------------------------------------------------------
6 |
7 | # from __future__ import division, print_function, absolute_import
8 | # from distutils.version import LooseVersion
9 | #
10 | #
11 | # from . import spyritest
12 | #
13 | #
14 | # __all__ = [s for s in dir() if not s.startswith('_')]
15 |
16 | # from . import core
17 | # from . import misc
18 | # from . import external
19 |
--------------------------------------------------------------------------------
/spyrit/core/__init__.py:
--------------------------------------------------------------------------------
1 | """Core module for Spyrit package, containing the main classes and functions."""
2 |
--------------------------------------------------------------------------------
/spyrit/dev/meas.py:
--------------------------------------------------------------------------------
1 | # ==================================================================================
2 | class Linear_shift(Linear):
3 | # ==================================================================================
4 | r"""Linear with shifted pattern matrix of size :math:`(M+1,N)` and :math:`Perm` matrix of size :math:`(N,N)`.
5 |
6 | Args:
7 | - Hsub: subsampled Hadamard matrix
8 | - Perm: Permuation matrix
9 |
10 | Shape:
11 | - Input1: :math:`(M, N)`
12 | - Input2: :math:`(N, N)`
13 |
14 | Example:
15 | >>> Hsub = np.array(np.random.random([400,32*32]))
16 | >>> Perm = np.array(np.random.random([32*32,32*32]))
17 | >>> FO_Shift = Linear_shift(Hsub, Perm)
18 |
19 | """
20 |
21 | def __init__(self, Hsub, Perm):
22 | super().__init__(Hsub)
23 |
24 | # Todo: Use index rather than permutation (see misc.walsh_hadamard)
25 | self.Perm = nn.Linear(self.N, self.N, False)
26 | self.Perm.weight.data = torch.from_numpy(Perm.T)
27 | self.Perm.weight.data = self.Perm.weight.data.float()
28 | self.Perm.weight.requires_grad = False
29 |
30 | H_shift = torch.cat((torch.ones((1, self.N)), (self.Hsub.weight.data + 1) / 2))
31 |
32 | self.H_shift = nn.Linear(self.N, self.M + 1, False)
33 | self.H_shift.weight.data = H_shift # include the all-one pattern
34 | self.H_shift.weight.data = self.H_shift.weight.data.float() # keep ?
35 | self.H_shift.weight.requires_grad = False
36 |
37 | def forward(self, x: torch.tensor) -> torch.tensor:
38 | r"""Applies Linear transform such that :math:`y = \begin{bmatrix}{1}\\{H_{sub}}\end{bmatrix}x`.
39 |
40 | Args:
41 | :math:`x`: batch of images.
42 |
43 | Shape:
44 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of pixels in the image.
45 | - Output: :math:`(b*c, M+1)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`M+1` the number of measurements + 1.
46 |
47 | Example:
48 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float)
49 | >>> y = FO_Shift(x)
50 | >>> print(y.shape)
51 | torch.Size([10, 401])
52 | """
53 | # input x is a set of images with shape (b*c, N)
54 | # output input is a set of measurement vector with shape (b*c, M+1)
55 | x = self.H_shift(x)
56 | return x
57 |
58 | # x_shift = super().forward(x) - x_dark.expand(x.shape[0],self.M) # (H-1/2)x
59 |
60 |
61 | # ==================================================================================
62 | class Linear_pos(Linear):
63 | # ==================================================================================
64 | r"""Linear with Permutation Matrix :math:`Perm` of size :math:`(N,N)`.
65 |
66 | Args:
67 | - Hsub: subsampled Hadamard matrix
68 | - Perm: Permuation matrix
69 |
70 | Shape:
71 | - Input1: :math:`(M, N)`
72 | - Input2: :math:`(N, N)`
73 |
74 | Example:
75 | >>> Hsub = np.array(np.random.random([400,32*32]))
76 | >>> Perm = np.array(np.random.random([32*32,32*32]))
77 | >>> meas_op_pos = Linear_pos(Hsub, Perm)
78 | """
79 |
80 | def __init__(self, Hsub, Perm):
81 | super().__init__(Hsub)
82 |
83 | # Todo: Use index rather than permutation (see misc.walsh_hadamard)
84 | self.Perm = nn.Linear(self.N, self.N, False)
85 | self.Perm.weight.data = torch.from_numpy(Perm.T)
86 | self.Perm.weight.data = self.Perm.weight.data.float()
87 | self.Perm.weight.requires_grad = False
88 |
89 | def forward(self, x: torch.tensor) -> torch.tensor:
90 | r"""Computes :math:`y` according to :math:`y=0.5(H_{sub}x+\sum_{j=1}^{N}x_{j})` where :math:`j` is the pixel (column) index of :math:`x`.
91 |
92 | Args:
93 | :math:`x`: Batch of images.
94 |
95 | Shape:
96 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of pixels in the image.
97 | - Output: :math:`(b*c, M)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`M` the number of measurements.
98 |
99 | Example:
100 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float)
101 | >>> y = meas_op_pos(x)
102 | >>> print(y.shape)
103 | torch.Size([100, 400])
104 | """
105 | # input x is a set of images with shape (b*c, N)
106 | # output is a set of measurement vectors with shape (b*c, M)
107 |
108 | # compute 1/2(H+1)x = 1/2 HX + 1/2 1x
109 | x = super().forward(x) + x.sum(dim=1, keepdim=True).expand(-1, self.M)
110 | x *= 0.5
111 |
112 | return x
113 |
114 |
115 | # ==================================================================================
116 | class Linear_shift_had(Linear_shift):
117 | # ==================================================================================
118 | r"""Linear_shift operator with inverse method.
119 |
120 | Args:
121 | - Hsub: subsampled Hadamard matrix
122 | - Perm: Permuation matrix
123 |
124 | Shape:
125 | - Input1: :math:`(M, N)`
126 | - Input2: :math:`(N, N)`.
127 |
128 | Example:
129 | >>> Hsub = np.array(np.random.random([400,32*32]))
130 | >>> Perm = np.array(np.random.random([32*32,32*32]))
131 | >>> FO_Shift_Had = Linear_shift_had(Hsub, Perm)
132 | """
133 |
134 | def __init__(self, Hsub, Perm):
135 | super().__init__(Hsub, Perm)
136 |
137 | def inverse(self, x: torch.tensor, n: Union[None, int] = None) -> torch.tensor:
138 | r"""Inverse transform such that :math:`x = \frac{1}{N}H_{sub}y`.
139 |
140 | Args:
141 | :math:`x`: Batch of completed measurements.
142 |
143 | Shape:
144 | - Input: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of measurements.
145 | - Output: :math:`(b*c, N)` with :math:`b` the batch size, :math:`c` the number of channels, and :math:`N` the number of reconstructed. pixels.
146 |
147 | Example:
148 | >>> x = torch.tensor(np.random.random([10,32*32]), dtype=torch.float)
149 | >>> x_reconstruct = FO_Shift_Had.inverse(y_pad)
150 | >>> print(x_reconstruct.shape)
151 | torch.Size([10, 1024])
152 | """
153 | # rearrange the terms + inverse transform
154 | # maybe needs to be initialised with a permutation matrix as well!
155 | # Permutation matrix may be sparsified when sparse tensors are no longer in
156 | # beta (as of pytorch 1.11, it is still in beta).
157 |
158 | # --> Use index rather than permutation (see misc.walsh_hadamard)
159 |
160 | # input x is a set of **measurements** with shape (b*c, N)
161 | # output is a set of **images** with shape (b*c, N)
162 | bc, N = x.shape
163 | x = self.Perm(x)
164 |
165 | if n is None:
166 | n = int(np.sqrt(N))
167 |
168 | # Inverse transform
169 | x = x.reshape(bc, 1, n, n)
170 | x = (
171 | 1 / self.N * walsh2_torch(x)
172 | ) # todo: initialize with 1D transform to speed up
173 | x = x.reshape(bc, N)
174 | return x
175 |
--------------------------------------------------------------------------------
/spyrit/dev/prep.py:
--------------------------------------------------------------------------------
1 | # ==================================================================================
2 | class Preprocess_shift_poisson(nn.Module): # header needs to be updated!
3 | # ==================================================================================
4 | r"""Preprocess the measurements acquired using shifted patterns corrupted
5 | by Poisson noise
6 |
7 | Computes:
8 | m = (2 m_shift - m_offset)/N_0
9 | var = 4*Diag(m_shift + m_offset)/alpha**2
10 | Warning: dark measurement is assumed to be the 0-th entry of raw measurements
11 |
12 | Args:
13 | - :math:`alpha`: noise level
14 | - :math:`M`: number of measurements
15 | - :math:`N`: number of image pixels
16 |
17 | Shape:
18 | - Input1: scalar
19 | - Input2: scalar
20 | - Input3: scalar
21 |
22 | Example:
23 | >>> PSP = Preprocess_shift_poisson(9, 400, 32*32)
24 | """
25 |
26 | def __init__(self, alpha, M, N):
27 | super().__init__()
28 | self.alpha = alpha
29 | self.N = N
30 | self.M = M
31 |
32 | def forward(self, x: torch.tensor, meas_op: Linear) -> torch.tensor:
33 | r"""
34 |
35 | Warning:
36 | - The offset measurement is the 0-th entry of the raw measurements.
37 |
38 | Args:
39 | - :math:`x`: Batch of images in Hadamard domain shifted by 1
40 | - :math:`meas_op`: Forward_operator
41 |
42 | Shape:
43 | - Input: :math:`(b*c, M+1)`
44 | - Output: :math:`(b*c, M)`
45 |
46 | Example:
47 | >>> Hsub = np.array(np.random.random([400,32*32]))
48 | >>> FO = Forward_operator(Hsub)
49 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float)
50 | >>> y_PSP = PSP(x, FO)
51 | >>> print(y_PSP.shape)
52 | torch.Size([10, 400])
53 |
54 | """
55 | y = self.offset(x)
56 | x = 2 * x[:, 1:] - y.expand(
57 | x.shape[0], self.M
58 | ) # Warning: dark measurement is the 0-th entry
59 | x = x / self.alpha
60 | x = 2 * x - meas_op.H(
61 | torch.ones(x.shape[0], self.N).to(x.device)
62 | ) # to shift images in [-1,1]^N
63 | return x
64 |
65 | def sigma(self, x):
66 | r"""
67 | Args:
68 | - :math:`x`: Batch of images in Hadamard domain shifted by 1
69 |
70 | Shape:
71 | - Input: :math:`(b*c, M+1)`
72 |
73 | Example:
74 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float)
75 | >>> sigma_PSP = PSP.sigma(x)
76 | >>> print(sigma_PSP.shape)
77 | torch.Size([10, 400])
78 | """
79 | # input x is a set of measurement vectors with shape (b*c, M+1)
80 | # output is a set of measurement vectors with shape (b*c,M)
81 | y = self.offset(x)
82 | x = 4 * x[:, 1:] + y.expand(x.shape[0], self.M)
83 | x = x / (self.alpha**2)
84 | x = 4 * x # to shift images in [-1,1]^N
85 | return x
86 |
87 | def cov(self, x): # return a full matrix ? It is such that Diag(a) + b
88 | return x
89 |
90 | def sigma_from_image(self, x, meas_op): # should check this!
91 | # input x is a set of images with shape (b*c, N)
92 | # input meas_op is a Forward_operator
93 | x = meas_op.H(x)
94 | y = self.offset(x)
95 | x = x[:, 1:] + y.expand(x.shape[0], self.M)
96 | x = x / (self.alpha) # here the alpha contribution is not squared.
97 | return x
98 |
99 | def offset(self, x):
100 | r"""Get offset component from bach of shifted images.
101 |
102 | Args:
103 | - :math:`x`: Batch of shifted images
104 |
105 | Shape:
106 | - Input: :math:`(bc, M+1)`
107 | - Output: :math:`(bc, 1)`
108 |
109 | Example:
110 | >>> x = torch.tensor(np.random.random([10, 400+1]), dtype=torch.float)
111 | >>> y = PSP.offset(x)
112 | >>> print(y.shape)
113 | torch.Size([10, 1])
114 |
115 | """
116 | y = x[:, 0, None]
117 | return y
118 |
119 |
120 | # ==================================================================================
121 | class Preprocess_pos_poisson(nn.Module): # header needs to be updated!
122 | # ==================================================================================
123 | r"""Preprocess the measurements acquired using positive (shifted) patterns
124 | corrupted by Poisson noise
125 |
126 | The output value of the layer with input size :math:`(B*C, M)` can be
127 | described as:
128 |
129 | .. math::
130 | \text{out}((B*C)_i, M_j}) = 2*\text{input}((B*C)_i, M_j}) -
131 | \sum_{k = 1}^{M-1} \text{input}((B*C)_i, M_k})
132 |
133 | The output size of the layer is :math:`(B*C, M)`, which is the imput size
134 |
135 |
136 | Warning:
137 | dark measurement is assumed to be the 0-th entry of raw measurements
138 |
139 | Args:
140 | - :math:`alpha`: noise level
141 | - :math:`M`: number of measurements
142 | - :math:`N`: number of image pixels
143 |
144 | Shape:
145 | - Input1: scalar
146 | - Input2: scalar
147 | - Input3: scalar
148 |
149 | Example:
150 | >>> PPP = Preprocess_pos_poisson(9, 400, 32*32)
151 |
152 | """
153 |
154 | def __init__(self, alpha, M, N):
155 | super().__init__()
156 | self.alpha = alpha
157 | self.N = N
158 | self.M = M
159 |
160 | def forward(self, x: torch.tensor, meas_op: Linear) -> torch.tensor:
161 | r"""
162 | Args:
163 | - :math:`x`: noise level
164 | - :math:`meas_op`: Forward_operator
165 |
166 | Shape:
167 | - Input1: :math:`(bc, M)`
168 | - Input2: None
169 | - Output: :math:`(bc, M)`
170 |
171 | Example:
172 | >>> Hsub = np.array(np.random.random([400,32*32]))
173 | >>> meas_op = Forward_operator(Hsub)
174 | >>> x = torch.tensor(np.random.random([10, 400]), dtype=torch.float)
175 | >>> y = PPP(x, meas_op)
176 | torch.Size([10, 400])
177 |
178 | """
179 | y = self.offset(x)
180 | x = 2 * x - y.expand(-1, self.M)
181 | x = x / self.alpha
182 | x = 2 * x - meas_op.H(
183 | torch.ones(x.shape[0], self.N).to(x.device)
184 | ) # to shift images in [-1,1]^N
185 | return x
186 |
187 | def offset(self, x):
188 | r"""Get offset component from bach of shifted images.
189 |
190 | Args:
191 | - :math:`x`: Batch of shifted images
192 |
193 | Shape:
194 | - Input: :math:`(bc, M)`
195 | - Output: :math:`(bc, 1)`
196 |
197 | Example:
198 | >>> x = torch.tensor(np.random.random([10, 400]), dtype=torch.float)
199 | >>> y = PPP.offset(x)
200 | >>> print(y.shape)
201 | torch.Size([10, 1])
202 |
203 | """
204 | y = 2 / (self.M - 2) * x[:, 1:].sum(dim=1, keepdim=True)
205 | return y
206 |
--------------------------------------------------------------------------------
/spyrit/external/__init__.py:
--------------------------------------------------------------------------------
1 | """This module uses a modified version of the Unet presented in https://github.com/cszn/DPIR/blob/master/models/network_unet.py"""
2 |
3 | # from . import drunet
4 |
--------------------------------------------------------------------------------
/spyrit/hadamard_matrix/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/spyrit/hadamard_matrix/__init__.py
--------------------------------------------------------------------------------
/spyrit/hadamard_matrix/create_hadamard_matrix_with_sage.py:
--------------------------------------------------------------------------------
1 | from sage.all import *
2 | from sage.combinat.matrices.hadamard_matrix import (
3 | hadamard_matrix,
4 | skew_hadamard_matrix,
5 | is_hadamard_matrix,
6 | is_skew_hadamard_matrix,
7 | )
8 | import numpy as np
9 | import glob
10 |
11 | # Get all Hadamard matrices of order 4*n for Sage
12 | # https://github.com/sagemath/sage/
13 | # run in conda env with:
14 | # sage create_hadamard_matrix_with_sage.py
15 |
16 | k = Integer(2000)
17 | for n in range(Integer(1), k + Integer(1)):
18 | try:
19 | H = hadamard_matrix(Integer(4) * n, check=False)
20 |
21 | if is_hadamard_matrix(H):
22 | print(n * 4)
23 | a = np.array(H)
24 | a[a == -1] = 0
25 | a = a.astype(bool)
26 |
27 | # find the files with that order
28 | files = glob.glob("had." + str(n * 4) + "*.npz")
29 | already_saved = False
30 | for file in files:
31 | b = np.load(file)
32 | if a == b:
33 | already_saved = True
34 | if already_saved:
35 | break
36 |
37 | if not already_saved:
38 | name = "had." + str(n * 4) + ".sage.npz"
39 | np.savez_compressed(name, a)
40 | except ValueError as e:
41 | pass
42 |
--------------------------------------------------------------------------------
/spyrit/hadamard_matrix/download_hadamard_matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import requests
3 | import os
4 | import glob
5 | import importlib.util
6 | import tqdm
7 |
8 |
9 | def download_from_girder():
10 | """
11 | Download Hadamard matrices from the Girder repository into hadamard_matrix folder.
12 | """
13 |
14 | hadamard_matrix_path = os.path.dirname(__file__)
15 | if os.path.isfile(
16 | os.path.join(hadamard_matrix_path, "had.236.sage.cooper-wallis.npz")
17 | ):
18 | return
19 | print("Downloading Hadamard matrices (>2300) from Girder repository...")
20 | print(
21 | "The matrices were downloaded from http://neilsloane.com/hadamard/ Sloane et al."
22 | )
23 | import girder_client
24 |
25 | gc = girder_client.GirderClient(
26 | apiUrl="https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
27 | )
28 |
29 | collection_id = "66796d3cbaa5a90007058946"
30 | folder_id = "6800c6891240141f6aa53845"
31 | limit = 50 # Number of items to retrieve per request
32 | offset = 0 # Starting point
33 | pbar = tqdm.tqdm(total=0)
34 |
35 | while True:
36 | items = gc.get(
37 | "item",
38 | parameters={
39 | "parentType": "collection",
40 | "parentId": collection_id,
41 | "folderId": folder_id,
42 | "limit": limit,
43 | "offset": offset,
44 | },
45 | )
46 | if not items:
47 | break
48 | pbar.total += len(items)
49 | pbar.refresh()
50 | for item in items:
51 | files = gc.get(f'item/{item["_id"]}/files')
52 | for file in files:
53 | pbar.update(1)
54 | gc.downloadFile(
55 | file["_id"], os.path.join(hadamard_matrix_path, file["name"])
56 | )
57 | offset += limit
58 | pbar.close()
59 |
60 |
61 | def read_text_file_from_url(url):
62 | response = requests.get(url)
63 | content = response.text
64 | return content
65 |
66 |
67 | def download_from_sloane():
68 | from selenium import webdriver
69 | from selenium.webdriver.common.by import By
70 | from selenium.webdriver.chrome.service import Service
71 | from webdriver_manager.chrome import ChromeDriverManager
72 |
73 | # Set up the WebDriver
74 | driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()))
75 |
76 | # Open the website
77 | driver.get("http://neilsloane.com/hadamard/")
78 |
79 | # Find all links to Hadamard matrices
80 | links = driver.find_elements(By.XPATH, "//a[contains(@href, 'had.')]")
81 |
82 | # Extract the URLs
83 | hadamard_urls = set([link.get_attribute("href") for link in links])
84 |
85 | # Print the URLs
86 | for url in hadamard_urls:
87 | print(url)
88 | # Read the text file from the URL
89 | file_content = read_text_file_from_url(url)
90 | # Split the content into lines
91 | lines = file_content.splitlines()
92 |
93 | # Print the content of the file
94 | if "+" in file_content or "0" in file_content or "-1" in file_content:
95 | if len(lines) > 1:
96 | size = len(lines[1])
97 | else:
98 | size = len(lines[0])
99 | array = []
100 | for line in lines:
101 | if len(line) == size:
102 | line = line.replace("-1", "0")
103 | tmp = []
104 | for e in line:
105 | if e == "+" or e == "1":
106 | tmp += [1]
107 | elif e == "-" or e == "0":
108 | tmp += [0]
109 | elif e == " ":
110 | pass
111 | else:
112 | print("Error during reading of " + url)
113 | array += [tmp]
114 | np_array = np.array(array, dtype=bool)
115 |
116 | name = url.split("/")[-1][:-4]
117 | order = int(name.split(".")[1])
118 |
119 | # Check if the file already exists
120 | files = glob.glob("had." + str(order) + "*.npz")
121 | already_saved = False
122 | for file in files:
123 | b = np.load(file)
124 | if np.all(np_array == b):
125 | already_saved = True
126 | if already_saved:
127 | break
128 |
129 | if not already_saved:
130 | np.savez_compressed(name + ".npz", np_array)
131 | else:
132 | print("no ok for " + url)
133 | # print(file_content)
134 |
135 | # Close the WebDriver
136 | driver.quit()
137 |
138 |
139 | if __name__ == "__main__":
140 | download_from_sloane()
141 |
--------------------------------------------------------------------------------
/spyrit/misc/__init__.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # This software is distributed under the terms
3 | # of the GNU Lesser General Public Licence (LGPL)
4 | # See LICENSE.md for further details
5 | # -----------------------------------------------------------------------------
6 |
7 | """Contains miscellaneous Numpy / Pytorch functions useful for spyrit.core."""
8 |
9 | # from . import color
10 | # from . import data_visualisation
11 | # from . import disp
12 | # from . import examples
13 | # from . import matrix_tools
14 | # from . import metrics
15 | # from . import pattern_choice
16 | # from . import sampling
17 | # from . import statistics
18 | # from . import walsh_hadamard
19 |
--------------------------------------------------------------------------------
/spyrit/misc/color.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mon Dec 2 20:53:59 2024
4 |
5 | @author: ducros
6 | """
7 | import numpy as np
8 | import warnings
9 | from typing import Tuple
10 |
11 | import warnings
12 | from typing import Tuple
13 |
14 | import numpy as np
15 |
16 |
17 | # %%
18 | def wavelength_to_rgb(
19 | wavelength: float, gamma: float = 0.8
20 | ) -> Tuple[float, float, float]:
21 | """Converts wavelength to RGB.
22 |
23 | Based on https://gist.github.com/friendly/67a7df339aa999e2bcfcfec88311abfc.
24 | Itself based on code by Dan Bruton:
25 | http://www.physics.sfasu.edu/astro/color/spectra.html
26 |
27 | Args:
28 | wavelength (float):
29 | Single wavelength to be converted to RGB.
30 | gamma (float, optional):
31 | Gamma correction. Defaults to 0.8.
32 |
33 | Returns:
34 | Tuple[float, float, float]:
35 | RGB value.
36 | """
37 |
38 | if np.min(wavelength) < 380 or np.max(wavelength) > 750:
39 | warnings.warn("Some wavelengths are not in the visible range [380-750] nm")
40 |
41 | if wavelength >= 380 and wavelength <= 440:
42 | attenuation = 0.3 + 0.7 * (wavelength - 380) / (440 - 380)
43 | R = ((-(wavelength - 440) / (440 - 380)) * attenuation) ** gamma
44 | G = 0.0
45 | B = (1.0 * attenuation) ** gamma
46 |
47 | elif wavelength >= 440 and wavelength <= 490:
48 | R = 0.0
49 | G = ((wavelength - 440) / (490 - 440)) ** gamma
50 | B = 1.0
51 |
52 | elif wavelength >= 490 and wavelength <= 510:
53 | R = 0.0
54 | G = 1.0
55 | B = (-(wavelength - 510) / (510 - 490)) ** gamma
56 |
57 | elif wavelength >= 510 and wavelength <= 580:
58 | R = ((wavelength - 510) / (580 - 510)) ** gamma
59 | G = 1.0
60 | B = 0.0
61 |
62 | elif wavelength >= 580 and wavelength <= 645:
63 | R = 1.0
64 | G = (-(wavelength - 645) / (645 - 580)) ** gamma
65 | B = 0.0
66 |
67 | elif wavelength >= 645 and wavelength <= 750:
68 | attenuation = 0.3 + 0.7 * (750 - wavelength) / (750 - 645)
69 | R = (1.0 * attenuation) ** gamma
70 | G = 0.0
71 | B = 0.0
72 |
73 | else:
74 | R = 0.0
75 | G = 0.0
76 | B = 0.0
77 |
78 | return R, G, B
79 |
80 |
81 | def wavelength_to_rgb_mat(wav_range, gamma=1):
82 |
83 | rgb_mat = np.zeros((len(wav_range), 3))
84 |
85 | for i, wav in enumerate(wav_range):
86 | rgb_mat[i, :] = wavelength_to_rgb(wav, gamma)
87 |
88 | return rgb_mat
89 |
90 |
91 | def spectral_colorization(M_gray, wav, axis=None):
92 | """
93 | Colorize the last dimension of an array
94 |
95 | Args:
96 | M_gray (np.ndarray): Grayscale array where the last dimension is the
97 | spectral dimension. This is an A-by-C array, where A can indicate multiple
98 | dimensions (e.g., 4-by-3-by-7) and C is the number of spectral channels.
99 |
100 | wav (np.ndarray): Wavelenth. This is a 1D array of size C.
101 |
102 | axis (None or int or tuple of ints, optional): Axis or axes along which
103 | the grayscale input is normalized. By default, global normalization
104 | across all axes is considered.
105 |
106 | Returns:
107 | M_color (np.ndarray): Color array with an extra dimension. This is an A-by-C-by-3 array.
108 |
109 | """
110 |
111 | # Normalize to adjust contrast
112 | M_gray_min = M_gray.min(keepdims=True, axis=axis)
113 | M_gray_max = M_gray.max(keepdims=True, axis=axis)
114 | M_gray = (M_gray - M_gray_min) / (M_gray_max - M_gray_min)
115 |
116 | #
117 | rgb_mat = wavelength_to_rgb_mat(wav, gamma=1)
118 | M_red = M_gray @ np.diag(rgb_mat[:, 0])
119 | M_green = M_gray @ np.diag(rgb_mat[:, 1])
120 | M_blue = M_gray @ np.diag(rgb_mat[:, 2])
121 |
122 | M_color = np.stack((M_red, M_green, M_blue), axis=-1)
123 |
124 | return M_color
125 |
126 |
127 | def colorize(im, color, clip_percentile=0.1):
128 | """
129 | Helper function to create an RGB image from a single-channel image using a
130 | specific color.
131 | """
132 | # Check that we just have a 2D image
133 | if im.ndim > 2 and im.shape[2] != 1:
134 | raise ValueError("This function expects a single-channel image!")
135 |
136 | # Rescale the image according to how we want to display it
137 | im_scaled = im.astype(np.float32) - np.percentile(im, clip_percentile)
138 | im_scaled = im_scaled / np.percentile(im_scaled, 100 - clip_percentile)
139 | print(
140 | f"Norm: min={np.percentile(im, clip_percentile)}, max={np.percentile(im_scaled, 100 - clip_percentile)}"
141 | )
142 | print(f"New: min={im_scaled.min()}, max={im_scaled.max()}")
143 | im_scaled = np.clip(im_scaled, 0, 1)
144 |
145 | # Need to make sure we have a channels dimension for the multiplication to work
146 | im_scaled = np.atleast_3d(im_scaled)
147 |
148 | # Reshape the color (here, we assume channels last)
149 | color = np.asarray(color).reshape((1, 1, -1))
150 | return im_scaled * color
151 |
--------------------------------------------------------------------------------
/spyrit/misc/data_visualisation.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # This software is distributed under the terms
3 | # of the GNU Lesser General Public Licence (LGPL)
4 | # See LICENSE.md for further details
5 | # -----------------------------------------------------------------------------
6 |
7 | #!/usr/bin/env python3
8 | # -*- coding: utf-8 -*-
9 | """
10 | Created on Thu Jan 16 08:56:13 2020
11 |
12 | @author: crombez
13 | """
14 |
15 | from astropy.io import fits
16 | import matplotlib.pyplot as plt
17 |
18 |
19 | # Show basic information of a fits image acquired
20 | # with the andor zyla and plot the image
21 | def show_image_and_infos(path, file):
22 | hdul = fits.open(path + file)
23 | show_images_infos(path, file)
24 | plt.figure()
25 | plt.imshow(hdul[0].data[0])
26 | plt.show()
27 |
28 |
29 | def show_images_infos(path, file): # Show basic information of a fits image acquired
30 | hdul = fits.open(path + file)
31 | print("***** Name file : " + file + " *****")
32 | print("Type de données : " + hdul[0].header["DATATYPE"])
33 | print("Mode d'acquisition : " + hdul[0].header["ACQMODE"])
34 | print("Temps d'exposition : " + str(hdul[0].header["EXPOSURE"]))
35 | print("Temps de lecture : " + str(hdul[0].header["READTIME"]))
36 | print("Longeur d'onde de Rayleigh : " + str(hdul[0].header["RAYWAVE"]))
37 | print("Longeur d'onde détectée : " + str(hdul[0].header["DTNWLGTH"]))
38 | print("***********************************" + "\n")
39 |
40 |
41 | # Plot the resulting fuction of to set of 1D data with the same dimension
42 | def simple_plot_2D(
43 | Lx, Ly, fig=None, title=None, xlabel=None, ylabel=None, style_color="b"
44 | ):
45 | plt.figure(fig)
46 | plt.clf()
47 | plt.title(title)
48 | plt.xlabel(xlabel)
49 | plt.ylabel(ylabel)
50 | plt.plot(Lx, Ly, style_color)
51 | plt.show()
52 |
53 |
54 | # Plot a 2D matrix
55 | def plot_im2D(Im, fig=None, title=None, xlabel=None, ylabel=None, cmap="viridis"):
56 | plt.figure(fig)
57 | plt.clf()
58 | plt.title(title)
59 | plt.xlabel(xlabel)
60 | plt.ylabel(ylabel)
61 | plt.imshow(Im, cmap=cmap)
62 | plt.colorbar()
63 | plt.show()
64 |
--------------------------------------------------------------------------------
/spyrit/misc/disp.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # This software is distributed under the terms
3 | # of the GNU Lesser General Public Licence (LGPL)
4 | # See LICENSE.md for further details
5 | # -----------------------------------------------------------------------------
6 |
7 | import matplotlib.pyplot as plt
8 | from mpl_toolkits.axes_grid1 import make_axes_locatable
9 | from PIL import Image
10 | import numpy as np
11 | from numpy import linalg as LA
12 | import time
13 | from scipy import signal
14 | from scipy import misc
15 | from scipy import sparse
16 | import torch
17 | import math
18 |
19 |
20 | def display_vid(video, fps, title="", colormap=plt.cm.gray):
21 | """
22 | video is a numpy array of shape [nb_frames, 1, nx, ny]
23 | """
24 | plt.ion()
25 | (nb_frames, channels, nx, ny) = video.shape
26 | fig = plt.figure()
27 | ax = fig.add_subplot(1, 1, 1)
28 | for i in range(nb_frames):
29 | current_frame = video[i, 0, :, :]
30 | plt.imshow(current_frame, cmap=colormap)
31 | plt.title(title)
32 | divider = make_axes_locatable(ax)
33 | cax = plt.axes([0.85, 0.1, 0.075, 0.8])
34 | plt.colorbar(cax=cax)
35 | plt.show()
36 | plt.pause(fps)
37 | plt.ioff()
38 |
39 |
40 | def display_rgb_vid(video, fps, title=""):
41 | """
42 | video is a numpy array of shape [nb_frames, 3, nx, ny]
43 | """
44 | plt.ion()
45 | (nb_frames, channels, nx, ny) = video.shape
46 | fig = plt.figure()
47 | ax = fig.add_subplot(1, 1, 1)
48 | for i in range(nb_frames):
49 | current_frame = video[i, :, :, :]
50 | current_frame = np.moveaxis(current_frame, 0, -1)
51 | plt.imshow(current_frame)
52 | plt.title(title)
53 | plt.show()
54 | plt.pause(fps)
55 | plt.ioff()
56 |
57 |
58 | def fitPlots(N, aspect=(16, 9)):
59 | width = aspect[0]
60 | height = aspect[1]
61 | area = width * height * 1.0
62 | factor = (N / area) ** (1 / 2.0)
63 | cols = math.floor(width * factor)
64 | rows = math.floor(height * factor)
65 | rowFirst = width < height
66 | while rows * cols < N:
67 | if rowFirst:
68 | rows += 1
69 | else:
70 | cols += 1
71 | rowFirst = not (rowFirst)
72 | return rows, cols
73 |
74 |
75 | def Multi_plots(
76 | img_list,
77 | title_list,
78 | shape,
79 | suptitle="",
80 | colormap=plt.cm.gray,
81 | axis_off=True,
82 | aspect=(16, 9),
83 | savefig="",
84 | fontsize=14,
85 | ):
86 | [rows, cols] = shape
87 | plt.figure()
88 | plt.suptitle(suptitle, fontsize=16)
89 | if (len(img_list) < rows * cols) or (len(title_list) < rows * cols):
90 | for k in range(max(rows * cols - len(img_list), rows * cols - len(title_list))):
91 | img_list.append(np.zeros((64, 64)))
92 | title_list.append("")
93 |
94 | for k in range(rows * cols):
95 | ax = plt.subplot(rows, cols, k + 1)
96 | ax.imshow(img_list[k], cmap=colormap)
97 | ax.set_title(title_list[k], fontsize=fontsize)
98 | if axis_off:
99 | plt.axis("off")
100 | if savefig:
101 | plt.savefig(savefig, bbox_inches="tight")
102 | plt.show()
103 |
104 |
105 | def compare_video_frames(
106 | vid_list,
107 | nb_disp_frames,
108 | title_list,
109 | suptitle="",
110 | colormap=plt.cm.gray,
111 | aspect=(16, 9),
112 | savefig="",
113 | fontsize=14,
114 | ):
115 | rows = len(vid_list)
116 | cols = nb_disp_frames
117 | plt.figure(figsize=aspect)
118 | plt.suptitle(suptitle, fontsize=16)
119 | for i in range(rows):
120 | for j in range(cols):
121 | k = (j + 1) + (i) * (cols)
122 | i
123 | # print(k)
124 | ax = plt.subplot(rows, cols, k)
125 | # print("i = {}, j = {}".format(i,j))
126 | ax.imshow(vid_list[i][0, j, 0, :, :], cmap=colormap)
127 | ax.set_title(title_list[i][j], fontsize=fontsize)
128 | plt.axis("off")
129 | if savefig:
130 | plt.savefig(savefig, bbox_inches="tight")
131 | plt.show()
132 |
133 |
134 | def torch2numpy(torch_tensor):
135 | return torch_tensor.cpu().detach().numpy()
136 |
137 |
138 | def uint8(dsp):
139 | x = (dsp - np.amin(dsp)) / (np.amax(dsp) - np.amin(dsp)) * 255
140 | x = x.astype("uint8")
141 | return x
142 |
143 |
144 | def imagesc(
145 | Img,
146 | title="",
147 | colormap=plt.cm.gray,
148 | show=True,
149 | figsize=None,
150 | cbar_pos=None,
151 | title_fontsize=16,
152 | ):
153 | """
154 | imagesc(IMG) Display image Img with scaled colors with greyscale
155 | colormap and colorbar
156 | imagesc(IMG, title=ttl) Display image Img with scaled colors with
157 | greyscale colormap and colorbar, with the title ttl
158 | imagesc(IMG, title=ttl, colormap=cmap) Display image Img with scaled colors
159 | with colormap and colorbar specified by cmap (choose between 'plasma',
160 | 'jet', and 'grey'), with the title ttl
161 | """
162 | fig = plt.figure(figsize=figsize)
163 | ax = fig.add_subplot(1, 1, 1)
164 | plt.imshow(Img, cmap=colormap)
165 | plt.title(title, fontsize=title_fontsize)
166 | divider = make_axes_locatable(ax)
167 | from mpl_toolkits.axes_grid1.inset_locator import inset_axes
168 |
169 | if cbar_pos == "bottom":
170 | cax = inset_axes(
171 | ax, width="100%", height="5%", loc="lower center", borderpad=-5
172 | )
173 | plt.colorbar(cax=cax, orientation="horizontal")
174 | else:
175 | cax = plt.axes([0.85, 0.1, 0.075, 0.8])
176 | plt.colorbar(cax=cax, orientation="vertical")
177 |
178 | # fig.tight_layout() # it raises warnings in some cases
179 | if show is True:
180 | plt.show()
181 |
182 |
183 | def imagecomp(
184 | Img1,
185 | Img2,
186 | suptitle="",
187 | title1="",
188 | title2="",
189 | colormap1=plt.cm.gray,
190 | colormap2=plt.cm.gray,
191 | ):
192 | f, (ax1, ax2) = plt.subplots(1, 2)
193 | im1 = ax1.imshow(Img1, cmap=colormap1)
194 | ax1.set_title(title1)
195 | cax = plt.axes([0.43, 0.3, 0.025, 0.4])
196 | plt.colorbar(im1, cax=cax)
197 | plt.suptitle(suptitle, fontsize=16)
198 | #
199 | im2 = ax2.imshow(Img2, cmap=colormap2)
200 | ax2.set_title(title2)
201 | cax = plt.axes([0.915, 0.3, 0.025, 0.4])
202 | plt.colorbar(im2, cax=cax)
203 | plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9)
204 | plt.show()
205 |
206 |
207 | def imagepanel(
208 | Img1,
209 | Img2,
210 | Img3,
211 | Img4,
212 | suptitle="",
213 | title1="",
214 | title2="",
215 | title3="",
216 | title4="",
217 | colormap1=plt.cm.gray,
218 | colormap2=plt.cm.gray,
219 | colormap3=plt.cm.gray,
220 | colormap4=plt.cm.gray,
221 | ):
222 | fig, axarr = plt.subplots(2, 2, figsize=(20, 10))
223 | plt.suptitle(suptitle, fontsize=16)
224 |
225 | im1 = axarr[0, 0].imshow(Img1, cmap=colormap1)
226 | axarr[0, 0].set_title(title1)
227 | cax = plt.axes([0.4, 0.54, 0.025, 0.35])
228 | plt.colorbar(im1, cax=cax)
229 |
230 | im2 = axarr[0, 1].imshow(Img2, cmap=colormap2)
231 | axarr[0, 1].set_title(title2)
232 | cax = plt.axes([0.90, 0.54, 0.025, 0.35])
233 | plt.colorbar(im2, cax=cax)
234 |
235 | im3 = axarr[1, 0].imshow(Img3, cmap=colormap3)
236 | axarr[1, 0].set_title(title3)
237 | cax = plt.axes([0.4, 0.12, 0.025, 0.35])
238 | plt.colorbar(im3, cax=cax)
239 |
240 | im4 = axarr[1, 1].imshow(Img4, cmap=colormap4)
241 | axarr[1, 1].set_title(title4)
242 | cax = plt.axes([0.9, 0.12, 0.025, 0.35])
243 | plt.colorbar(im4, cax=cax)
244 |
245 | plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9)
246 | plt.show()
247 |
248 |
249 | def plot(x, y, title="", xlabel="", ylabel="", color="black"):
250 | fig = plt.figure()
251 | ax = fig.add_subplot(1, 1, 1)
252 | plt.plot(x, y, color=color)
253 | plt.title(title)
254 | plt.xlabel(xlabel)
255 | plt.ylabel(ylabel)
256 | plt.show()
257 |
258 |
259 | def add_colorbar(mappable, position="right", size="5%"):
260 | """
261 | Example:
262 | f, axs = plt.subplots(1, 2)
263 | im = axs[0].imshow(img1, cmap='gray')
264 | add_colorbar(im)
265 | im = axs[0].imshow(img2, cmap='gray')
266 | add_colorbar(im)
267 | """
268 | if position == "bottom":
269 | orientation = "horizontal"
270 | else:
271 | orientation = "vertical"
272 |
273 | last_axes = plt.gca()
274 | ax = mappable.axes
275 | fig = ax.figure
276 | divider = make_axes_locatable(ax)
277 | cax = divider.append_axes(position, size="5%", pad=0.05)
278 | cbar = fig.colorbar(mappable, cax=cax, orientation=orientation)
279 | plt.sca(last_axes)
280 | return cbar
281 |
282 |
283 | def noaxis(axs):
284 | if type(axs) is np.ndarray:
285 | for ax in axs:
286 | ax.get_xaxis().set_visible(False)
287 | ax.get_yaxis().set_visible(False)
288 | else:
289 | axs.get_xaxis().set_visible(False)
290 | axs.get_yaxis().set_visible(False)
291 |
292 |
293 | def string_mean_std(x, prec=3):
294 | return "{:.{p}f} +/- {:.{p}f}".format(np.mean(x), np.std(x), p=prec)
295 |
296 |
297 | def print_mean_std(x, tag="", prec=3):
298 | print("{} = {:.{p}f} +/- {:.{p}f}".format(tag, np.mean(x), np.std(x), p=prec))
299 |
300 |
301 | def histogram(s):
302 | count, bins, ignored = plt.hist(s, 30, density=True)
303 | plt.show()
304 |
305 |
306 | def vid2batch(root, img_dim, start_frame, end_frame):
307 | from imutils.video import FPS
308 | import imutils
309 | import cv2
310 |
311 | stream = cv2.VideoCapture(root)
312 | fps = FPS().start()
313 | frame_nb = 0
314 | output_batch = torch.zeros(1, end_frame - start_frame, 1, img_dim, img_dim)
315 | while True:
316 | (grabbed, frame) = stream.read()
317 | if not grabbed:
318 | break
319 |
320 | frame_nb += 1
321 | if (frame_nb >= start_frame) & (frame_nb < end_frame):
322 | frame = cv2.resize(frame, (img_dim, img_dim))
323 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB)
324 | output_batch[0, frame_nb - start_frame, 0, :, :] = torch.Tensor(
325 | frame[:, :, 1]
326 | )
327 |
328 | return output_batch
329 |
330 |
331 | def pre_process_video(video, crop_patch, kernel_size):
332 | import cv2
333 |
334 | batch_size, seq_length, c, h, w = video.shape
335 | batched_frames = video.reshape(batch_size * seq_length * c, h, w)
336 | output_batch = torch.zeros(batched_frames.shape)
337 |
338 | for i in range(batch_size * seq_length * c):
339 | img = torch2numpy(batched_frames[i, :, :])
340 | img[crop_patch] = 0
341 | median_frame = cv2.medianBlur(img, kernel_size)
342 | output_batch[i, :, :] = torch.Tensor(median_frame)
343 | output_batch = output_batch.reshape(batch_size, seq_length, c, h, w)
344 | return output_batch
345 |
--------------------------------------------------------------------------------
/spyrit/misc/examples.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # This software is distributed under the terms
3 | # of the GNU Lesser General Public Licence (LGPL)
4 | # See LICENSE.md for further details
5 | # -----------------------------------------------------------------------------
6 |
7 | import numpy as np
8 |
9 |
10 | def translation_matrix(img_size, nb_pixels):
11 | init_ind = np.reshape(np.arange(img_size**2), (img_size, img_size))
12 | final_ind = np.zeros((img_size, img_size))
13 | final_ind[:, : (img_size - nb_pixels)] = init_ind[:, nb_pixels:]
14 | final_ind[:, (img_size - nb_pixels) :] = init_ind[:, :nb_pixels]
15 |
16 | final_ind = np.reshape(final_ind, (img_size**2, 1))
17 | init_ind = np.reshape(init_ind, (img_size**2, 1))
18 | F = permutation_matrix(final_ind, init_ind)
19 | return F
20 |
21 |
22 | def permutation_matrix(A, B):
23 | N = A.shape[0]
24 | I = np.eye(N)
25 | P = np.zeros((N, N))
26 |
27 | for i in range(N):
28 | pat = np.matlib.repmat(A[i, :], N, 1)
29 | ind = np.where(np.sum((pat == B), axis=1))
30 | P[ind, :] = I[i, :]
31 |
32 | return P
33 |
34 |
35 | def circle(img_size, R, x_max):
36 | x = np.linspace(-x_max, x_max, img_size)
37 | X, Y = np.meshgrid(x, x)
38 | return 1.0 * (X**2 + Y**2 < R)
39 |
--------------------------------------------------------------------------------
/spyrit/misc/load_data.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # This software is distributed under the terms
3 | # of the GNU Lesser General Public Licence (LGPL)
4 | # See LICENSE.md for further details
5 | # -----------------------------------------------------------------------------
6 |
7 | #!/usr/bin/env python3
8 | # -*- coding: utf-8 -*-
9 | """
10 | Created on Wed Jan 15 17:06:19 2020
11 |
12 | @author: crombez
13 | """
14 |
15 | import os
16 | import sys
17 | import glob
18 | import numpy as np
19 | import PIL
20 | from typing import Union
21 |
22 |
23 | def Files_names(Path, name_type):
24 | files = glob.glob(Path + name_type)
25 | print
26 | files.sort(key=os.path.getmtime)
27 | return [os.path.basename(x) for x in files]
28 |
29 |
30 | def load_data_recon_3D(Path_files, list_files, Nl, Nc, Nh):
31 | Data = np.zeros((Nl, Nc, Nh))
32 |
33 | for i in range(0, 2 * Nh, 2):
34 | Data[:, :, i // 2] = np.rot90(
35 | np.array(PIL.Image.open(Path_files + list_files[i]))
36 | ) - np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])))
37 |
38 | return Data
39 |
40 |
41 | # Load the data of the hSPIM and compresse the spectrale dimensions to do the reconstruction for every lambda
42 | # odl convention the set of data has to be arranged in such way that the positive part of the hadamard motifs comes first
43 | def load_data_Comp_1D_old(Path_files, list_files, Nh, Nl, Nc):
44 | Data = np.zeros((Nl, Nh))
45 |
46 | for i in range(0, 2 * Nh, 2):
47 | Data[:, i // 2] = Sum_coll(
48 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc
49 | ) - Sum_coll(
50 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3),
51 | Nl,
52 | Nc,
53 | )
54 |
55 | return Data
56 |
57 |
58 | # Load the data of the hSPIM and compresse the spectrale dimensions to do the reconstruction for every lambda
59 | # new convention the set of data has to be arranged in such way that the negative part of the hadamard motifs comes first
60 | def load_data_Comp_1D_new(Path_files, list_files, Nh, Nl, Nc):
61 | Data = np.zeros((Nl, Nh))
62 |
63 | for i in range(0, 2 * Nh, 2):
64 | Data[:, i // 2] = Sum_coll(
65 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3),
66 | Nl,
67 | Nc,
68 | ) - Sum_coll(
69 | np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc
70 | )
71 |
72 | return Data
73 |
74 |
75 | def download_girder(
76 | server_url: str,
77 | hex_ids: Union[str, list[str]],
78 | local_folder: str,
79 | file_names: Union[str, list[str]] = None,
80 | ):
81 | """
82 | Downloads data from a Girder server and saves it locally.
83 |
84 | This function first creates the local folder if it does not exist. Then, it
85 | connects to the Girder server and gets the file names for the files
86 | whose name are not provided. For each file, it checks if it already exists
87 | by checking if the file name is already in the local folder. If not, it
88 | downloads the file.
89 |
90 | Args:
91 | server_url (str): The URL of the Girder server.
92 |
93 | hex_id (str or list[str]): The hexadecimal id of the file(s) to download.
94 | If a list is provided, the files are downloaded in the same order and
95 | are saved in the same folder.
96 |
97 | local_folder (str): The path to the local folder where the files will
98 | be saved. If it does not exist, it will be created.
99 |
100 | file_name (str or list[str], optional): The name of the file(s) to save.
101 | If a list is provided, it must have the same length as hex_id. Each
102 | element equal to `None` will be replaced by the name of the file on the
103 | server. If None, all the names will be obtained from the server.
104 | Default is None. All names include the extension.
105 |
106 | Raises:
107 | ValueError: If the number of file names provided does not match the
108 | number of files to download.
109 |
110 | Returns:
111 | list[str]: The absolute paths to the downloaded files.
112 | """
113 | # leave import in function, so that the module can be used without
114 | # girder_client
115 | import girder_client
116 |
117 | # check the local folder exists
118 | if not os.path.exists(local_folder):
119 | print("Local folder not found, creating it... ", end="")
120 | os.makedirs(local_folder)
121 | print("done.")
122 |
123 | # connect to the server
124 | gc = girder_client.GirderClient(apiUrl=server_url)
125 |
126 | # create lists if strings are provided
127 | if type(hex_ids) is str:
128 | hex_ids = [hex_ids]
129 | if file_names is None:
130 | file_names = [None] * len(hex_ids)
131 | elif type(file_names) is str:
132 | file_names = [file_names]
133 |
134 | if len(file_names) != len(hex_ids):
135 | raise ValueError("There must be as many file names as hex ids.")
136 |
137 | abs_paths = []
138 |
139 | # for each file, check if it exists and download if necessary
140 | for id, name in zip(hex_ids, file_names):
141 |
142 | if name is None:
143 | # get the file name
144 | name = gc.getFile(id)["name"]
145 |
146 | # check the file exists
147 | if not os.path.exists(os.path.join(local_folder, name)):
148 | # connect to the server to download the file
149 | print(f"Downloading {name}... ", end="\r")
150 | gc.downloadFile(id, os.path.join(local_folder, name))
151 | print(f"Downloading {name}... done.")
152 |
153 | else:
154 | print("File already exists at", os.path.join(local_folder, name))
155 |
156 | abs_paths.append(os.path.abspath(os.path.join(local_folder, name)))
157 |
158 | return abs_paths[0] if len(abs_paths) == 1 else abs_paths
159 |
--------------------------------------------------------------------------------
/spyrit/misc/matrix_tools.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # This software is distributed under the terms
3 | # of the GNU Lesser General Public Licence (LGPL)
4 | # See LICENSE.md for further details
5 | # -----------------------------------------------------------------------------
6 |
7 | #!/usr/bin/env python3
8 | # -*- coding: utf-8 -*-
9 | """
10 | Created on Wed Jan 15 16:37:27 2020
11 |
12 | @author: crombez
13 | """
14 | import warnings
15 |
16 | warnings.simplefilter("always", DeprecationWarning)
17 |
18 | import numpy as np
19 |
20 | import spyrit.misc.sampling as samp
21 |
22 |
23 | def Permutation_Matrix(mat):
24 | r"""
25 | Returns permutation matrix from sampling matrix
26 |
27 | Args:
28 | Mat (np.ndarray):
29 | N-by-N sampling matrix, where high values indicate high significance.
30 |
31 | Returns:
32 | P (np.ndarray): N^2-by-N^2 permutation matrix (boolean)
33 |
34 | .. warning::
35 | This function is a duplicate of
36 | :func:`spyrit.misc.sampling.Permutation_Matrix` and will be removed
37 | in a future release.
38 |
39 | .. note::
40 | Consider using :func:`sort_by_significance` for increased
41 | computational performance if using :func:`Permutation_Matrix` to
42 | reorder a matrix as follows:
43 | ``y = Permutation_Matrix(Ord) @ Mat``
44 | """
45 | warnings.warn(
46 | "\nspyrit.misc.matrix_tools.Permutation_Matrix is deprecated and will"
47 | + " be removed in a future release. Use\n"
48 | + "spyrit.misc.sampling.Permutation_Matrix instead.",
49 | DeprecationWarning,
50 | )
51 | return samp.Permutation_Matrix(mat)
52 |
53 |
54 | def expend_vect(Vect, N1, N2): # Expened a vectors of siez N1 to N2
55 | V_out = np.zeros(N2)
56 | S = int(N2 / N1)
57 | j = 0
58 | ad = 0
59 | for i in range(N1):
60 | for j in range(0, S):
61 | V_out[i + j + ad] = Vect[i]
62 | ad += S - 1
63 | return V_out
64 |
65 |
66 | def data_conv_hadamard(H, Data, N):
67 | for i in range(N):
68 | H[:, :, i] = H[:, :, i] * Data
69 | return H
70 |
71 |
72 | def Sum_coll(Mat, N_lin, N_coll): # Return the sum of all the raw of the N1xN2 matrix
73 | Mturn = np.zeros(N_lin)
74 |
75 | for i in range(N_coll):
76 | Mturn += Mat[:, i]
77 |
78 | return Mturn
79 |
80 |
81 | def compression_1D(
82 | H, Nl, Nc, Nh
83 | ): # Compress a Matrix of N1xN2xN3 into a matrix of N1xN3 by summing the raw
84 | H_1D = np.zeros((Nl, Nh))
85 | for i in range(Nh):
86 | H_1D[:, i] = Sum_coll(H[:, :, i], Nl, Nc)
87 |
88 | return H_1D
89 |
90 |
91 | def normalize_mat_2D(Mat): # Normalise a N1xN2 matrix by is maximum value
92 | Max = np.amax(Mat)
93 | return Mat * (1 / Max)
94 |
95 |
96 | def normalize_by_median_mat_2D(Mat): # Normalise a N1xN2 matrix by is median value
97 | Median = np.median(Mat)
98 | return Mat * (1 / Median)
99 |
100 |
101 | def remove_offset_mat_2D(Mat): # Substract the mean value of the matrix
102 | Mean = np.mean(Mat)
103 | return Mat - Mean
104 |
105 |
106 | def resize(Mat, Nl, Nc, Nh): # Re-size a matrix of N1xN2 into N1xN3
107 | Mres = np.zeros((Nl, Nc))
108 | for i in range(Nl):
109 | Mres[i, :] = expend_vect(Mat[i, :], Nh, Nc)
110 | return Mres
111 |
112 |
113 | def stack_depth_matrice(
114 | Mat, Nl, Nc, Nd
115 | ): # Stack a 3 by 3 matrix along its third dimensions
116 | M_out = np.zeros((Nl, Nc))
117 | for i in range(Nd):
118 | M_out += Mat[:, :, i]
119 | return M_out
120 |
121 |
122 | # fuction that need to be better difended
123 |
124 |
125 | def smooth(y, box_pts): # Smooth a vectors
126 | box = np.ones(box_pts) / box_pts
127 | y_smooth = np.convolve(y, box, mode="same")
128 | return y_smooth
129 |
130 |
131 | def reject_outliers(data, m=2): # Remove
132 | return np.where(abs(data - np.mean(data)) < m * np.std(data), data, 0)
133 |
134 |
135 | def clean_out(Data, Nl, Nc, Nh, m=2):
136 | Mout = np.zeros((Nl, Nc, Nh))
137 | for i in range(Nh):
138 | Mout[:, :, i] = reject_outliers(Data[:, :, i], m)
139 | return Data
140 |
--------------------------------------------------------------------------------
/spyrit/misc/metrics.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # This software is distributed under the terms
3 | # of the GNU Lesser General Public Licence (LGPL)
4 | # See LICENSE.md for further details
5 | # -----------------------------------------------------------------------------
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | from torch.optim import lr_scheduler
11 | import numpy as np
12 | import torchvision
13 | from torchvision import datasets, models, transforms
14 | import torch.nn.functional as F
15 | import imageio
16 | import matplotlib.pyplot as plt
17 |
18 | # import skimage.metrics as skm
19 |
20 |
21 | def batch_psnr(torch_batch, output_batch):
22 | list_psnr = []
23 | for i in range(torch_batch.shape[0]):
24 | img = torch_batch[i, 0, :, :]
25 | img_out = output_batch[i, 0, :, :]
26 | img = img.cpu().detach().numpy()
27 | img_out = img_out.cpu().detach().numpy()
28 | list_psnr.append(psnr(img, img_out))
29 | return list_psnr
30 |
31 |
32 | def batch_psnr_(torch_batch, output_batch, r=2):
33 | list_psnr = []
34 | for i in range(torch_batch.shape[0]):
35 | img = torch_batch[i, 0, :, :]
36 | img_out = output_batch[i, 0, :, :]
37 | img = img.cpu().detach().numpy()
38 | img_out = img_out.cpu().detach().numpy()
39 | list_psnr.append(psnr_(img, img_out, r=r))
40 | return list_psnr
41 |
42 |
43 | def batch_ssim(torch_batch, output_batch):
44 | list_ssim = []
45 | for i in range(torch_batch.shape[0]):
46 | img = torch_batch[i, 0, :, :]
47 | img_out = output_batch[i, 0, :, :]
48 | img = img.cpu().detach().numpy()
49 | img_out = img_out.cpu().detach().numpy()
50 | list_ssim.append(ssim(img, img_out))
51 | return list_ssim
52 |
53 |
54 | def dataset_meas(dataloader, model, device):
55 | meas = []
56 | for inputs, labels in dataloader:
57 | inputs = inputs.to(device)
58 | # with torch.no_grad():
59 | b, c, h, w = inputs.shape
60 | net_output = model.acquire(inputs, b, c, h, w)
61 | raw = net_output[:, 0, :]
62 | raw = raw.cpu().detach().numpy()
63 | meas.extend(raw)
64 | return meas
65 |
66 |
67 | #
68 | # def dataset_psnr_different_measures(dataloader, model, model_2, device):
69 | # psnr = [];
70 | # #psnr_fc = [];
71 | # for inputs, labels in dataloader:
72 | # inputs = inputs.to(device)
73 | # m = model_2.normalized measure(inputs);
74 | # net_output = model.forward_reconstruct(inputs);
75 | # #net_output2 = model.evaluate_fcl(inputs);
76 | #
77 | # psnr += batch_psnr(inputs, net_output);
78 | # #psnr_fc += batch_psnr(inputs, net_output2);
79 | # psnr = np.array(psnr);
80 | # #psnr_fc = np.array(psnr_fc);
81 | # return psnr;
82 | #
83 |
84 |
85 | def dataset_psnr(dataloader, model, device):
86 | psnr = []
87 | psnr_fc = []
88 | for inputs, labels in dataloader:
89 | inputs = inputs.to(device)
90 | # with torch.no_grad():
91 | # b,c,h,w = inputs.shape;
92 |
93 | net_output = model.evaluate(inputs)
94 | net_output2 = model.evaluate_fcl(inputs)
95 |
96 | psnr += batch_psnr(inputs, net_output)
97 | psnr_fc += batch_psnr(inputs, net_output2)
98 | psnr = np.array(psnr)
99 | psnr_fc = np.array(psnr_fc)
100 | return psnr, psnr_fc
101 |
102 |
103 | def dataset_ssim(dataloader, model, device):
104 | ssim = []
105 | ssim_fc = []
106 | for inputs, labels in dataloader:
107 | inputs = inputs.to(device)
108 | # evaluate full model and fully connected layer
109 | net_output = model.evaluate(inputs)
110 | net_output2 = model.evaluate_fcl(inputs)
111 | # compute SSIM and concatenate
112 | ssim += batch_ssim(inputs, net_output)
113 | ssim_fc += batch_ssim(inputs, net_output2)
114 | ssim = np.array(ssim)
115 | ssim_fc = np.array(ssim_fc)
116 | return ssim, ssim_fc
117 |
118 |
119 | def dataset_psnr_ssim(dataloader, model, device):
120 | # init lists
121 | psnr = []
122 | ssim = []
123 | # loop over batches
124 | for inputs, labels in dataloader:
125 | inputs = inputs.to(device)
126 | # evaluate full model
127 | net_output = model.evaluate(inputs)
128 | # compute PSNRs and concatenate
129 | psnr += batch_psnr(inputs, net_output)
130 | # compute SSIMs and concatenate
131 | ssim += batch_ssim(inputs, net_output)
132 | # convert
133 | psnr = np.array(psnr)
134 | ssim = np.array(ssim)
135 | return psnr, ssim
136 |
137 |
138 | def dataset_psnr_ssim_fcl(dataloader, model, device):
139 | # init lists
140 | psnr = []
141 | ssim = []
142 | # loop over batches
143 | for inputs, labels in dataloader:
144 | inputs = inputs.to(device)
145 | # evaluate fully connected layer
146 | net_output = model.evaluate_fcl(inputs)
147 | # compute PSNRs and concatenate
148 | psnr += batch_psnr(inputs, net_output)
149 | # compute SSIMs and concatenate
150 | ssim += batch_ssim(inputs, net_output)
151 | # convert
152 | psnr = np.array(psnr)
153 | ssim = np.array(ssim)
154 | return psnr, ssim
155 |
156 |
157 | def psnr(I1, I2):
158 | """
159 | Computes the psnr between two images I1 and I2
160 | """
161 | d = np.amax(I1) - np.amin(I1)
162 | diff = np.square(I2 - I1)
163 | MSE = diff.sum() / I1.size
164 | Psnr = 10 * np.log(d**2 / MSE) / np.log(10)
165 | return Psnr
166 |
167 |
168 | def psnr_(img1, img2, r=2):
169 | """
170 | Computes the psnr between two image with values expected in a given range
171 |
172 | Args:
173 | img1, img2 (np.ndarray): images
174 | r (float): image range
175 |
176 | Returns:
177 | Psnr (float): Peak signal-to-noise ratio
178 |
179 | """
180 | MSE = np.mean((img1 - img2) ** 2)
181 | Psnr = 10 * np.log(r**2 / MSE) / np.log(10)
182 | return Psnr
183 |
184 |
185 | def psnr_torch(img_gt, img_rec, mask=None, dim=(-2, -1), img_dyn=None):
186 | r"""
187 | Computes the Peak Signal-to-Noise Ratio (PSNR) between two images.
188 |
189 | .. math::
190 |
191 | \text{PSNR} = 20 \, \log_{10} \left( \frac{\text{d}}{\sqrt{\text{MSE}}} \right), \\
192 | \text{MSE} = \frac{1}{L}\sum_{\ell=1}^L \|I_\ell - \tilde{I}_\ell\|^2_2,
193 |
194 | where :math:`d` is the image dynamic and :math:`\{I_\ell\}` (resp. :math:`\{\tilde{I}_\ell\}`) is the set of ground truth (resp. reconstructed) images.
195 |
196 | Args:
197 | :attr:`img_gt`: Tensor containing the *ground-truth* image.
198 |
199 | :attr:`img_rec`: Tensor containing the reconstructed image.
200 |
201 | :attr:`mask`: Mask where the squared error is computed. Defaults :attr:`None`, i.e., no mask is considered.
202 |
203 | :attr:`dim`: Dimensions where the squared error is computed. If mask is :attr:`None`, defaults to :attr:`-1` (i.e., the last dimension). Othewise defaults to :attr:`(-2,-1)` (i.e., the last two dimensions).
204 |
205 | :attr:`img_dyn`: Image dynamic range (e.g., 1.0 for normalized images, 255 for 8-bit images). When :attr:`img_dyn` is :attr:`None`, the dynamic range is computed from the ground-truth image.
206 |
207 | Returns:
208 | PSNR value.
209 |
210 | .. note::
211 | :attr:`psnr_torch(img_gt, img_rec)` is different from :attr:`psnr_torch(img_rec, img_gt)`. The first expression assumes :attr:`img_gt` is the ground truth while the second assumes that this is :attr:`img_rec`. This leads to different dynamic ranges.
212 |
213 | Example 1: 10 images of size 64x64 with values in [0,1) corrupted with 5% noise
214 | >>> x = torch.rand(10,1,64,64)
215 | >>> n = x + 0.05*torch.randn(x.shape)
216 | >>> out = psnr_torch(x,n)
217 | >>> print(out.shape)
218 | torch.Size([10, 1])
219 |
220 | Example 2: 10 images of size 64x64 with values in [0,1) corrupted with 5% noise
221 | >>> psnr_torch(n,x)
222 | tensor(...)
223 | >>> psnr_torch(x,n)
224 | tensor(...)
225 | >>> psnr_torch(n,x,img_dyn=1.0)
226 | tensor(...)
227 |
228 | """
229 | if mask is not None:
230 | dim = -1
231 | img_gt = img_gt[mask > 0]
232 | img_rec = img_rec[mask > 0]
233 | print("mask")
234 |
235 | mse = (img_gt - img_rec) ** 2
236 | mse = torch.mean(mse, dim=dim)
237 |
238 | if img_dyn is None:
239 | img_dyn = torch.amax(img_gt, dim=dim) - torch.amin(img_gt, dim=dim)
240 |
241 | return 10 * torch.log10(img_dyn**2 / mse)
242 |
243 |
244 | def ssim(I1, I2):
245 | """
246 | Computes the ssim between two images I1 and I2
247 | """
248 | L = np.amax(I1) - np.amin(I1)
249 | mu1 = np.mean(I1)
250 | mu2 = np.mean(I2)
251 | s1 = np.std(I1)
252 | s2 = np.std(I2)
253 | s12 = np.mean(np.multiply((I1 - mu1), (I2 - mu2)))
254 | c1 = (0.01 * L) ** 2
255 | c2 = (0.03 * L) ** 2
256 | result = ((2 * mu1 * mu2 + c1) * (2 * s12 + c2)) / (
257 | (mu1**2 + mu2**2 + c1) * (s1**2 + s2**2 + c2)
258 | )
259 | return result
260 |
261 |
262 | # def ssim_sk(x_gt, x, img_dyn=None):
263 | # """
264 | # SSIM from skimage
265 |
266 | # Args:
267 | # torch tensors
268 |
269 | # Returns:
270 | # torch tensor
271 | # """
272 | # if not isinstance(x, np.ndarray):
273 | # x = x.cpu().detach().numpy().squeeze()
274 | # x_gt = x_gt.cpu().detach().numpy().squeeze()
275 | # ssim_val = np.zeros(x.shape[0])
276 | # for i in range(x.shape[0]):
277 | # ssim_val[i] = skm.structural_similarity(x_gt[i], x[i], data_range=img_dyn)
278 | # return torch.tensor(ssim_val)
279 |
280 |
281 | def batch_psnr_vid(input_batch, output_batch):
282 | list_psnr = []
283 | batch_size, seq_length, c, h, w = input_batch.shape
284 | input_batch = input_batch.reshape(batch_size * seq_length * c, 1, h, w)
285 | output_batch = output_batch.reshape(batch_size * seq_length * c, 1, h, w)
286 | for i in range(input_batch.shape[0]):
287 | img = input_batch[i, 0, :, :]
288 | img_out = output_batch[i, 0, :, :]
289 | img = img.cpu().detach().numpy()
290 | img_out = img_out.cpu().detach().numpy()
291 | list_psnr.append(psnr(img, img_out))
292 | return list_psnr
293 |
294 |
295 | def batch_ssim_vid(input_batch, output_batch):
296 | list_ssim = []
297 | batch_size, seq_length, c, h, w = input_batch.shape
298 | input_batch = input_batch.reshape(batch_size * seq_length * c, 1, h, w)
299 | output_batch = output_batch.reshape(batch_size * seq_length * c, 1, h, w)
300 | for i in range(input_batch.shape[0]):
301 | img = input_batch[i, 0, :, :]
302 | img_out = output_batch[i, 0, :, :]
303 | img = img.cpu().detach().numpy()
304 | img_out = img_out.cpu().detach().numpy()
305 | list_ssim.append(ssim(img, img_out))
306 | return list_ssim
307 |
308 |
309 | def compare_video_nets_supervised(net_list, testloader, device):
310 | psnr = [[] for i in range(len(net_list))]
311 | ssim = [[] for i in range(len(net_list))]
312 | for batch, (inputs, labels) in enumerate(testloader):
313 | [batch_size, seq_length, c, h, w] = inputs.shape
314 | print("Batch :{}/{}".format(batch + 1, len(testloader)))
315 | inputs = inputs.to(device)
316 | labels = labels.to(device)
317 | with torch.no_grad():
318 | for i in range(len(net_list)):
319 | outputs = net_list[i].evaluate(inputs)
320 | psnr[i] += batch_psnr_vid(labels, outputs)
321 | ssim[i] += batch_ssim_vid(labels, outputs)
322 | return psnr, ssim
323 |
324 |
325 | def compare_nets_unsupervised(net_list, testloader, device):
326 | psnr = [[] for i in range(len(net_list))]
327 | ssim = [[] for i in range(len(net_list))]
328 | for batch, (inputs, labels) in enumerate(testloader):
329 | [batch_size, seq_length, c, h, w] = inputs.shape
330 | print("Batch :{}/{}".format(batch + 1, len(testloader)))
331 | inputs = inputs.to(device)
332 | labels = labels.to(device)
333 | with torch.no_grad():
334 | for i in range(len(net_list)):
335 | outputs = net_list[i].evaluate(inputs)
336 | psnr[i] += batch_psnr_vid(outputs, labels)
337 | ssim[i] += batch_ssim_vid(outputs, labels)
338 | return psnr, ssim
339 |
340 |
341 | def print_mean_std(x, tag=""):
342 | print("{}psnr = {} +/- {}".format(tag, np.mean(x), np.std(x)))
343 |
--------------------------------------------------------------------------------
/tutorial/README.txt:
--------------------------------------------------------------------------------
1 | Tutorials
2 | =========
3 |
4 | This series of tutorials should guide you through the use of the SPyRiT pipeline.
5 |
6 | .. figure:: ../fig/direct_net.png
7 | :width: 600
8 | :align: center
9 | :alt: SPyRiT pipeline
10 |
11 | |
12 |
13 | Each tutorial focuses on a specific submodule of the full pipeline.
14 |
15 | * :ref:`Tutorial 1 `.a introduces the basics of measurement operators.
16 |
17 | * :ref:`Tutorial 1 `.b introduces the splitting of measurement operators.
18 |
19 | * :ref:`Tutorial 1 `.c introduces the 2d Hadamard transform with subsampling.
20 |
21 | * :ref:`Tutorial 2 ` introduces the noise operators.
22 |
23 | * :ref:`Tutorial 3 ` demonstrates pseudo-inverse reconstructions from Hadamard measurements.
24 |
25 |
26 | .. note::
27 |
28 | The Python script (*.py*) or Jupyter notebook (*.ipynb*) corresponding to each tutorial can be downloaded at the bottom of the page. The images used in these files can be found on `GitHub`_.
29 |
30 | The tutorials below will gradually be updated to be compatible with SPyRiT 3 (work in progress, in the meantime see SPyRiT `2.4.0`_).
31 |
32 | * :ref:`Tutorial 3 ` uses a CNN to denoise the image if necessary
33 |
34 | * :ref:`Tutorial 4 ` is used to train the CNN introduced in Tutorial 3
35 |
36 | * :ref:`Tutorial 5 ` introduces a new type of measurement operator ('split') that simulates positive and negative measurements
37 |
38 | * :ref:`Tutorial 6 ` uses a Denoised Completion Network with a trainable image denoiser to improve the results obtained in Tutorial 5
39 |
40 | * :ref:`Tutorial 7 ` shows how to perform image reconstruction using a pretrained plug-and-play denoising network.
41 |
42 | * :ref:`Tutorial 8 ` shows how to perform image reconstruction using a learnt proximal gradient descent.
43 |
44 | * :ref:`Tutorial 9 ` explains motion simulation from an image, dynamic measurements and reconstruction.
45 |
46 | * Explore :ref:`Bonus Tutorial ` if you want to go deeper into Spyrit's capabilities
47 |
48 |
49 | .. _GitHub: https://github.com/openspyrit/spyrit/tree/3895b5e61fb6d522cff5e8b32a36da89b807b081/tutorial/images/test
50 |
51 | .. _2.4.0: https://spyrit.readthedocs.io/en/2.4.0/gallery/index.html
52 |
--------------------------------------------------------------------------------
/tutorial/images/test/ILSVRC2012_test_00000001.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000001.jpeg
--------------------------------------------------------------------------------
/tutorial/images/test/ILSVRC2012_test_00000002.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000002.jpeg
--------------------------------------------------------------------------------
/tutorial/images/test/ILSVRC2012_test_00000003.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000003.jpeg
--------------------------------------------------------------------------------
/tutorial/images/test/ILSVRC2012_test_00000004.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000004.jpeg
--------------------------------------------------------------------------------
/tutorial/images/test/ILSVRC2012_test_00000005.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000005.jpeg
--------------------------------------------------------------------------------
/tutorial/images/test/ILSVRC2012_test_00000006.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000006.jpeg
--------------------------------------------------------------------------------
/tutorial/images/test/ILSVRC2012_test_00000007.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openspyrit/spyrit/6a9a5847cb7b69f2d11459bc7fffcef630126f75/tutorial/images/test/ILSVRC2012_test_00000007.jpeg
--------------------------------------------------------------------------------
/tutorial/tuto_01_a_acquisition_operators.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 01.a. Acquisition operators (basic)
3 | ====================================================
4 | .. _tuto_acquisition_operators:
5 |
6 | This tutorial shows how to simulate measurements using the :mod:`spyrit.core.meas` submodule.
7 |
8 |
9 | .. image:: ../fig/tuto1.png
10 | :width: 600
11 | :align: center
12 | :alt: Reconstruction architecture sketch
13 |
14 | |
15 |
16 | All simulations are based on :class:`spyrit.core.meas.Linear` base class that simulates linear measurements
17 |
18 | .. math::
19 | m = Hx,
20 |
21 | where :math:`H\in\mathbb{R}^{M\times N}` is the acquisition matrix, :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`M` is the number of measurements, and :math:`N` is the dimension of the signal.
22 |
23 | .. important::
24 | The vector :math:`x \in \mathbb{R}^N` represents a multi-dimensional array (e.g, an image :math:`X \in \mathbb{R}^{N_1 \times N_2}` with :math:`N = N_1 \times N_2`). Both variables are related through vectorization , i.e., :math:`x = \texttt{vec}(X)`.
25 |
26 | """
27 |
28 | # %%
29 | # 1D Measurements
30 | # -----------------------------------------------------------------------------
31 |
32 | ###############################################################################
33 | # We instantiate a measurement operator from a matrix of shape (10, 15).
34 | import torch
35 | from spyrit.core.meas import Linear
36 |
37 | H = torch.randn(10, 15)
38 | meas_op = Linear(H)
39 |
40 | ###############################################################################
41 | # We consider 3 signals of length 15
42 | x = torch.randn(3, 15)
43 |
44 | ###############################################################################
45 | # We apply the operator to the batch of images, which produces 3 measurements
46 | # of length 10
47 | m = meas_op(x)
48 | print(m.shape)
49 |
50 | ###############################################################################
51 | # We now plot the matrix-vector products
52 |
53 | from spyrit.misc.disp import add_colorbar, noaxis
54 | import matplotlib.pyplot as plt
55 |
56 | f, axs = plt.subplots(1, 3, figsize=(10, 5))
57 | axs[0].set_title("Forward matrix H")
58 | im = axs[0].imshow(H, cmap="gray")
59 | add_colorbar(im, "bottom")
60 |
61 | axs[1].set_title("Signals x")
62 | im = axs[1].imshow(x.T, cmap="gray")
63 | add_colorbar(im, "bottom")
64 |
65 | axs[2].set_title("Measurements m")
66 | im = axs[2].imshow(m.T, cmap="gray")
67 | add_colorbar(im, "bottom")
68 |
69 | noaxis(axs)
70 | # sphinx_gallery_thumbnail_number = 1
71 |
72 | # %%
73 | # 2D Measurements
74 | # -----------------------------------------------------------------------------
75 |
76 | ###############################################################################
77 | # We load a batch of images from the :attr:`/images/` folder. Using the
78 | # :func:`transform_gray_norm` function with the :attr:`normalize=False`
79 | # argument returns images with values in (0,1).
80 | import os
81 | import torchvision
82 | from spyrit.misc.statistics import transform_gray_norm
83 |
84 | spyritPath = os.getcwd()
85 | imgs_path = os.path.join(spyritPath, "images/")
86 |
87 | # Grayscale images of size (32, 32), no normalization to keep values in (0,1)
88 | transform = transform_gray_norm(img_size=32, normalize=False)
89 |
90 | # Create dataset and loader (expects class folder :attr:'images/test/')
91 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
92 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
93 |
94 | x, _ = next(iter(dataloader))
95 |
96 | ###############################################################################
97 | # We crop the batch to get image of shape (9, 25).
98 | x = x[:, :, :9, :25]
99 | print(f"Shape of input images: {x.shape}")
100 |
101 | ###############################################################################
102 | # We plot the second image.
103 | from spyrit.misc.disp import imagesc
104 |
105 | imagesc(x[1, 0, :, :], "Image X")
106 |
107 |
108 | ###############################################################################
109 | # We instantiate a measurement operator from a random matrix with shape (10, 9*25). To indicate that the operator works in 2D, we use the :attr:`meas_shape` argument.
110 | H = torch.randn(10, 9 * 25)
111 | meas_op = Linear(H, meas_shape=(9, 25))
112 |
113 | ###############################################################################
114 | # We apply the operator to the batch of images, which produces a batch of measurement vectors of length 10.
115 | m = meas_op(x)
116 | print(m.shape)
117 |
118 |
119 | ###############################################################################
120 | # We now plot the matrix-vector products corresponding to the second image in the batch.
121 |
122 | ###############################################################################
123 | # We first select the second image and the second measurement vector in the batch.
124 | x_plot = x[1, 0, :, :]
125 | m_plot = m[1]
126 |
127 | ###############################################################################
128 | # Then we vectorize the image to get a 1D array of length 9*25.
129 | x_plot = x_plot.reshape(1, -1)
130 |
131 | print(f"Vectorised image with shape: {x_plot.shape}")
132 |
133 | ###############################################################################
134 | # We finally plot the matrix-vector products :math:`m = H x = H \texttt{vec}(X)`.
135 |
136 | from spyrit.misc.disp import add_colorbar, noaxis
137 | import matplotlib.pyplot as plt
138 |
139 | f, axs = plt.subplots(1, 3)
140 | axs[0].set_title("Forward matrix H")
141 | im = axs[0].imshow(H, cmap="gray")
142 | # add_colorbar(im, "bottom")
143 |
144 | axs[1].set_title("x = vec(X)")
145 | im = axs[1].imshow(x_plot.mT, cmap="gray")
146 | # add_colorbar(im, "bottom")
147 |
148 | axs[2].set_title("Measurements m")
149 | im = axs[2].imshow(m_plot.mT, cmap="gray")
150 | # add_colorbar(im, "bottom")
151 |
152 | noaxis(axs)
153 |
--------------------------------------------------------------------------------
/tutorial/tuto_01_b_splitting.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 01.b. Acquisition operators (splitting)
3 | ====================================================
4 | .. _tuto_acquisition_operators_splitting:
5 |
6 | This tutorial shows how to simulate linear measurements by splitting an acquisition matrix :math:`H\in \mathbb{R}^{M\times N}` that contains negative values. It based on the :class:`spyrit.core.meas.LinearSplit` class of the :mod:`spyrit.core.meas` submodule.
7 |
8 |
9 | .. image:: ../fig/tuto1.png
10 | :width: 600
11 | :align: center
12 | :alt: Reconstruction architecture sketch
13 |
14 | |
15 |
16 | In practice, only positive values can be implemented using a digital micromirror device (DMD). Therefore, we acquire
17 |
18 | .. math::
19 | y =Ax,
20 |
21 | where :math:`A \colon\, \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains positive DMD patterns, :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`2M` is the number of DMD patterns, and :math:`N` is the dimension of the signal.
22 |
23 | .. important::
24 | The vector :math:`x \in \mathbb{R}^N` represents a multi-dimensional array (e.g, an image :math:`X \in \mathbb{R}^{N_1 \times N_2}` with :math:`N = N_1 \times N_2`). Both variables are related through vectorization , i.e., :math:`x = \texttt{vec}(X)`.
25 |
26 | Given a matrix :math:`H` with negative values, we define the positive DMD patterns :math:`A` from the positive and negative components :math:`H`. In practice, the even rows of :math:`A` contain the positive components of :math:`H`, while odd rows of :math:`A` contain the negative components of :math:`H`.
27 |
28 | .. math::
29 | \begin{cases}
30 | A[0::2, :] = H_{+}, \text{ with } H_{+} = \max(0,H),\\
31 | A[1::2, :] = H_{-}, \text{ with } H_{-} = \max(0,-H).
32 | \end{cases}
33 |
34 | """
35 |
36 | # %%
37 | # Splitting in 1D
38 | # -----------------------------------------------------------------------------
39 |
40 | ###############################################################################
41 | # We instantiate a measurement operator from a matrix of shape (10, 15).
42 | import torch
43 | from spyrit.core.meas import LinearSplit
44 |
45 | H = torch.randn(10, 15)
46 | meas_op = LinearSplit(H)
47 |
48 | ###############################################################################
49 | # We consider 3 signals of length 15.
50 | x = torch.randn(3, 15)
51 |
52 | ###############################################################################
53 | # We apply the operator to the batch of images, which produces 3 measurements
54 | # of length 10*2 = 20.
55 | y = meas_op(x)
56 | print(y.shape)
57 |
58 | ###############################################################################
59 | # .. note::
60 | # The number of measurements is twice the number of rows of the matrix H that contains negative values.
61 |
62 | # %%
63 | # Illustration
64 | # -----------------------------------------------------------------------------
65 |
66 | ###############################################################################
67 | # We plot the positive and negative components of H that are concatenated in the matrix A.
68 |
69 | A = meas_op.A
70 | H_pos = meas_op.A[::2, :] # Even rows
71 | H_neg = meas_op.A[1::2, :] # Odd rows
72 |
73 | from spyrit.misc.disp import add_colorbar, noaxis
74 | import matplotlib.pyplot as plt
75 |
76 | fig = plt.figure(figsize=(10, 5))
77 | gs = fig.add_gridspec(2, 2)
78 |
79 | ax1 = fig.add_subplot(gs[:, 0])
80 | ax2 = fig.add_subplot(gs[0, 1])
81 | ax3 = fig.add_subplot(gs[1, 1])
82 |
83 | ax1.set_title("Forward matrix A")
84 | im = ax1.imshow(A, cmap="gray")
85 | add_colorbar(im)
86 |
87 | ax2.set_title("Forward matrix H_pos")
88 | im = ax2.imshow(H_pos, cmap="gray")
89 | add_colorbar(im)
90 |
91 | ax3.set_title("Measurements H_neg")
92 | im = ax3.imshow(H_neg, cmap="gray")
93 | add_colorbar(im)
94 |
95 | noaxis(ax1)
96 | noaxis(ax2)
97 | noaxis(ax3)
98 | # sphinx_gallery_thumbnail_number = 1
99 |
100 | ###############################################################################
101 | # We can verify numerically that H = H_pos - H_neg
102 |
103 | H = meas_op.H
104 | diff = torch.linalg.norm(H - (H_pos - H_neg))
105 |
106 | print(f"|| H - (H_pos - H_neg) || = {diff}")
107 |
108 | ###############################################################################
109 | # We now plot the matrix-vector products between A and x.
110 |
111 | f, axs = plt.subplots(1, 3, figsize=(10, 5))
112 | axs[0].set_title("Forward matrix A")
113 | im = axs[0].imshow(A, cmap="gray")
114 | add_colorbar(im, "bottom")
115 |
116 | axs[1].set_title("Signals x")
117 | im = axs[1].imshow(x.T, cmap="gray")
118 | add_colorbar(im, "bottom")
119 |
120 | axs[2].set_title("Split measurements y")
121 | im = axs[2].imshow(y.T, cmap="gray")
122 | add_colorbar(im, "bottom")
123 |
124 | noaxis(axs)
125 |
126 | # %%
127 | # Simulations with noise and using the matrix H
128 | # --------------------------------------------------------------------
129 |
130 | ######################################################################
131 | # The operators in the :mod:`spyrit.core.meas` submodule allow for simulating noisy measurements
132 | #
133 | # .. math::
134 | # y =\mathcal{N}\left(Ax\right),
135 | #
136 | # where :math:`\mathcal{N} \colon\, \mathbb{R}^{2M} \to \mathbb{R}^{2M}` represents a noise operator (e.g., Gaussian). By default, no noise is applied to the measurement, i.e., :math:`\mathcal{N}` is the identity. We can consider noise by setting the :attr:`noise_model` attribute of the :class:`spyrit.core.meas.LinearSplit` class.
137 |
138 | #####################################################################
139 | # For instance, we can consider additive Gaussian noise with standard deviation 2.
140 |
141 | from spyrit.core.noise import Gaussian
142 |
143 | meas_op.noise_model = Gaussian(2)
144 |
145 | #####################################################################
146 | # .. note::
147 | # To learn more about noise models, please refer to :ref:`tutorial 2 `.
148 |
149 | #####################################################################
150 | # We simulate the noisy measurement vectors
151 | y_noise = meas_op(x)
152 |
153 | #####################################################################
154 | # Noiseless measurements can be simulated using the :meth:`spyrit.core.LinearSplit.measure` method.
155 | y_nonoise = meas_op.measure(x)
156 |
157 | #####################################################################
158 | # The :meth:`spyrit.core.LinearSplit.measure_H` method simulates noiseless measurements using the matrix H, i.e., :math:`m = Hx`.
159 | m_nonoise = meas_op.measure_H(x)
160 |
161 | #####################################################################
162 | # We now plot the noisy and noiseless measurements
163 | f, axs = plt.subplots(1, 3, figsize=(8, 5))
164 | axs[0].set_title("Split measurements y \n with noise")
165 | im = axs[0].imshow(y_noise.mT, cmap="gray")
166 | add_colorbar(im)
167 |
168 | axs[1].set_title("Split measurements y \n without noise")
169 | im = axs[1].imshow(y_nonoise.mT, cmap="gray")
170 | add_colorbar(im)
171 |
172 | axs[2].set_title("Measurements m \n without noise")
173 | im = axs[2].imshow(m_nonoise.mT, cmap="gray")
174 | add_colorbar(im)
175 |
176 | noaxis(axs)
177 |
--------------------------------------------------------------------------------
/tutorial/tuto_01_c_HadamSplit2d.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 01.c. Acquisition operators (HadamSplit2d)
3 | ====================================================
4 | .. _tuto_acquisition_operators_HadamSplit2d:
5 |
6 | This tutorial shows how to simulate measurements that correspond to the 2D Hadamard transform of an image. It based on the :class:`spyrit.core.meas.HadamSplit2d` class of the :mod:`spyrit.core.meas` submodule.
7 |
8 |
9 | .. image:: ../fig/tuto1.png
10 | :width: 600
11 | :align: center
12 | :alt: Reconstruction architecture sketch
13 |
14 | |
15 |
16 | In practice, only positive values can be implemented using a digital micromirror device (DMD). Therefore, we acquire
17 |
18 | .. math::
19 | y = \texttt{vec}\left(AXA^T\right),
20 |
21 | where :math:`A \in \mathbb{R}_+^{2h\times h}` is the acquisition matrix that contains the positive and negative components of a Hadamard matrix and :math:`X \in \mathbb{R}^{h\times h}` is the (2D) image.
22 |
23 | We define the positive DMD patterns :math:`A` from the positive and negative components a Hadamard matrix :math:`H`. In practice, the even rows of :math:`A` contain the positive components of :math:`H`, while odd rows of :math:`A` contain the negative components of :math:`H`.
24 |
25 | .. math::
26 | \begin{cases}
27 | A[0::2, :] = H_{+}, \text{ with } H_{+} = \max(0,H),\\
28 | A[1::2, :] = H_{-}, \text{ with } H_{-} = \max(0,-H).
29 | \end{cases}
30 |
31 | """
32 |
33 | # %%
34 | # Loads images
35 | # -----------------------------------------------------------------------------
36 |
37 | ###############################################################################
38 | # We load a batch of images from the :attr:`/images/` folder with values in (0,1).
39 | import os
40 | import torchvision
41 | import torch.nn
42 |
43 | import matplotlib.pyplot as plt
44 |
45 | from spyrit.misc.disp import imagesc
46 | from spyrit.misc.statistics import transform_gray_norm
47 |
48 | spyritPath = os.getcwd()
49 | imgs_path = os.path.join(spyritPath, "images/")
50 |
51 | # Grayscale images of size 64 x 64, values in (-1,1)
52 | transform = transform_gray_norm(img_size=64)
53 |
54 | # Create dataset and loader (expects class folder 'images/test/')
55 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
56 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
57 |
58 | x, _ = next(iter(dataloader))
59 | print(f"Ground-truth images: {x.shape}")
60 |
61 | ###############################################################################
62 | # We select the second image in the batch and plot it.
63 |
64 | i_plot = 1
65 | imagesc(x[i_plot, 0, :, :], r"$64\times 64$ image $X$")
66 |
67 | # %%
68 | # Basic example
69 | # -----------------------------------------------------------------------------
70 |
71 | ######################################################################
72 | # We instantiate an HadamSplit2d object and simulate the 2D hadamard transform of the input images. As measurements are split, this produces vectors of size :math:`64 \times 64 \times 2 = 8192`.
73 | from spyrit.core.meas import HadamSplit2d
74 |
75 | meas_op = HadamSplit2d(64)
76 | y = meas_op(x)
77 |
78 | print(y.shape)
79 |
80 | ######################################################################
81 | # As with :class:`spyrit.core.meas.LinearSplit`, the :meth:`spyrit.core.HadamSplit2d.measure_H` method simulates measurements using the matrix :math:`H`, i.e., it computes :math:`m = \texttt{vec}\left(HXH^\top\right)`. This produces vectors of size :math:`64 \times 64 = 4096`.
82 | meas_op = HadamSplit2d(64)
83 | m = meas_op.measure_H(x)
84 |
85 | print(m.shape)
86 |
87 | ######################################################################
88 | # We plot the components of the positive and negative Hadamard transform that are concatenated in the measurement vector :math:`y` as well as the measurement vector :math:`m`.
89 |
90 | from spyrit.misc.disp import add_colorbar, noaxis
91 |
92 | y_pos = y[:, :, 0::2]
93 | y_neg = y[:, :, 1::2]
94 |
95 | f, axs = plt.subplots(1, 3, figsize=(10, 5))
96 | axs[0].set_title(r"$H_+XH_+^\top$")
97 | im = axs[0].imshow(y_pos[1, 0].reshape(64, 64), cmap="gray")
98 | add_colorbar(im, "bottom")
99 |
100 | axs[1].set_title(r"$H_-XH_-^\top$")
101 | im = axs[1].imshow(
102 | y_neg[
103 | 1,
104 | 0,
105 | ].reshape(64, 64),
106 | cmap="gray",
107 | )
108 | add_colorbar(im, "bottom")
109 |
110 | axs[2].set_title(r"$HXH^\top$")
111 | im = axs[2].imshow(m[1, 0].reshape(64, 64), cmap="gray")
112 | add_colorbar(im, "bottom")
113 |
114 | noaxis(axs)
115 | # sphinx_gallery_thumbnail_number = 2
116 |
117 |
118 | # %%
119 | # Subsampling
120 | # ----------------------------------------------------------------------
121 |
122 | ######################################################################
123 | # To reduce the acquisition time, only a few of the measurement can be acquired. In thise case, we simulate:
124 | #
125 | # .. math::
126 | # y = \mathcal{S}\left(AXA^T\right),
127 | #
128 | # where :math:`\mathcal{S} \colon\, \mathbb{R}^{2h\times 2h} \to \mathbb{R}^{2M}` is a subsampling operator and :math:`2M < 2h` represents the number of DMD patterns that are displayed on the DMD.
129 |
130 | ######################################################################
131 | # The subsampling operator :math:`\mathcal{S}` is defined by an order matrix :math:`O\in\mathbb{R}^{h\times h}` that ranks the measurements by decreasing significance, before retaining only the first :math:`M`.
132 |
133 | ######################################################################
134 | # .. note::
135 | # This process applies to both :math:`H_{+}XH_{+}^T` and :math:`H_{-}XH_{-}^T` the same way, independently.
136 | #
137 | # We consider two subsampling strategies:
138 | #
139 | # * The "naive" subsampling, which uses the linear (row-major) indexing order. This is the default subsampling strategy.
140 | #
141 | # * The variance subsampling, which sorts the Hadamard coefficient by decreasing variance. The motivation is that low variance coefficients are less informative than the others. This can be supported by principal component analysis, which states that preserving the components with largest variance leads to the best linear predictor.
142 |
143 | # %%
144 | # Naive subsampling
145 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
146 |
147 | ###############################################################################
148 | # The order matrix corresponding to the "naive" subsampling is given by linear values.
149 | Ord_naive = torch.arange(64 * 64, 0, step=-1).reshape(64, 64)
150 | print(Ord_naive)
151 |
152 |
153 | # %%
154 | # Variance subsampling
155 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
156 |
157 | ######################################################################
158 | # The order matrix corresponding is obtained by computing the variance of the Hadamard coefficients of the images belonging to the `ImageNet 2012 dataset `_.
159 |
160 | ######################################################################
161 | # First, we download the *covariance* matrix from our warehouse. The covariance was computed from the ImageNet 2012 dataset and has a size of (64*64, 64*64).
162 |
163 | from spyrit.misc.load_data import download_girder
164 |
165 | # url of the warehouse
166 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
167 | dataId = "672207cbf03a54733161e95d" # for reconstruction (imageNet, 64)
168 | data_folder = "./stat/"
169 | cov_name = "Cov_64x64.pt"
170 |
171 | # download the covariance matrix and get the file path
172 | file_abs_path = download_girder(url, dataId, data_folder, cov_name)
173 |
174 | try:
175 | # Load covariance matrix for "variance subsampling"
176 | Cov = torch.load(file_abs_path, weights_only=True)
177 | print(f"Cov matrix {cov_name} loaded")
178 | except:
179 | # Set to the identity if not found for "naive subsampling"
180 | Cov = torch.eye(64 * 64)
181 | print(f"Cov matrix {cov_name} not found! Set to the identity")
182 |
183 | ######################################################################
184 | # Then, we extract the variance from the covariance matrix. The variance matrix has a size
185 | # of (64, 64).
186 | from spyrit.core.torch import Cov2Var
187 |
188 | Ord_variance = Cov2Var(Cov)
189 |
190 | ######################################################################
191 | # .. note::
192 | # In this tutorial, the covariance matrix is used to define the subsampling strategy. As explained in another tutorial, the covariance matrix can also be used to reconstruct the image from the measurements.
193 |
194 | # %%
195 | # Comparison of the two subsampling strategies
196 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
197 |
198 | ######################################################################
199 | # We plot the masks corresponding to the two order matrices for a subsampling factor of 4, which corresponds to :math:`M = 64 \times 64 / 4 = 1024` measurements.
200 |
201 | # sphinx_gallery_thumbnail_number = 2
202 |
203 | ######################################################################
204 | # We build the masks using the function :func:`spyrit.core.torch.sort_by_significance` and reshape them to the image size.
205 | from spyrit.core.torch import sort_by_significance
206 |
207 | M = 64 * 64 // 4
208 | mask_basis = torch.zeros(64 * 64)
209 | mask_basis[:M] = 1
210 |
211 | # Mask for the naive subsampling
212 | mask_nai = sort_by_significance(mask_basis, Ord_naive, axis="cols")
213 | mask_nai = mask_nai.reshape(64, 64)
214 |
215 | # Mask for the variance subsampling
216 | mask_var = sort_by_significance(mask_basis, Ord_variance, axis="cols")
217 | mask_var = mask_var.reshape(64, 64)
218 |
219 | ######################################################################
220 | # We finally plot the masks.
221 | f, ax = plt.subplots(1, 2, figsize=(10, 5))
222 | im = ax[0].imshow(mask_nai, vmin=0, vmax=1)
223 | ax[0].set_title("Mask \n'naive subsampling'", fontsize=20)
224 | add_colorbar(im, "bottom", size="20%")
225 |
226 | im = ax[1].imshow(mask_var, vmin=0, vmax=1)
227 | ax[1].set_title("Mask \n'variance subsampling'", fontsize=20)
228 | add_colorbar(im, "bottom", size="20%")
229 |
230 | noaxis(ax)
231 |
232 | # %%
233 | # Measurements for accelerated acquisitions
234 | # --------------------------------------------------------------------
235 |
236 | ######################################################################
237 | # We instantiate two HadamSplit2d objects corresponding to the two subsampling strategies. By default, the HadamSplit2d object uses the "naive" subsampling strategy.
238 | meas_nai = HadamSplit2d(64, M=M)
239 |
240 | ######################################################################
241 | # For the variance subsampling strategy, we specify the order matrix using the :attr:`order` attribute.
242 | meas_var = HadamSplit2d(64, M=M, order=Ord_variance)
243 |
244 | ###############################################################################
245 | # We now simulate the measurements from both subsampling strategies. Here, we simulate measurements using the matrix :math:`H`, i.e., we compute :math:`m = HXH^\top`. This produces vectors of size :math:`M = 64 \times 64 / 4 = 1024`.
246 |
247 | m_nai = meas_nai.measure_H(x)
248 | m_var = meas_var.measure_H(x)
249 |
250 | print(f"Shape of measurement vectors: {m_nai.shape}")
251 |
252 | ###############################################################################
253 | # We transform the two measurement vectors as images in the Hadamard domain thanks to the function :meth:`spyrit.core.torch.meas2img`.
254 |
255 | from spyrit.core.torch import meas2img
256 |
257 | m_nai_plot = meas2img(m_nai, Ord_naive)
258 | m_var_plot = meas2img(m_var, Ord_variance)
259 |
260 | print(f"Shape of measurements: {m_nai_plot.shape}")
261 |
262 | ###############################################################################
263 | # We finally plot the measurements corresponding to one image in the batch.
264 | f, ax = plt.subplots(1, 2, figsize=(10, 5))
265 | im = ax[0].imshow(m_nai_plot[i_plot, 0, :, :], cmap="gray")
266 | ax[0].set_title("Measurements \n 'Naive' subsampling", fontsize=20)
267 | add_colorbar(im, "bottom")
268 |
269 | im = ax[1].imshow(m_var_plot[i_plot, 0, :, :], cmap="gray")
270 | ax[1].set_title("Measurements \n Variance subsampling", fontsize=20)
271 | add_colorbar(im, "bottom")
272 |
273 | noaxis(ax)
274 |
275 | ###############################################################################
276 | # We can also simulate the split measurements, i.e., the measurement obtained from the positive and negative components of the Hadamard transform. This produces vectors of size :math:`2 M = 2 \times 64 \times 64 / 4 = 2048`.
277 |
278 | y_var = meas_var(x)
279 | print(f"Shape of split measurements: {y_var.shape}")
280 |
281 |
282 | ###############################################################################
283 | # We separate the positive and negative components of the split measurements.
284 | y_var_pos = y_var[..., ::2] # Even rows
285 | y_var_neg = y_var[..., 1::2] # Odd rows
286 |
287 | print(f"Shape of the positive component: {y_var_pos.shape}")
288 | print(f"Shape of the negative component: {y_var_neg.shape}")
289 |
290 | ###############################################################################
291 | # We now send the measurement vectors to Hadamard domain to plot them as images.
292 | m_plot_1 = meas2img(y_var_pos, Ord_variance)
293 | m_plot_2 = meas2img(y_var_neg, Ord_variance)
294 |
295 | print(f"Shape of the positive component: {m_plot_1.shape}")
296 | print(f"Shape of the negative component: {m_plot_2.shape}")
297 |
298 | ###############################################################################
299 | # We finally plot the measurements corresponding to one image in the batch
300 | f, ax = plt.subplots(1, 2, figsize=(10, 5))
301 | im = ax[0].imshow(m_plot_1[i_plot, 0, :, :], cmap="gray")
302 | ax[0].set_title(r"$\mathcal{S}\left(H_+XH_+^\top\right)$", fontsize=20)
303 | add_colorbar(im, "bottom")
304 |
305 | im = ax[1].imshow(m_plot_2[i_plot, 0, :, :], cmap="gray")
306 | ax[1].set_title(r"$\mathcal{S}\left(H_-XH_-^\top\right)$", fontsize=20)
307 | add_colorbar(im, "bottom")
308 |
309 | noaxis(ax)
310 |
--------------------------------------------------------------------------------
/tutorial/tuto_02_noise.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 02. Noise operators
3 | ===================================================
4 | .. _tuto_noise:
5 |
6 | This tutorial shows how to use noise operators using the :mod:`spyrit.core.noise` submodule.
7 |
8 | .. image:: ../fig/tuto2.png
9 | :width: 600
10 | :align: center
11 | :alt: Reconstruction architecture sketch
12 |
13 | |
14 | """
15 |
16 | # %%
17 | # Load a batch of images
18 | # -----------------------------------------------------------------------------
19 |
20 | ###############################################################################
21 | # We load a batch of images from the `/images/` folder. Using the
22 | # :func:`transform_gray_norm` function with the :attr:`normalize=False`
23 | # argument returns images with values in (0,1).
24 | import os
25 |
26 | import torch
27 | import torchvision
28 | import matplotlib.pyplot as plt
29 |
30 | from spyrit.misc.disp import imagesc
31 | from spyrit.misc.statistics import transform_gray_norm
32 |
33 | spyritPath = os.getcwd()
34 | imgs_path = os.path.join(spyritPath, "images/")
35 |
36 | # Grayscale images of size 64 x 64, no normalization to keep values in (0,1)
37 | transform = transform_gray_norm(img_size=64, normalize=False)
38 |
39 | # Create dataset and loader (expects class folder 'images/test/')
40 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
41 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
42 |
43 | x, _ = next(iter(dataloader))
44 | print(f"Shape of input images: {x.shape}")
45 |
46 |
47 | ###############################################################################
48 | # We select the first image in the batch and plot it.
49 |
50 | i_plot = 1
51 | imagesc(x[i_plot, 0, :, :], r"$x$ in (0, 1)")
52 |
53 |
54 | # %%
55 | # Gaussian noise
56 | # -----------------------------------------------------------------------------
57 |
58 | ###############################################################################
59 | # We consider additive Gaussiane noise,
60 | #
61 | # .. math::
62 | # y \sim z + \mathcal{N}(0,\sigma^2),
63 | #
64 | # where :math:`\mathcal{N}(\mu, \sigma^2)` is a Gaussian distribution with mean :math:`\mu` and variance :math:`\sigma^2`, and :math:`z` is the noiseless image. The larger :math:`\sigma`, the lower the signal-to-noise ratio.
65 |
66 | ###############################################################################
67 | # To add 10% Gaussian noise, we instantiate a :class:`spyrit.core.noise`
68 | # operator with :attr:`sigma=0.1`.
69 |
70 | from spyrit.core.noise import Gaussian
71 |
72 | noise_op = Gaussian(sigma=0.1)
73 | x_noisy = noise_op(x)
74 |
75 | imagesc(x_noisy[1, 0, :, :], r"10% Gaussian noise")
76 | # sphinx_gallery_thumbnail_number = 2
77 |
78 | ###############################################################################
79 | # To add 2% Gaussian noise, we update the class attribute :attr:`sigma`.
80 |
81 | noise_op.sigma = 0.02
82 | x_noisy = noise_op(x)
83 |
84 | imagesc(x_noisy[1, 0, :, :], r"2% Gaussian noise")
85 |
86 | # %%
87 | # Poisson noise
88 | # -----------------------------------------------------------------------------
89 |
90 | ###############################################################################
91 | # We now consider Poisson noise,
92 | #
93 | # .. math::
94 | # y \sim \mathcal{P}(\alpha z), \quad z \ge 0,
95 | #
96 | # where :math:`\alpha \ge 0` is a scalar value that represents the maximum
97 | # image intensity (in photons). The larger :math:`\alpha`, the higher the signal-to-noise ratio.
98 |
99 | ###############################################################################
100 | # We consider the :class:`spyrit.core.noise.Poisson` class and set :math:`\alpha`
101 | # to 100 photons (which corresponds to the Poisson parameter).
102 |
103 | from spyrit.core.noise import Poisson
104 | from spyrit.misc.disp import add_colorbar, noaxis
105 |
106 | alpha = 100 # number of photons
107 | noise_op = Poisson(alpha)
108 |
109 | ###############################################################################
110 | # We simulate two noisy versions of the same images
111 |
112 | y1 = noise_op(x) # first sample
113 | y2 = noise_op(x) # another sample
114 |
115 | ###############################################################################
116 | # We now consider the case :math:`\alpha = 1000` photons.
117 |
118 | noise_op.alpha = 1000
119 | y3 = noise_op(x) # noisy measurement vector
120 |
121 | ###############################################################################
122 | # We finally plot the noisy images
123 |
124 | # plot
125 | f, axs = plt.subplots(1, 3, figsize=(10, 5))
126 | axs[0].set_title("100 photons")
127 | im = axs[0].imshow(y1[1, 0].reshape(64, 64), cmap="gray")
128 | add_colorbar(im, "bottom")
129 |
130 | axs[1].set_title("100 photons")
131 | im = axs[1].imshow(y2[1, 0].reshape(64, 64), cmap="gray")
132 | add_colorbar(im, "bottom")
133 |
134 | axs[2].set_title("1000 photons")
135 | im = axs[2].imshow(
136 | y3[
137 | 1,
138 | 0,
139 | ].reshape(64, 64),
140 | cmap="gray",
141 | )
142 | add_colorbar(im, "bottom")
143 |
144 | noaxis(axs)
145 |
146 | ###############################################################################
147 | # As expected the signal-to-noise ratio of the measurement vector is higher for
148 | # 1,000 photons than for 100 photons
149 | #
150 | # .. note::
151 | # Not only the signal-to-noise, but also the scale of the measurements
152 | # depends on :math:`\alpha`, which motivates the introduction of the
153 | # preprocessing operator.
154 |
--------------------------------------------------------------------------------
/tutorial/tuto_03_pseudoinverse_linear.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 03. Pseudoinverse solution from linear measurements
3 | ===================================================
4 | .. _tuto_pseudoinverse_linear:
5 |
6 | This tutorial shows how to simulate measurements and perform image reconstruction using the :class:`spyrit.core.inverse.PseudoInverse` class of the :mod:`spyrit.core.inverse` submodule.
7 |
8 | .. image:: ../fig/tuto3_pinv.png
9 | :width: 600
10 | :align: center
11 | :alt: Reconstruction architecture sketch
12 |
13 | |
14 | """
15 |
16 | # %%
17 | # Loads images
18 | # -----------------------------------------------------------------------------
19 |
20 | ###############################################################################
21 | # We load a batch of images from the :attr:`/images/` folder. Using the
22 | # :func:`spyrit.misc.statistics.transform_gray_norm` function with the :attr:`normalize=False`
23 | # argument returns images with values in (0,1).
24 | import os
25 | import torchvision
26 | import torch.nn
27 | from spyrit.misc.statistics import transform_gray_norm
28 |
29 | spyritPath = os.getcwd()
30 | imgs_path = os.path.join(spyritPath, "images/")
31 |
32 | # Grayscale images of size 32 x 32, no normalization to keep values in (0,1)
33 | transform = transform_gray_norm(img_size=64, normalize=False)
34 |
35 | # Create dataset and loader (expects class folder 'images/test/')
36 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
37 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
38 |
39 | x, _ = next(iter(dataloader))
40 | print(f"Ground-truth images: {x.shape}")
41 |
42 |
43 | # %%
44 | # Linear measurements without noise
45 | # -----------------------------------------------------------------------------
46 |
47 | ###############################################################################
48 | # We consider a Hadamard matrix in "2D". The matrix has a shape of (64*64, 64*64)and values in {-1, 1}.
49 | from spyrit.core.torch import walsh_matrix_2d
50 |
51 | H = walsh_matrix_2d(64)
52 |
53 | print(f"Acquisition matrix: {H.shape}", end=" ")
54 | print(rf"with values in {{{H.min()}, {H.max()}}}")
55 |
56 | ###############################################################################
57 | # We instantiate a :class:`spyrit.core.meas.Linear` operator. To indicate that the operator works in 2D, on images with shape (64, 64), we use the :attr:`meas_shape` argument.
58 | from spyrit.core.meas import Linear
59 |
60 | meas_op = Linear(H, (64, 64))
61 |
62 | ###############################################################################
63 | # We simulate the measurement vectors, which have a shape of (7, 1, 4096).
64 | y = meas_op(x)
65 |
66 | print(f"Measurement vectors: {y.shape}")
67 |
68 | ###############################################################################
69 | # We now compute the pseudo inverse solutions, which have a shape of (7, 1, 64, 64).
70 | from spyrit.core.inverse import PseudoInverse
71 |
72 | pinv = PseudoInverse(meas_op)
73 | x_rec = pinv(y)
74 |
75 | print(f"Reconstructed images: {x_rec.shape}")
76 |
77 | ###############################################################################
78 | # We plot the reconstruction of the second image in the batch
79 | from spyrit.misc.disp import imagesc, add_colorbar
80 |
81 | imagesc(x_rec[1, 0])
82 | # sphinx_gallery_thumbnail_number = 1
83 |
84 | ###############################################################################
85 | # .. note::
86 | # The measurement operator is chosen as a Hadamard matrix with positive but this matrix can be replaced by any other matrix.
87 |
88 | # %%
89 | # LinearSplit measurements with Gaussian noise
90 | # -----------------------------------------------------------------------------
91 |
92 | ###############################################################################
93 | # We consider a linear operator where the positive and negative components are split, i.e. acquired separately. To do so, we instantiate a :class:`spyrit.core.meas.LinearSplit` operator.
94 | from spyrit.core.meas import LinearSplit
95 |
96 | meas_op = LinearSplit(H, (64, 64))
97 |
98 | ###############################################################################
99 | # We consider additive Gaussian noise with standard deviation 2.
100 | from spyrit.core.noise import Gaussian
101 |
102 | meas_op.noise_model = Gaussian(2)
103 |
104 | ###############################################################################
105 | # We simulate the measurement vectors, which have shape (7, 1, 8192).
106 | y = meas_op(x)
107 |
108 | print(f"Measurement vectors: {y.shape}")
109 |
110 | ###############################################################################
111 | # We preprocess measurement vectors by computing the difference of the positive and negative components of the measurement vectors. To do so, we use the :class:`spyrit.core.prep.Unsplit` class. The preprocess measurements have a shape of (7, 1, 4096).
112 |
113 | from spyrit.core.prep import Unsplit
114 |
115 | prep = Unsplit()
116 | m = prep(y)
117 |
118 | print(f"Preprocessed measurement vectors: {m.shape}")
119 |
120 | ###############################################################################
121 | # We now compute the pseudo inverse solutions, which have a shape of (7, 1, 64, 64).
122 | from spyrit.core.inverse import PseudoInverse
123 |
124 | pinv = PseudoInverse(meas_op)
125 | x_rec = pinv(m)
126 |
127 | print(f"Reconstructed images: {x_rec.shape}")
128 |
129 | ###############################################################################
130 | # We plot the reconstruction
131 | from spyrit.misc.disp import imagesc, add_colorbar
132 |
133 | imagesc(x_rec[1, 0])
134 |
135 | # %%
136 | # HadamSplit2d with x4 subsampling with Poisson noise
137 | # -----------------------------------------------------------------------------
138 |
139 | ###############################################################################
140 | # We consider the acquisition of the 2D Hadamard transform of an image, where the positive and negative components of acquisition matrix are acquired separately. To do so, we use the dedicated :class:`spyrit.core.meas.HadamSplit2d` operator. It also allows for subsampling the rows the Hadamard matrix, using a sampling map.
141 |
142 | from spyrit.core.meas import HadamSplit2d
143 |
144 | # Sampling map with ones in the top left corner and zeros elsewhere (low-frequency subsampling)
145 | sampling_map = torch.ones((64, 64))
146 | sampling_map[:, 64 // 2 :] = 0
147 | sampling_map[64 // 2 :, :] = 0
148 |
149 | # Linear operator with HadamSplit2d
150 | meas_op = HadamSplit2d(64, 64**2 // 4, order=sampling_map, reshape_output=True)
151 |
152 | ###############################################################################
153 | # We consider additive Poisson noise with an intensity of 100 photons.
154 | from spyrit.core.noise import Poisson
155 |
156 | meas_op.noise_model = Poisson(100)
157 |
158 |
159 | ###############################################################################
160 | # We simulate the measurement vectors, which have a shape of (7, 1, 2048)
161 |
162 | ###############################################################################
163 | # .. note::
164 | # The :class:`spyrit.core.noise.Poisson` class noise assumes that the images are in the range [0, 1]
165 | y = meas_op(x)
166 |
167 | print(rf"Reference images with values in {{{x.min()}, {x.max()}}}")
168 | print(f"Measurement vectors: {y.shape}")
169 |
170 | ###############################################################################
171 | # We preprocess measurement vectors by i) computing the difference of the positive and negative components, and ii) normalizing the intensity. To do so, we use the :class:`spyrit.core.prep.UnsplitRescale` class. The preprocessed measurements have a shape of (7, 1, 1024).
172 |
173 | from spyrit.core.prep import UnsplitRescale
174 |
175 | prep = UnsplitRescale(100)
176 |
177 | m = prep(y) # (y+ - y-)/alpha
178 | print(f"Preprocessed measurement vectors: {m.shape}")
179 |
180 | ###############################################################################
181 | # We compute the pseudo inverse solution, which has a shape of (7, 1, 64, 64).
182 |
183 | x_rec = meas_op.fast_pinv(m)
184 |
185 | print(f"Reconstructed images: {x_rec.shape}")
186 |
187 | ###############################################################################
188 | # .. note::
189 | # There is no need to use the :class:`spyrit.core.inverse.PseudoInverse` class here, as the :class:`spyrit.core.meas.HadamSplit2d` class includes a method that returns the pseudo inverse solution.
190 |
191 | ###############################################################################
192 | # We plot the reconstruction
193 | from spyrit.misc.disp import imagesc, add_colorbar
194 |
195 | imagesc(x_rec[1, 0])
196 |
--------------------------------------------------------------------------------
/tutorial/wip/_tuto_03_pseudoinverse_cnn_linear.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 03. Pseudoinverse solution + CNN denoising
3 | ==========================================
4 | .. _tuto_pseudoinverse_cnn_linear:
5 |
6 | This tutorial shows how to simulate measurements and perform image reconstruction
7 | using PinvNet (pseudoinverse linear network) with CNN denoising as a last layer.
8 | This tutorial is a continuation of the :ref:`Pseudoinverse solution tutorial `
9 | but uses a CNN denoiser instead of the identity operator in order to remove artefacts.
10 |
11 | The measurement operator is chosen as a Hadamard matrix with positive coefficients,
12 | which can be replaced by any matrix.
13 |
14 | .. image:: ../fig/tuto3.png
15 | :width: 600
16 | :align: center
17 | :alt: Reconstruction and neural network denoising architecture sketch
18 |
19 | These tutorials load image samples from `/images/`.
20 | """
21 |
22 | # %%
23 | # Load a batch of images
24 | # -----------------------------------------------------------------------------
25 |
26 | ###############################################################################
27 | # Images :math:`x` for training expect values in [-1,1]. The images are normalized
28 | # using the :func:`transform_gray_norm` function.
29 | if False:
30 |
31 | import os
32 |
33 | import torch
34 | import torchvision
35 | import matplotlib.pyplot as plt
36 |
37 | import spyrit.core.torch as spytorch
38 | from spyrit.misc.disp import imagesc
39 | from spyrit.misc.statistics import transform_gray_norm
40 |
41 | # sphinx_gallery_thumbnail_path = 'fig/tuto3.png'
42 |
43 | h = 64 # image size hxh
44 | i = 1 # Image index (modify to change the image)
45 | spyritPath = os.getcwd()
46 | imgs_path = os.path.join(spyritPath, "images/")
47 |
48 | # Create a transform for natural images to normalized grayscale image tensors
49 | transform = transform_gray_norm(img_size=h)
50 |
51 | # Create dataset and loader (expects class folder 'images/test/')
52 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
53 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
54 |
55 | x, _ = next(iter(dataloader))
56 | print(f"Shape of input images: {x.shape}")
57 |
58 | # Select image
59 | x = x[i : i + 1, :, :, :]
60 | x = x.detach().clone()
61 | print(f"Shape of selected image: {x.shape}")
62 | b, c, h, w = x.shape
63 |
64 | # plot
65 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]")
66 |
67 | # %%
68 | # Define a measurement operator
69 | # -----------------------------------------------------------------------------
70 |
71 | ###############################################################################
72 | # We consider the case where the measurement matrix is the positive
73 | # component of a Hadamard matrix and the sampling operator preserves only
74 | # the first :attr:`M` low-frequency coefficients
75 | # (see :ref:`Positive Hadamard matrix ` for full explantion).
76 |
77 | import math
78 |
79 | F = spytorch.walsh_matrix_2d(h)
80 | F = torch.max(F, torch.zeros_like(F))
81 |
82 | und = 4 # undersampling factor
83 | M = h**2 // und # number of measurements (undersampling factor = 4)
84 |
85 | Sampling_map = torch.zeros(h, h)
86 | M_xy = math.ceil(M**0.5)
87 | Sampling_map[:M_xy, :M_xy] = 1
88 |
89 | imagesc(Sampling_map, "low-frequency sampling map")
90 |
91 | ###############################################################################
92 | # After permutation of the full Hadamard matrix, we keep only its first
93 | # :attr:`M` rows
94 |
95 | F = spytorch.sort_by_significance(F, Sampling_map, "rows", False)
96 | H = F[:M, :]
97 |
98 | print(f"Shape of the measurement matrix: {H.shape}")
99 |
100 | ###############################################################################
101 | # Then, we instantiate a :class:`spyrit.core.meas.Linear` measurement operator
102 |
103 | from spyrit.core.meas import Linear
104 |
105 | meas_op = Linear(H, pinv=True)
106 |
107 | # %%
108 | # Noiseless case
109 | # -----------------------------------------------------------------------------
110 |
111 | ###############################################################################
112 | # In the noiseless case, we consider the :class:`spyrit.core.noise.NoNoise` noise
113 | # operator
114 |
115 | from spyrit.core.noise import NoNoise
116 |
117 | N0 = 1.0 # Noise level (noiseless)
118 | noise = NoNoise(meas_op)
119 |
120 | # Simulate measurements
121 | y = noise(x)
122 | print(f"Shape of raw measurements: {y.shape}")
123 |
124 | ###############################################################################
125 | # We now compute and plot the preprocessed measurements corresponding to an
126 | # image in [-1,1]
127 |
128 | from spyrit.core.prep import DirectPoisson
129 |
130 | prep = DirectPoisson(N0, meas_op) # "Undo" the NoNoise operator
131 |
132 | m = prep(y)
133 | print(f"Shape of the preprocessed measurements: {m.shape}")
134 |
135 | ###############################################################################
136 | # To display the subsampled measurement vector as an image in the transformed
137 | # domain, we use the :func:`spyrit.core.torch.meas2img` function
138 |
139 | # plot
140 | m_plot = spytorch.meas2img(m, Sampling_map)
141 | print(f"Shape of the preprocessed measurement image: {m_plot.shape}")
142 |
143 | imagesc(m_plot[0, 0, :, :], "Preprocessed measurements (no noise)")
144 |
145 | # %%
146 | # PinvNet Network
147 | # -----------------------------------------------------------------------------
148 |
149 | ###############################################################################
150 | # We consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs an
151 | # image by computing the pseudoinverse solution, which is fed to a neural
152 | # network denoiser. To compute the pseudoinverse solution only, the denoiser
153 | # can be set to the identity operator
154 |
155 | from spyrit.core.recon import PinvNet
156 |
157 | pinv_net = PinvNet(noise, prep, denoi=torch.nn.Identity())
158 |
159 | ###############################################################################
160 | # or equivalently
161 | pinv_net = PinvNet(noise, prep)
162 |
163 | ###############################################################################
164 | # Then, we reconstruct the image from the measurement vector :attr:`y` using the
165 | # :func:`~spyrit.core.recon.PinvNet.reconstruct` method
166 |
167 | x_rec = pinv_net.reconstruct(y)
168 |
169 | # %%
170 | # Removing artefacts with a CNN
171 | # -----------------------------------------------------------------------------
172 |
173 | ###############################################################################
174 | # Artefacts can be removed by selecting a neural network denoiser
175 | # (last layer of PinvNet). We select a simple CNN using the
176 | # :class:`spyrit.core.nnet.ConvNet` class, but this can be replaced by any
177 | # neural network (eg. UNet from :class:`spyrit.core.nnet.Unet`).
178 |
179 | ###############################################################################
180 | # .. image:: ../fig/pinvnet_cnn.png
181 | # :width: 400
182 | # :align: center
183 | # :alt: Sketch of the PinvNet with CNN architecture
184 |
185 | from spyrit.core.nnet import ConvNet
186 | from spyrit.core.train import load_net
187 |
188 | # Define PInvNet with ConvNet denoising layer
189 | denoi = ConvNet()
190 | pinv_net_cnn = PinvNet(noise, prep, denoi)
191 |
192 | # Send to GPU if available
193 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
194 | print("Using device:", device)
195 | pinv_net_cnn = pinv_net_cnn.to(device)
196 |
197 | ###############################################################################
198 | # As an example, we use a simple ConvNet that has been pretrained using STL-10 dataset.
199 | # We download the pretrained weights and load them into the network.
200 |
201 | from spyrit.misc.load_data import download_girder
202 |
203 | # Load pretrained model
204 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
205 | dataID = "67221889f03a54733161e963" # unique ID of the file
206 | local_folder = "./model/"
207 | data_name = "tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth"
208 | # download the model and save it in the local folder
209 | model_cnn_path = download_girder(url, dataID, local_folder, data_name)
210 |
211 | # Load model weights
212 | load_net(model_cnn_path, pinv_net_cnn, device, False)
213 |
214 | ###############################################################################
215 | # We now reconstruct the image using PinvNet with pretrained CNN denoising
216 | # and plot results side by side with the PinvNet without denoising
217 |
218 | from spyrit.misc.disp import add_colorbar, noaxis
219 |
220 | with torch.no_grad():
221 | x_rec_cnn = pinv_net_cnn.reconstruct(y.to(device))
222 | x_rec_cnn = pinv_net_cnn(x.to(device))
223 |
224 | # plot
225 | f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
226 |
227 | im1 = ax1.imshow(x[0, 0, :, :], cmap="gray")
228 | ax1.set_title("Ground-truth image", fontsize=20)
229 | noaxis(ax1)
230 | add_colorbar(im1, "bottom", size="20%")
231 |
232 | im2 = ax2.imshow(x_rec[0, 0, :, :], cmap="gray")
233 | ax2.set_title("Pinv reconstruction", fontsize=20)
234 | noaxis(ax2)
235 | add_colorbar(im2, "bottom", size="20%")
236 |
237 | im3 = ax3.imshow(x_rec_cnn.cpu()[0, 0, :, :], cmap="gray")
238 | ax3.set_title(f"Pinv + CNN (trained 30 epochs", fontsize=20)
239 | noaxis(ax3)
240 | add_colorbar(im3, "bottom", size="20%")
241 |
242 | plt.show()
243 |
244 | ###############################################################################
245 | # We show the best result again (tutorial thumbnail purpose)
246 |
247 | # Plot
248 | imagesc(
249 | x_rec_cnn.cpu()[0, 0, :, :], f"Pinv + CNN (trained 30 epochs", title_fontsize=20
250 | )
251 |
252 | ###############################################################################
253 | # In the next tutorial, we will show how to train PinvNet + CNN denoiser.
254 |
--------------------------------------------------------------------------------
/tutorial/wip/_tuto_04_train_pseudoinverse_cnn_linear.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 04. Train pseudoinverse solution + CNN denoising
3 | ================================================
4 | .. _tuto_train_pseudoinverse_cnn_linear:
5 |
6 | This tutorial shows how to train PinvNet with a CNN denoiser for
7 | reconstruction of linear measurements (results shown in the
8 | :ref:`previous tutorial `).
9 | As an example, we use a small CNN, which can be replaced by any other network,
10 | for example Unet. Training is performed on the STL-10 dataset.
11 |
12 | You can use Tensorboard for Pytorch for experiment tracking and
13 | for visualizing the training process: losses, network weights,
14 | and intermediate results (reconstructed images at different epochs).
15 |
16 | The linear measurement operator is chosen as the positive part of a Hadamard matrix,
17 | but this matrix can be replaced by any desired matrix.
18 |
19 | These tutorials load image samples from `/images/`.
20 | """
21 |
22 | # %%
23 | # Load a batch of images
24 | # -----------------------------------------------------------------------------
25 |
26 | ###############################################################################
27 | # First, we load an image :math:`x` and normalized it to [-1,1], as in previous examples.
28 | if False:
29 |
30 | import os
31 |
32 | import torch
33 | import torchvision
34 | import matplotlib.pyplot as plt
35 |
36 | import spyrit.core.torch as spytorch
37 | from spyrit.misc.disp import imagesc
38 | from spyrit.misc.statistics import transform_gray_norm
39 |
40 | h = 64 # image size hxh
41 | i = 1 # Image index (modify to change the image)
42 | spyritPath = os.getcwd()
43 | imgs_path = os.path.join(spyritPath, "images/")
44 |
45 | # Create a transform for natural images to normalized grayscale image tensors
46 | transform = transform_gray_norm(img_size=h)
47 |
48 | # Create dataset and loader (expects class folder 'images/test/')
49 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
50 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
51 |
52 | x, _ = next(iter(dataloader))
53 | print(f"Shape of input images: {x.shape}")
54 |
55 | # Select image
56 | x = x[i : i + 1, :, :, :]
57 | x = x.detach().clone()
58 | print(f"Shape of selected image: {x.shape}")
59 | b, c, h, w = x.shape
60 |
61 | # plot
62 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]")
63 |
64 | # %%
65 | # Define a dataloader
66 | # -----------------------------------------------------------------------------
67 | # We define a dataloader for STL-10 dataset using :func:`spyrit.misc.statistics.data_loaders_stl10`.
68 | # This will download the dataset to the provided path if it is not already downloaded.
69 | # It is based on pytorch pre-loaded dataset :class:`torchvision.datasets.STL10` and
70 | # :class:`torch.utils.data.DataLoader`, which creates a generator that iterates
71 | # through the dataset, returning a batch of images and labels at each iteration.
72 | #
73 | # Set :attr:`mode_run` to True in the script below to download the dataset and for training;
74 | # otherwise, pretrained weights and results will be download for display.
75 |
76 | from spyrit.misc.statistics import data_loaders_stl10
77 | from pathlib import Path
78 |
79 | # Parameters
80 | h = 64 # image size hxh
81 | data_root = Path("./data") # path to data folder (where the dataset is stored)
82 | batch_size = 512
83 |
84 | # Dataloader for STL-10 dataset
85 | mode_run = False
86 | if mode_run:
87 | dataloaders = data_loaders_stl10(
88 | data_root,
89 | img_size=h,
90 | batch_size=batch_size,
91 | seed=7,
92 | shuffle=True,
93 | download=True,
94 | )
95 |
96 | # %%
97 | # Define a measurement operator
98 | # -----------------------------------------------------------------------------
99 |
100 | ###############################################################################
101 | # We consider the case where the measurement matrix is the positive
102 | # component of a Hadamard matrix, which is often used in single-pixel imaging
103 | # (see :ref:`Hadamard matrix `).
104 | # Then, we simulate an accelerated acquisition by keeping only the first
105 | # :attr:`M` low-frequency coefficients (see :ref:`low frequency sampling `).
106 |
107 | import math
108 |
109 | und = 4 # undersampling factor
110 | M = h**2 // und # number of measurements (undersampling factor = 4)
111 |
112 | F = spytorch.walsh_matrix_2d(h)
113 | F = torch.max(F, torch.zeros_like(F))
114 |
115 | Sampling_map = torch.zeros(h, h)
116 | M_xy = math.ceil(M**0.5)
117 | Sampling_map[:M_xy, :M_xy] = 1
118 |
119 | # imagesc(Sampling_map, 'low-frequency sampling map')
120 |
121 | F = spytorch.sort_by_significance(F, Sampling_map, "rows", False)
122 | H = F[:M, :]
123 |
124 | print(f"Shape of the measurement matrix: {H.shape}")
125 |
126 | ###############################################################################
127 | # Then, we instantiate a :class:`spyrit.core.meas.Linear` measurement operator,
128 | # a :class:`spyrit.core.noise.NoNoise` noise operator for noiseless case,
129 | # and a preprocessing measurements operator :class:`spyrit.core.prep.DirectPoisson`.
130 |
131 | from spyrit.core.meas import Linear
132 | from spyrit.core.noise import NoNoise
133 | from spyrit.core.prep import DirectPoisson
134 |
135 | meas_op = Linear(H, pinv=True)
136 | noise = NoNoise(meas_op)
137 | N0 = 1.0 # Mean maximum total number of photons
138 | prep = DirectPoisson(N0, meas_op) # "Undo" the NoNoise operator
139 |
140 | # %%
141 | # PinvNet Network
142 | # -----------------------------------------------------------------------------
143 |
144 | ###############################################################################
145 | # We consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs an
146 | # image by computing the pseudoinverse solution and applies a nonlinear
147 | # network denoiser. First, we must define the denoiser. As an example,
148 | # we choose a small CNN using the :class:`spyrit.core.nnet.ConvNet` class.
149 | # Then, we define the PinvNet network by passing the noise and preprocessing operators
150 | # and the denoiser.
151 |
152 | from spyrit.core.nnet import ConvNet
153 | from spyrit.core.recon import PinvNet
154 |
155 | denoiser = ConvNet()
156 | model = PinvNet(noise, prep, denoi=denoiser)
157 |
158 | # Send to GPU if available
159 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
160 |
161 | # Use multiple GPUs if available
162 | if torch.cuda.device_count() > 1:
163 | print("Let's use", torch.cuda.device_count(), "GPUs!")
164 | model = nn.DataParallel(model)
165 |
166 | print("Using device:", device)
167 | model = model.to(device)
168 |
169 | ###############################################################################
170 | # .. note::
171 | #
172 | # In the example provided, we choose a small CNN using the :class:`spyrit.core.nnet.ConvNet` class.
173 | # This can be replaced by any denoiser, for example the :class:`spyrit.core.nnet.Unet` class
174 | # or a custom denoiser.
175 |
176 | # %%
177 | # Define a Loss function optimizer and scheduler
178 | # -----------------------------------------------------------------------------
179 |
180 | ###############################################################################
181 | # In order to train the network, we need to define a loss function, an optimizer
182 | # and a scheduler. We use the Mean Square Error (MSE) loss function, weigh decay
183 | # loss and the Adam optimizer. The scheduler decreases the learning rate
184 | # by a factor of :attr:`gamma` every :attr:`step_size` epochs.
185 |
186 | import torch.nn as nn
187 | import torch.optim as optim
188 | from torch.optim import lr_scheduler
189 | from spyrit.core.train import save_net, Weight_Decay_Loss
190 |
191 | # Parameters
192 | lr = 1e-3
193 | step_size = 10
194 | gamma = 0.5
195 |
196 | loss = nn.MSELoss()
197 | criterion = Weight_Decay_Loss(loss)
198 | optimizer = optim.Adam(model.parameters(), lr=lr)
199 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
200 |
201 | # %%
202 | # Train the network
203 | # -----------------------------------------------------------------------------
204 |
205 | ###############################################################################
206 | # To train the network, we use the :func:`~spyrit.core.train.train_model` function,
207 | # which handles the training process. It iterates through the dataloader, feeds the inputs to the
208 | # network and optimizes the solution (by computing the loss and its gradients and
209 | # updating the network weights at each iteration). In addition, it computes
210 | # the loss and desired metrics on the training and validation sets at each iteration.
211 | # The training process can be monitored using Tensorboard.
212 |
213 | ###############################################################################
214 | # .. note::
215 | #
216 | # To launch Tensorboard type in a new console:
217 | #
218 | # tensorboard --logdir runs
219 | #
220 | # and open the provided link in a browser. The training process can be monitored
221 | # in real time in the "Scalars" tab. The "Images" tab allows to visualize the
222 | # reconstructed images at different iterations :attr:`tb_freq`.
223 |
224 | ###############################################################################
225 | # In order to train, you must set :attr:`mode_run` to True for training. It is set to False
226 | # by default to download the pretrained weights and results for display,
227 | # as training takes around 40 min for 30 epochs.
228 |
229 | # We train for one epoch only to check that everything works fine.
230 |
231 | from spyrit.core.train import train_model
232 | from datetime import datetime
233 |
234 | # Parameters
235 | model_root = Path("./model") # path to model saving files
236 | num_epochs = 5 # number of training epochs (num_epochs = 30)
237 | checkpoint_interval = 2 # interval between saving model checkpoints
238 | tb_freq = 50 # interval between logging to Tensorboard (iterations through the dataloader)
239 |
240 | # Path for Tensorboard experiment tracking logs
241 | name_run = "stdl10_hadampos"
242 | now = datetime.now().strftime("%Y-%m-%d_%H-%M")
243 | tb_path = f"runs/runs_{name_run}_n{int(N0)}_m{M}/{now}"
244 |
245 | # Train the network
246 | if mode_run:
247 | model, train_info = train_model(
248 | model,
249 | criterion,
250 | optimizer,
251 | scheduler,
252 | dataloaders,
253 | device,
254 | model_root,
255 | num_epochs=num_epochs,
256 | disp=True,
257 | do_checkpoint=checkpoint_interval,
258 | tb_path=tb_path,
259 | tb_freq=tb_freq,
260 | )
261 | else:
262 | train_info = {}
263 |
264 | # %%
265 | # Save the network and training history
266 | # -----------------------------------------------------------------------------
267 |
268 | ###############################################################################
269 | # We save the model so that it can later be utilized. We save the network's
270 | # architecture, the training parameters and the training history.
271 |
272 | from spyrit.core.train import save_net
273 |
274 | # Training parameters
275 | train_type = "N0_{:g}".format(N0)
276 | arch = "pinv-net"
277 | denoi = "cnn"
278 | data = "stl10"
279 | reg = 1e-7 # Default value
280 | suffix = "N_{}_M_{}_epo_{}_lr_{}_sss_{}_sdr_{}_bs_{}".format(
281 | h, M, num_epochs, lr, step_size, gamma, batch_size
282 | )
283 | title = model_root / f"{arch}_{denoi}_{data}_{train_type}_{suffix}"
284 | print(title)
285 |
286 | Path(model_root).mkdir(parents=True, exist_ok=True)
287 |
288 | if checkpoint_interval:
289 | Path(title).mkdir(parents=True, exist_ok=True)
290 |
291 | save_net(str(title) + ".pth", model)
292 |
293 | # Save training history
294 | import pickle
295 |
296 | if mode_run:
297 | from spyrit.core.train import Train_par
298 |
299 | params = Train_par(batch_size, lr, h, reg=reg)
300 | params.set_loss(train_info)
301 |
302 | train_path = (
303 | model_root / f"TRAIN_{arch}_{denoi}_{data}_{train_type}_{suffix}.pkl"
304 | )
305 |
306 | with open(train_path, "wb") as param_file:
307 | pickle.dump(params, param_file)
308 | torch.cuda.empty_cache()
309 |
310 | else:
311 | from spyrit.misc.load_data import download_girder
312 |
313 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
314 | dataID = "667ebfe4baa5a90007058964" # unique ID of the file
315 | data_name = "tuto4_TRAIN_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pkl"
316 | train_path = os.path.join(model_root, data_name)
317 | # download girder file
318 | download_girder(url, dataID, model_root, data_name)
319 |
320 | with open(train_path, "rb") as param_file:
321 | params = pickle.load(param_file)
322 | train_info["train"] = params.train_loss
323 | train_info["val"] = params.val_loss
324 |
325 | ###############################################################################
326 | # We plot the training loss and validation loss
327 |
328 | # Plot
329 | # sphinx_gallery_thumbnail_number = 2
330 |
331 | fig = plt.figure()
332 | plt.plot(train_info["train"], label="train")
333 | plt.plot(train_info["val"], label="val")
334 | plt.xlabel("Epochs", fontsize=20)
335 | plt.ylabel("Loss", fontsize=20)
336 | plt.legend(fontsize=20)
337 | plt.show()
338 |
339 | ###############################################################################
340 | # .. note::
341 | #
342 | # See the googlecolab notebook `spyrit-examples/tutorial/tuto_train_lin_meas_colab.ipynb `_
343 | # for training a reconstruction network on GPU. It shows how to train
344 | # using different architectures, denoisers and other hyperparameters from
345 | # :func:`~spyrit.core.train.train_model` function.
346 |
--------------------------------------------------------------------------------
/tutorial/wip/_tuto_05_recon_hadamSplit.py:
--------------------------------------------------------------------------------
1 | # %%
2 |
3 | if False:
4 |
5 | # %%
6 | # Split measurement operator and no noise
7 | # -----------------------------------------------------------------------------
8 | # .. _split_measurements:
9 |
10 | ###############################################################################
11 | # .. math::
12 | # y = P\tilde{x}= \begin{bmatrix} H_{+} \\ H_{-} \end{bmatrix} \tilde{x}.
13 |
14 | ###############################################################################
15 | # Hadamard split measurement operator is defined in the :class:`spyrit.core.meas.HadamSplit` class.
16 | # It computes linear measurements from incoming images, where :math:`P` is a
17 | # linear operator (matrix) with positive entries and :math:`\tilde{x}` is an image.
18 | # The class relies on a matrix :math:`H` with
19 | # shape :math:`(M,N)` where :math:`N` represents the number of pixels in the
20 | # image and :math:`M \le N` the number of measurements. The matrix :math:`P`
21 | # is obtained by splitting the matrix :math:`H` as :math:`H = H_{+}-H_{-}` where
22 | # :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`.
23 |
24 | # %%
25 | # Measurement and noise operators
26 | # -----------------------------------------------------------------------------
27 |
28 | ###############################################################################
29 | # We compute the measurement and noise operators and then
30 | # simulate the measurement vector :math:`y`.
31 |
32 | ###############################################################################
33 | # We consider Poisson noise, i.e., a noisy measurement vector given by
34 | #
35 | # .. math::
36 | # y \sim \mathcal{P}(\alpha P \tilde{x}),
37 | #
38 | # where :math:`\alpha` is a scalar value that represents the maximum image intensity
39 | # (in photons). The larger :math:`\alpha`, the higher the signal-to-noise ratio.
40 |
41 | ###############################################################################
42 | # We use the :class:`spyrit.core.noise.Poisson` class, set :math:`\alpha`
43 | # to 100 photons, and simulate a noisy measurement vector for the two sampling
44 | # strategies. Subsampling is handled internally by the :class:`~spyrit.core.meas.HadamSplit` class.
45 |
46 | from spyrit.core.noise import Poisson
47 | from spyrit.core.meas import HadamSplit
48 |
49 | alpha = 100.0 # number of photons
50 |
51 | # "Naive subsampling"
52 | # Measurement and noise operators
53 | meas_nai_op = HadamSplit(M, h, Ord_naive)
54 | noise_nai_op = Poisson(meas_nai_op, alpha)
55 |
56 | # Measurement operator
57 | y_nai = noise_nai_op(x) # a noisy measurement vector
58 |
59 | # "Variance subsampling"
60 | meas_var_op = HadamSplit(M, h, Ord_variance)
61 | noise_var_op = Poisson(meas_var_op, alpha)
62 | y_var = noise_var_op(x) # a noisy measurement vector
63 |
64 | print(f"Shape of image: {x.shape}")
65 | print(f"Shape of simulated measurements y: {y_var.shape}")
66 |
67 | # %%
68 | # The preprocessing operator measurements for split measurements
69 | # -----------------------------------------------------------------------------
70 |
71 | ###############################################################################
72 | # We compute the preprocessing operators for the three cases considered above,
73 | # using the :mod:`spyrit.core.prep` module. As previously introduced,
74 | # a preprocessing operator applies to the noisy measurements in order to
75 | # compensate for the scaling factors that appear in the measurement or noise operators:
76 | #
77 | # .. math::
78 | # m = \texttt{Prep}(y),
79 |
80 | ###############################################################################
81 | # We consider the :class:`spyrit.core.prep.SplitPoisson` class that intends
82 | # to "undo" the :class:`spyrit.core.noise.Poisson` class, for split measurements, by compensating for
83 | #
84 | # * the scaling that appears when computing Poisson-corrupted measurements
85 | #
86 | # * the affine transformation to get images in [0,1] from images in [-1,1]
87 | #
88 | # For this, it computes
89 | #
90 | # .. math::
91 | # m = \frac{2(y_+-y_-)}{\alpha} - P\mathbb{1},
92 | #
93 | # where :math:`y_+=H_+\tilde{x}` and :math:`y_-=H_-\tilde{x}`.
94 | # This is handled internally by the :class:`spyrit.core.prep.SplitPoisson` class.
95 |
96 | ###############################################################################
97 | # We compute the preprocessing operator and the measurements vectors for
98 | # the two sampling strategies.
99 |
100 | from spyrit.core.prep import SplitPoisson
101 |
102 | # "Naive subsampling"
103 | #
104 | # Preprocessing operator
105 | prep_nai_op = SplitPoisson(alpha, meas_nai_op)
106 |
107 | # Preprocessed measurements
108 | m_nai = prep_nai_op(y_nai)
109 |
110 | # "Variance subsampling"
111 | prep_var_op = SplitPoisson(alpha, meas_var_op)
112 | m_var = prep_var_op(y_var)
113 |
114 | # %%
115 | # Noiseless measurements
116 | # -----------------------------------------------------------------------------
117 |
118 | ###############################################################################
119 | # We consider now noiseless measurements for the "naive subsampling" strategy.
120 | # We compute the required operators and the noiseless measurement vector.
121 | # For this we use the :class:`spyrit.core.noise.NoNoise` class, which normalizes
122 | # the input image to get an image in [0,1], as explained in
123 | # :ref:`acquisition operators tutorial `.
124 | # For the preprocessing operator, we assign the number of photons equal to one.
125 |
126 | from spyrit.core.noise import NoNoise
127 |
128 | nonoise_nai_op = NoNoise(meas_nai_op)
129 | y_nai_nonoise = nonoise_nai_op(x) # a noisy measurement vector
130 |
131 | prep_nonoise_op = SplitPoisson(1.0, meas_nai_op)
132 | m_nai_nonoise = prep_nonoise_op(y_nai_nonoise)
133 |
134 | ###############################################################################
135 | # We can now plot the three measurement vectors
136 |
137 | # Plot the three measurement vectors
138 | m_plot = meas2img(m_nai_nonoise, Ord_naive)
139 | m_plot2 = meas2img(m_nai, Ord_naive)
140 | m_plot3 = spytorch.meas2img(m_var, Ord_variance)
141 |
142 | m_plot_max = m_plot[0, 0, :, :].max()
143 | m_plot_min = m_plot[0, 0, :, :].min()
144 |
145 | f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 7))
146 | im1 = ax1.imshow(m_plot[0, 0, :, :], cmap="gray")
147 | ax1.set_title("Noiseless measurements $m$ \n 'Naive' subsampling", fontsize=20)
148 | noaxis(ax1)
149 | add_colorbar(im1, "bottom", size="20%")
150 |
151 | im2 = ax2.imshow(m_plot2[0, 0, :, :], cmap="gray", vmin=m_plot_min, vmax=m_plot_max)
152 | ax2.set_title("Measurements $m$ \n 'Naive' subsampling", fontsize=20)
153 | noaxis(ax2)
154 | add_colorbar(im2, "bottom", size="20%")
155 |
156 | im3 = ax3.imshow(m_plot3[0, 0, :, :], cmap="gray", vmin=m_plot_min, vmax=m_plot_max)
157 | ax3.set_title("Measurements $m$ \n 'Variance' subsampling", fontsize=20)
158 | noaxis(ax3)
159 | add_colorbar(im3, "bottom", size="20%")
160 |
161 | plt.show()
162 |
163 | # %%
164 | # PinvNet network
165 | # -----------------------------------------------------------------------------
166 |
167 | ###############################################################################
168 | # We use the :class:`spyrit.core.recon.PinvNet` class where
169 | # the pseudo inverse reconstruction is performed by a neural network
170 |
171 | from spyrit.core.recon import PinvNet
172 |
173 | pinvnet_nai_nonoise = PinvNet(nonoise_nai_op, prep_nonoise_op)
174 | pinvnet_nai = PinvNet(noise_nai_op, prep_nai_op)
175 | pinvnet_var = PinvNet(noise_var_op, prep_var_op)
176 |
177 | # Reconstruction
178 | z_nai_nonoise = pinvnet_nai_nonoise.reconstruct(y_nai_nonoise)
179 | z_nai = pinvnet_nai.reconstruct(y_nai)
180 | z_var = pinvnet_var.reconstruct(y_var)
181 |
182 | ###############################################################################
183 | # We can now plot the three reconstructed images
184 | from spyrit.misc.disp import add_colorbar, noaxis
185 |
186 | # Plot
187 | f, axs = plt.subplots(2, 2, figsize=(10, 10))
188 | im1 = axs[0, 0].imshow(x[0, 0, :, :], cmap="gray")
189 | axs[0, 0].set_title("Ground-truth image")
190 | noaxis(axs[0, 0])
191 | add_colorbar(im1, "bottom")
192 |
193 | im2 = axs[0, 1].imshow(z_nai_nonoise[0, 0, :, :], cmap="gray")
194 | axs[0, 1].set_title("Reconstruction noiseless")
195 | noaxis(axs[0, 1])
196 | add_colorbar(im2, "bottom")
197 |
198 | im3 = axs[1, 0].imshow(z_nai[0, 0, :, :], cmap="gray")
199 | axs[1, 0].set_title("Reconstruction \n 'Naive' subsampling")
200 | noaxis(axs[1, 0])
201 | add_colorbar(im3, "bottom")
202 |
203 | im4 = axs[1, 1].imshow(z_var[0, 0, :, :], cmap="gray")
204 | axs[1, 1].set_title("Reconstruction \n 'Variance' subsampling")
205 | noaxis(axs[1, 1])
206 | add_colorbar(im4, "bottom")
207 |
208 | plt.show()
209 |
210 | ###############################################################################
211 | # .. note::
212 | #
213 | # Note that reconstructed images are pixelized when using the "naive subsampling",
214 | # while they are smoother and more similar to the ground-truth image when using the
215 | # "variance subsampling".
216 | #
217 | # Another way to further improve results is to include a nonlinear post-processing step,
218 | # which we will consider in a future tutorial.
219 |
--------------------------------------------------------------------------------
/tutorial/wip/_tuto_06_dcnet_split_measurements.py:
--------------------------------------------------------------------------------
1 | r"""
2 | =========================================
3 | 06. Denoised Completion Network (DCNet)
4 | =========================================
5 | .. _tuto_dcnet_split_measurements:
6 | This tutorial shows how to perform image reconstruction using the denoised
7 | completion network (DCNet) with a trainable image denoiser. In the next
8 | tutorial, we will plug a denoiser into a DCNet, which requires no training.
9 |
10 | .. figure:: ../fig/tuto3.png
11 | :width: 600
12 | :align: center
13 | :alt: Reconstruction and neural network denoising architecture sketch using split measurements
14 | """
15 |
16 | ######################################################################
17 | # .. note::
18 | # As in the previous tutorials, we consider a split Hadamard operator and
19 | # measurements corrupted by Poisson noise (see :ref:`Tutorial 5 `).
20 |
21 | # %%
22 | # Load a batch of images
23 | # =========================================
24 |
25 | ######################################################################
26 | # Update search path
27 |
28 | # sphinx_gallery_thumbnail_path = 'fig/tuto6.png'
29 | if False:
30 |
31 | import os
32 |
33 | import torch
34 | import torchvision
35 | import matplotlib.pyplot as plt
36 |
37 | import spyrit.core.torch as spytorch
38 | from spyrit.misc.disp import imagesc
39 | from spyrit.misc.statistics import transform_gray_norm
40 |
41 | spyritPath = os.getcwd()
42 | imgs_path = os.path.join(spyritPath, "images/")
43 |
44 | ######################################################################
45 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function.
46 |
47 | h = 64 # image is resized to h x h
48 | transform = transform_gray_norm(img_size=h)
49 |
50 | ######################################################################
51 | # Create a data loader from some dataset (images must be in the folder `images/test/`)
52 |
53 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
54 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
55 |
56 | x, _ = next(iter(dataloader))
57 | print(f"Shape of input images: {x.shape}")
58 |
59 | ######################################################################
60 | # Select the `i`-th image in the batch
61 | i = 1 # Image index (modify to change the image)
62 | x = x[i : i + 1, :, :, :]
63 | x = x.detach().clone()
64 | print(f"Shape of selected image: {x.shape}")
65 | b, c, h, w = x.shape
66 |
67 | ######################################################################
68 | # Plot the selected image
69 |
70 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]")
71 |
72 | # %%
73 | # Forward operators for split measurements
74 | # =========================================
75 |
76 | ######################################################################
77 | # We consider noisy measurements obtained from a split Hadamard operator, and a subsampling strategy that retaines the coefficients with the largest variance (for more details, refer to :ref:`Tutorial 5 `).
78 |
79 | ######################################################################
80 | # First, we download the covariance matrix from our warehouse.
81 |
82 | import girder_client
83 | from spyrit.misc.load_data import download_girder
84 |
85 | # Get covariance matrix
86 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
87 | dataId = "672207cbf03a54733161e95d"
88 | data_folder = "./stat/"
89 | cov_name = "Cov_64x64.pt"
90 | # download
91 | file_abs_path = download_girder(url, dataId, data_folder, cov_name)
92 |
93 | try:
94 | Cov = torch.load(file_abs_path, weights_only=True)
95 | print(f"Cov matrix {cov_name} loaded")
96 | except:
97 | Cov = torch.eye(h * h)
98 | print(f"Cov matrix {cov_name} not found! Set to the identity")
99 |
100 | ######################################################################
101 | # We define the measurement, noise and preprocessing operators and then simulate
102 | # a measurement vector corrupted by Poisson noise. As in the previous tutorials,
103 | # we simulate an accelerated acquisition by subsampling the measurement matrix
104 | # by retaining only the first rows of a Hadamard matrix that is permuted looking
105 | # at the diagonal of the covariance matrix.
106 |
107 | from spyrit.core.meas import HadamSplit
108 | from spyrit.core.noise import Poisson
109 | from spyrit.core.prep import SplitPoisson
110 |
111 | # Measurement parameters
112 | M = h**2 // 4 # Number of measurements (here, 1/4 of the pixels)
113 | alpha = 100.0 # number of photons
114 |
115 | # Measurement and noise operators
116 | Ord = spytorch.Cov2Var(Cov)
117 | meas_op = HadamSplit(M, h, Ord)
118 | noise_op = Poisson(meas_op, alpha)
119 | prep_op = SplitPoisson(alpha, meas_op)
120 |
121 | print(f"Shape of image: {x.shape}")
122 |
123 | # Measurements
124 | y = noise_op(x) # a noisy measurement vector
125 | m = prep_op(y) # preprocessed measurement vector
126 |
127 | m_plot = spytorch.meas2img(m, Ord)
128 | imagesc(m_plot[0, 0, :, :], r"Measurements $m$")
129 |
130 | # %%
131 | # Pseudo inverse solution
132 | # =========================================
133 |
134 | ######################################################################
135 | # We compute the pseudo inverse solution using :class:`spyrit.core.recon.PinvNet` class as in the previous tutorial.
136 |
137 | # Instantiate a PinvNet (with no denoising by default)
138 | from spyrit.core.recon import PinvNet
139 |
140 | pinvnet = PinvNet(noise_op, prep_op)
141 |
142 | # Use GPU, if available
143 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
144 | print("Using device: ", device)
145 | pinvnet = pinvnet.to(device)
146 | y = y.to(device)
147 |
148 | # Reconstruction
149 | with torch.no_grad():
150 | z_invnet = pinvnet.reconstruct(y)
151 |
152 | # %%
153 | # Denoised completion network (DCNet)
154 | # =========================================
155 |
156 | ######################################################################
157 | # .. image:: ../fig/dcnet.png
158 | # :width: 400
159 | # :align: center
160 | # :alt: Sketch of the DCNet architecture
161 |
162 | ######################################################################
163 | # The DCNet is based on four sequential steps:
164 | #
165 | # i) Denoising in the measurement domain.
166 | #
167 | # ii) Estimation of the missing measurements from the denoised ones.
168 | #
169 | # iii) Image-domain mapping.
170 | #
171 | # iv) (Learned) Denoising in the image domain.
172 | #
173 | # Typically, only the last step involves learnable parameters.
174 |
175 | # %%
176 | # Denoised completion
177 | # =========================================
178 |
179 | ######################################################################
180 | # The first three steps implement denoised completion, which corresponds to Tikhonov regularization. Considering linear measurements :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is the unknown image, it estimates :math:`x` from :math:`y` by minimizing
181 | #
182 | # .. math::
183 | # \| y - Hx \|^2_{\Sigma^{-1}_\alpha} + \|x\|^2_{\Sigma^{-1}},
184 | #
185 | # where :math:`\Sigma` is a covariance prior and :math:`\Sigma_\alpha` is the noise covariance. Denoised completation can be performed using the :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag` class (see documentation for more details).
186 |
187 | ######################################################################
188 | # In practice, it is more convenient to use the :class:`spyrit.core.recon.DCNet` class, which relies on a forward operator, a preprocessing operator, and a covariance prior.
189 |
190 | from spyrit.core.recon import DCNet
191 |
192 | dcnet = DCNet(noise_op, prep_op, Cov)
193 |
194 | # Use GPU, if available
195 | dcnet = dcnet.to(device)
196 | y = y.to(device)
197 |
198 | with torch.no_grad():
199 | z_dcnet = dcnet.reconstruct(y)
200 |
201 | ######################################################################
202 | # .. note::
203 | # In this tutorial, the covariance matrix used to define subsampling is also used as prior knowledge during reconstruction.
204 |
205 | # %%
206 | # (Learned) Denoising in the image domain
207 | # =========================================
208 |
209 | ######################################################################
210 | # To implement denoising in the image domain, we provide a :class:`spyrit.core.nnet.Unet` denoiser to a :class:`spyrit.core.recon.DCNet`.
211 |
212 | from spyrit.core.nnet import Unet
213 |
214 | denoi = Unet()
215 | dcnet_unet = DCNet(noise_op, prep_op, Cov, denoi)
216 | dcnet_unet = dcnet_unet.to(device) # Use GPU, if available
217 |
218 | ########################################################################
219 | # We load pretrained weights for the UNet
220 |
221 | from spyrit.core.train import load_net
222 |
223 | local_folder = "./model/"
224 | # Create model folder
225 | if os.path.exists(local_folder):
226 | print(f"{local_folder} found")
227 | else:
228 | os.mkdir(local_folder)
229 | print(f"Created {local_folder}")
230 |
231 | # Load pretrained model
232 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
233 | dataID = "67221559f03a54733161e960" # unique ID of the file
234 | data_name = "tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth"
235 | model_unet_path = os.path.join(local_folder, data_name)
236 |
237 | if os.path.exists(model_unet_path):
238 | print(f"Model found : {data_name}")
239 |
240 | else:
241 | print(f"Model not found : {data_name}")
242 | print(f"Downloading model... ", end="")
243 | try:
244 | gc = girder_client.GirderClient(apiUrl=url)
245 | gc.downloadFile(dataID, model_unet_path)
246 | print("Done")
247 | except Exception as e:
248 | print("Failed with error: ", e)
249 |
250 | # Load pretrained model
251 | load_net(model_unet_path, dcnet_unet, device, False)
252 |
253 | ######################################################################
254 | # We reconstruct the image
255 | with torch.no_grad():
256 | z_dcnet_unet = dcnet_unet.reconstruct(y)
257 |
258 | # %%
259 | # Results
260 | # =========================================
261 |
262 | from spyrit.misc.disp import add_colorbar, noaxis
263 |
264 | f, axs = plt.subplots(2, 2, figsize=(10, 10))
265 |
266 | # Plot the ground-truth image
267 | im1 = axs[0, 0].imshow(x[0, 0, :, :], cmap="gray")
268 | axs[0, 0].set_title("Ground-truth image", fontsize=16)
269 | noaxis(axs[0, 0])
270 | add_colorbar(im1, "bottom")
271 |
272 | # Plot the pseudo inverse solution
273 | im2 = axs[0, 1].imshow(z_invnet.cpu()[0, 0, :, :], cmap="gray")
274 | axs[0, 1].set_title("Pseudo inverse", fontsize=16)
275 | noaxis(axs[0, 1])
276 | add_colorbar(im2, "bottom")
277 |
278 | # Plot the solution obtained from denoised completion
279 | im3 = axs[1, 0].imshow(z_dcnet.cpu()[0, 0, :, :], cmap="gray")
280 | axs[1, 0].set_title(f"Denoised completion", fontsize=16)
281 | noaxis(axs[1, 0])
282 | add_colorbar(im3, "bottom")
283 |
284 | # Plot the solution obtained from denoised completion with UNet denoising
285 | im4 = axs[1, 1].imshow(z_dcnet_unet.cpu()[0, 0, :, :], cmap="gray")
286 | axs[1, 1].set_title(f"Denoised completion with UNet denoising", fontsize=16)
287 | noaxis(axs[1, 1])
288 | add_colorbar(im4, "bottom")
289 |
290 | plt.show()
291 |
292 | ######################################################################
293 | # .. note::
294 | # While the pseudo inverse reconstrcution is pixelized, the solution obtained by denoised completion is smoother. DCNet with UNet denoising in the image domain provides the best reconstruction.
295 |
296 | ######################################################################
297 | # .. note::
298 | # We refer to `spyrit-examples tutorials `_ for a comparison of different solutions (pinvNet, DCNet and DRUNet) that can be run in colab.
299 |
--------------------------------------------------------------------------------
/tutorial/wip/_tuto_08_lpgd_split_measurements.py:
--------------------------------------------------------------------------------
1 | r"""
2 | ======================================================================
3 | 08. Learned proximal gradient descent (LPGD) for split measurements
4 | ======================================================================
5 | .. _tuto_lpgd_split_measurements:
6 |
7 | This tutorial shows how to perform image reconstruction with unrolled Learned Proximal Gradient
8 | Descent (LPGD) for split measurements.
9 |
10 | Unfortunately, it has a large memory consumption so it cannot be run interactively.
11 | If you want to run it yourself, please remove all the "if False:" statements at
12 | the beginning of each code block. The figures displayed are the ones that would
13 | be generated if the code was run.
14 |
15 | .. figure:: ../fig/lpgd.png
16 | :width: 600
17 | :align: center
18 | :alt: Sketch of the unrolled Learned Proximal Gradient Descent
19 |
20 | """
21 |
22 | ###############################################################################
23 | # LPGD is a unrolled method, which can be explained as a recurrent network where
24 | # each block corresponds to un unrolled iteration of the proximal gradient descent.
25 | # At each iteration, the network performs a gradient step and a denoising step.
26 | #
27 | # The updated rule for the LPGD network is given by:
28 | #
29 | # .. math::
30 | # x^{(k+1)} = \mathcal{G}_{\theta}(x^{(k)} - \gamma H^T(H(x^{(k)}-m))).
31 | #
32 | # where :math:`x^{(k)}` is the image estimate at iteration :math:`k`,
33 | # :math:`H` is the forward operator, :math:`\gamma` is the step size,
34 | # and :math:`\mathcal{G}_{\theta}` is a denoising network with
35 | # learnable parameters :math:`\theta`.
36 |
37 | # %%
38 | # Load a batch of images
39 | # -----------------------------------------------------------------------------
40 | #
41 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized
42 | # using the :func:`transform_gray_norm` function.
43 |
44 | # sphinx_gallery_thumbnail_path = 'fig/lpgd.png'
45 |
46 | if False:
47 | import os
48 |
49 | import torch
50 | import torchvision
51 | import matplotlib.pyplot as plt
52 |
53 | import spyrit.core.torch as spytorch
54 | from spyrit.misc.disp import imagesc
55 | from spyrit.misc.statistics import transform_gray_norm
56 |
57 | spyritPath = os.getcwd()
58 | imgs_path = os.path.join(spyritPath, "images/")
59 |
60 | ######################################################################
61 | # Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function.
62 |
63 | h = 128 # image is resized to h x h
64 | transform = transform_gray_norm(img_size=h)
65 |
66 | ######################################################################
67 | # Create a data loader from some dataset (images must be in the folder `images/test/`)
68 |
69 | dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
70 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
71 |
72 | x, _ = next(iter(dataloader))
73 | print(f"Shape of input images: {x.shape}")
74 |
75 | ######################################################################
76 | # Select the `i`-th image in the batch
77 | i = 1 # Image index (modify to change the image)
78 | x = x[i : i + 1, :, :, :]
79 | x = x.detach().clone()
80 | print(f"Shape of selected image: {x.shape}")
81 | b, c, h, w = x.shape
82 |
83 | ######################################################################
84 | # Plot the selected image
85 |
86 | imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]")
87 |
88 | ###############################################################################
89 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679972abaa5a90007058950/download
90 | # :width: 600
91 | # :align: center
92 | # :alt: Ground-truth image x in [-1, 1]
93 | # %%
94 | # Forward operators for split measurements
95 | # -----------------------------------------------------------------------------
96 | #
97 | # We consider noisy split measurements for a Hadamard operator and a simple
98 | # rectangular subsampling” strategy
99 | # (for more details, refer to :ref:`Acquisition - split measurements `).
100 | #
101 | #
102 | # We define the measurement, noise and preprocessing operators and then
103 | # simulate a measurement vector :math:`y` corrupted by Poisson noise. As in the previous tutorial,
104 | # we simulate an accelerated acquisition by subsampling the measurement matrix
105 | # by retaining only the first rows of a Hadamard matrix.
106 |
107 | if False:
108 | import math
109 |
110 | from spyrit.core.meas import HadamSplit
111 | from spyrit.core.noise import Poisson
112 | from spyrit.core.prep import SplitPoisson
113 |
114 | # Measurement parameters
115 | M = h**2 // 4 # Number of measurements (here, 1/4 of the pixels)
116 | alpha = 10.0 # number of photons
117 |
118 | # Sampling: rectangular matrix
119 | Ord_rec = torch.zeros(h, h)
120 | n_sub = math.ceil(M**0.5)
121 | Ord_rec[:n_sub, :n_sub] = 1
122 |
123 | # Measurement and noise operators
124 | meas_op = HadamSplit(M, h, Ord_rec)
125 | noise_op = Poisson(meas_op, alpha)
126 | prep_op = SplitPoisson(alpha, meas_op)
127 |
128 | print(f"Shape of image: {x.shape}")
129 |
130 | # Measurements
131 | y = noise_op(x) # a noisy measurement vector
132 | m = prep_op(y) # preprocessed measurement vector
133 |
134 | m_plot = spytorch.meas2img(m, Ord_rec)
135 | imagesc(m_plot[0, 0, :, :], r"Measurements $m$")
136 |
137 | ###############################################################################
138 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679972bbaa5a90007058953/download
139 | # :width: 600
140 | # :align: center
141 | # :alt: Measurements m
142 | #
143 | # We define the LearnedPGD network by providing the measurement, noise and preprocessing operators,
144 | # the denoiser and other optional parameters to the class :class:`spyrit.core.recon.LearnedPGD`.
145 | # The optional parameters include the number of unrolled iterations (`iter_stop`)
146 | # and the step size decay factor (`step_decay`).
147 | # We choose Unet as the denoiser, as in previous tutorials.
148 | # For the optional parameters, we use three iterations and a step size decay
149 | # factor of 0.9, which worked well on this data (this should match the parameters
150 | # used during training).
151 | #
152 | # .. image:: ../fig/lpgd.png
153 | # :width: 600
154 | # :align: center
155 | # :alt: Sketch of the network architecture for LearnedPGD
156 |
157 | if False:
158 | from spyrit.core.nnet import Unet
159 | from spyrit.core.recon import LearnedPGD
160 |
161 | # use GPU, if available
162 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
163 | print("Using device:", device)
164 | # Define UNet denoiser
165 | denoi = Unet()
166 | # Define the LearnedPGD model
167 | lpgd_net = LearnedPGD(noise_op, prep_op, denoi, iter_stop=3, step_decay=0.9)
168 |
169 | ###############################################################################
170 | # Now, we download the pretrained weights and load them into the LPGD network.
171 | # Unfortunately, the pretrained weights are too heavy (2GB) to be downloaded
172 | # here. The last figure is nonetheless displayed to show the results.
173 |
174 | if False:
175 | from spyrit.core.train import load_net
176 | from spyrit.misc.load_data import download_girder
177 |
178 | # Download parameters
179 | url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
180 | dataID = "67221f60f03a54733161e96c" # unique ID of the file
181 | local_folder = "./model/"
182 | data_name = "tuto8_model_lpgd_light.pth"
183 | # Download from Girder
184 | model_abs_path = download_girder(url, dataID, local_folder, data_name)
185 |
186 | # Load pretrained weights to the model
187 | load_net(model_abs_path, lpgd_net, device, strict=False)
188 |
189 | lpgd_net.eval()
190 | lpgd_net.to(device)
191 |
192 | ###############################################################################
193 | # We reconstruct by calling the reconstruct method as in previous tutorials
194 | # and display the results.
195 |
196 | if False:
197 | from spyrit.misc.disp import add_colorbar, noaxis
198 |
199 | with torch.no_grad():
200 | z_lpgd = lpgd_net.reconstruct(y.to(device))
201 |
202 | # Plot results
203 | f, axs = plt.subplots(2, 1, figsize=(10, 10))
204 |
205 | im1 = axs[0].imshow(x.cpu()[0, 0, :, :], cmap="gray")
206 | axs[0].set_title("Ground-truth image", fontsize=16)
207 | noaxis(axs[0])
208 | add_colorbar(im1, "bottom")
209 |
210 | im2 = axs[1].imshow(z_lpgd.cpu()[0, 0, :, :], cmap="gray")
211 | axs[1].set_title("LPGD", fontsize=16)
212 | noaxis(axs[1])
213 | add_colorbar(im2, "bottom")
214 |
215 | plt.show()
216 |
217 | ###############################################################################
218 | # .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679853fbaa5a9000705894b/download
219 | # :width: 400
220 | # :align: center
221 | # :alt: Comparison of ground-truth image and LPGD reconstruction
222 |
--------------------------------------------------------------------------------
/tutorial/wip/_tuto_bonus_advanced_methods_colab.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | r"""
3 | Bonus. Advanced methods - Colab
4 | ===============================
5 | .. _tuto_advanced_methods_colab:
6 |
7 | We refer to `spyrit-examples/tutorial `_
8 | for a list of tutorials that can be run directly in colab and present more advanced cases than the main spyrit tutorials,
9 | such as comparison of methods for split measurements, or comparison of different denoising networks.
10 |
11 | """
12 |
13 | ###############################################################################
14 | # The spyrit-examples repository also includes research contributions based on the SPYRIT toolbox.
15 |
--------------------------------------------------------------------------------