├── .github
└── workflows
│ └── test_and_deploy.yml
├── .gitignore
├── .pylintrc
├── LICENSE
├── README.md
├── build.rst
├── docs
├── .gitignore
├── Makefile
├── make.bat
└── source
│ ├── _static
│ └── .gitkeep
│ ├── conf.py
│ ├── index.rst
│ ├── install.rst
│ ├── interfaces.rst
│ ├── intro.rst
│ ├── modules.rst
│ ├── quickstart.rst
│ └── references.rst
├── examples
├── README.rst
├── api.py
├── gaussian_psf_2d.py
├── gaussian_psf_3d.py
├── gibsonlanni.py
├── richardson_lucy.py
├── spitfire.py
└── wiener.py
├── pyproject.toml
├── requirements.txt
├── sdeconv
├── __init__.py
├── api
│ ├── __init__.py
│ ├── api.py
│ └── factory.py
├── cli
│ ├── __init__.py
│ ├── sdeconv.py
│ └── spsf.py
├── core
│ ├── __init__.py
│ ├── _observers.py
│ ├── _progress_logger.py
│ ├── _settings.py
│ └── _timing.py
├── data
│ ├── __init__.py
│ ├── celegans.tif
│ ├── pollen.tif
│ ├── pollen_poisson_noise_blurred.tif
│ └── pollen_psf.tif
├── deconv
│ ├── __init__.py
│ ├── _datasets.py
│ ├── _transforms.py
│ ├── _unet_2d.py
│ ├── _utils.py
│ ├── interface.py
│ ├── interface_nn.py
│ ├── nn_deconv.py
│ ├── noise2void.py
│ ├── richardson_lucy.py
│ ├── self_supervised_nn.py
│ ├── spitfire.py
│ └── wiener.py
└── psfs
│ ├── __init__.py
│ ├── gaussian.py
│ ├── gibson_lanni.py
│ ├── interface.py
│ └── lorentz.py
├── setup.cfg
├── setup.py
├── tests
├── __init__.py
├── deconv
│ ├── __init__.py
│ ├── celegans_richardson_lucy.tif
│ ├── celegans_spitfire.tif
│ ├── celegans_wiener.tif
│ ├── pollen_richardson_lucy.tif
│ ├── pollen_spitfire.tif
│ ├── pollen_wiener.tif
│ ├── test_richardson_lucy.py
│ ├── test_spitfire.py
│ └── test_wiener.py
└── psfs
│ ├── __init__.py
│ ├── gaussian2d.tif
│ ├── gaussian3d.tif
│ ├── gibsonlanni.tif
│ ├── lorentz2d.tif
│ ├── lorentz3d.tif
│ ├── test_gausian.py
│ ├── test_gibsonlanni.py
│ └── test_lorentz.py
└── tox.ini
/.github/workflows/test_and_deploy.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: tests
5 |
6 | on:
7 | push:
8 | branches:
9 | - master
10 | - main
11 | tags:
12 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
13 | pull_request:
14 | branches:
15 | - master
16 | - main
17 | workflow_dispatch:
18 |
19 | jobs:
20 | test:
21 | name: ${{ matrix.platform }} py${{ matrix.python-version }}
22 | runs-on: ${{ matrix.platform }}
23 | strategy:
24 | matrix:
25 | platform: [ubuntu-latest, windows-latest, macos-latest]
26 | python-version: ['3.10', '3.11']
27 |
28 | steps:
29 | - uses: actions/checkout@v2
30 |
31 | - name: Set up Python ${{ matrix.python-version }}
32 | uses: actions/setup-python@v2
33 | with:
34 | python-version: ${{ matrix.python-version }}
35 |
36 | # these libraries, along with pytest-xvfb (added in the `deps` in tox.ini),
37 | # enable testing on Qt on linux
38 | - name: Install Linux libraries
39 | if: runner.os == 'Linux'
40 | run: |
41 | sudo apt-get install -y libdbus-1-3 libxkbcommon-x11-0 libxcb-icccm4 \
42 | libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 \
43 | libxcb-xinerama0 libxcb-xinput0 libxcb-xfixes0
44 | # note: if you need dependencies from conda, considering using
45 | # setup-miniconda: https://github.com/conda-incubator/setup-miniconda
46 | # and
47 | # tox-conda: https://github.com/tox-dev/tox-conda
48 | - name: Install dependencies
49 | run: |
50 | python -m pip install --upgrade pip
51 | pip install setuptools tox numpy scipy tox-gh-actions
52 | # this runs the platform-specific tests declared in tox.ini
53 | - name: Test with tox
54 | run: tox
55 | env:
56 | PLATFORM: ${{ matrix.platform }}
57 |
58 | - name: Coverage
59 | uses: codecov/codecov-action@v1
60 |
61 | deploy:
62 | # this will run when you have tagged a commit, starting with "v*"
63 | # and requires that you have put your twine API key in your
64 | # github secrets (see readme for details)
65 | needs: [test]
66 | runs-on: ubuntu-latest
67 | if: contains(github.ref, 'tags')
68 | steps:
69 | - uses: actions/checkout@v2
70 | - name: Set up Python
71 | uses: actions/setup-python@v2
72 | with:
73 | python-version: "3.x"
74 | - name: Install dependencies
75 | run: |
76 | python -m pip install --upgrade pip
77 | pip install numpy scipy torch torchvision scikit-image
78 | pip install -U setuptools setuptools_scm wheel twine
79 | - name: Build and publish
80 | env:
81 | TWINE_USERNAME: __token__
82 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }}
83 | run: |
84 | git tag
85 | python setup.py sdist bdist_wheel
86 | twine upload dist/*
87 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .vscode/
3 | *.pyc
4 | *.pyo
5 | *~
6 | \.#*
7 | /build
8 | /dist
9 | *.egg-info
10 | *.so
11 | __pycache__
12 | .DS_Store
13 | _build
14 | docs/source/generated/
15 | runs/
16 | demo_*.*
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | # A comma-separated list of package or module names from where C extensions may
4 | # be loaded. Extensions are loading into the active Python interpreter and may
5 | # run arbitrary code.
6 | extension-pkg-allow-list=
7 |
8 | # A comma-separated list of package or module names from where C extensions may
9 | # be loaded. Extensions are loading into the active Python interpreter and may
10 | # run arbitrary code. (This is an alternative name to extension-pkg-allow-list
11 | # for backward compatibility.)
12 | extension-pkg-whitelist=
13 |
14 | # Return non-zero exit code if any of these messages/categories are detected,
15 | # even if score is above --fail-under value. Syntax same as enable. Messages
16 | # specified are enabled, while categories only check already-enabled messages.
17 | fail-on=
18 |
19 | # Specify a score threshold to be exceeded before program exits with error.
20 | fail-under=10.0
21 |
22 | # Files or directories to be skipped. They should be base names, not paths.
23 | ignore=CVS
24 |
25 | # Add files or directories matching the regex patterns to the ignore-list. The
26 | # regex matches against paths.
27 | ignore-paths=
28 |
29 | # Files or directories matching the regex patterns are skipped. The regex
30 | # matches against base names, not paths.
31 | ignore-patterns=
32 |
33 | # Python code to execute, usually for sys.path manipulation such as
34 | # pygtk.require().
35 | #init-hook=
36 |
37 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
38 | # number of processors available to use.
39 | jobs=1
40 |
41 | # Control the amount of potential inferred values when inferring a single
42 | # object. This can help the performance when dealing with large functions or
43 | # complex, nested conditions.
44 | limit-inference-results=100
45 |
46 | # List of plugins (as comma separated values of python module names) to load,
47 | # usually to register additional checkers.
48 | load-plugins=
49 |
50 | # Pickle collected data for later comparisons.
51 | persistent=yes
52 |
53 | # When enabled, pylint would attempt to guess common misconfiguration and emit
54 | # user-friendly hints instead of false-positive error messages.
55 | suggestion-mode=yes
56 |
57 | # Allow loading of arbitrary C extensions. Extensions are imported into the
58 | # active Python interpreter and may run arbitrary code.
59 | unsafe-load-any-extension=no
60 |
61 |
62 | [MESSAGES CONTROL]
63 |
64 | # Only show warnings with the listed confidence levels. Leave empty to show
65 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
66 | confidence=
67 |
68 | # Disable the message, report, category or checker with the given id(s). You
69 | # can either give multiple identifiers separated by comma (,) or put this
70 | # option multiple times (only on the command line, not in the configuration
71 | # file where it should appear only once). You can also use "--disable=all" to
72 | # disable everything first and then reenable specific checks. For example, if
73 | # you want to run only the similarities checker, you can use "--disable=all
74 | # --enable=similarities". If you want to run only the classes checker, but have
75 | # no Warning level messages displayed, use "--disable=all --enable=classes
76 | # --disable=W".
77 | disable=too-many-arguments, duplicate-code, too-few-public-methods, too-many-instance-attributes, consider-using-with
78 |
79 | # Enable the message, report, category or checker with the given id(s). You can
80 | # either give multiple identifier separated by comma (,) or put this option
81 | # multiple time (only on the command line, not in the configuration file where
82 | # it should appear only once). See also the "--disable" option for examples.
83 | enable=c-extension-no-member
84 |
85 |
86 | [REPORTS]
87 |
88 | # Python expression which should return a score less than or equal to 10. You
89 | # have access to the variables 'error', 'warning', 'refactor', and 'convention'
90 | # which contain the number of messages in each category, as well as 'statement'
91 | # which is the total number of statements analyzed. This score is used by the
92 | # global evaluation report (RP0004).
93 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
94 |
95 | # Template used to display messages. This is a python new-style format string
96 | # used to format the message information. See doc for all details.
97 | #msg-template=
98 |
99 | # Set the output format. Available formats are text, parseable, colorized, json
100 | # and msvs (visual studio). You can also give a reporter class, e.g.
101 | # mypackage.mymodule.MyReporterClass.
102 | output-format=text
103 |
104 | # Tells whether to display a full report or only the messages.
105 | reports=no
106 |
107 | # Activate the evaluation score.
108 | score=yes
109 |
110 |
111 | [REFACTORING]
112 |
113 | # Maximum number of nested blocks for function / method body
114 | max-nested-blocks=5
115 |
116 | # Complete name of functions that never returns. When checking for
117 | # inconsistent-return-statements if a never returning function is called then
118 | # it will be considered as an explicit return statement and no message will be
119 | # printed.
120 | never-returning-functions=sys.exit,argparse.parse_error
121 |
122 |
123 | [LOGGING]
124 |
125 | # The type of string formatting that logging methods do. `old` means using %
126 | # formatting, `new` is for `{}` formatting.
127 | logging-format-style=old
128 |
129 | # Logging modules to check that the string format arguments are in logging
130 | # function parameter format.
131 | logging-modules=logging
132 |
133 |
134 | [SPELLING]
135 |
136 | # Limits count of emitted suggestions for spelling mistakes.
137 | max-spelling-suggestions=4
138 |
139 | # Spelling dictionary name. Available dictionaries: none. To make it work,
140 | # install the 'python-enchant' package.
141 | spelling-dict=
142 |
143 | # List of comma separated words that should be considered directives if they
144 | # appear and the beginning of a comment and should not be checked.
145 | spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
146 |
147 | # List of comma separated words that should not be checked.
148 | spelling-ignore-words=
149 |
150 | # A path to a file that contains the private dictionary; one word per line.
151 | spelling-private-dict-file=
152 |
153 | # Tells whether to store unknown words to the private dictionary (see the
154 | # --spelling-private-dict-file option) instead of raising a message.
155 | spelling-store-unknown-words=no
156 |
157 |
158 | [MISCELLANEOUS]
159 |
160 | # List of note tags to take in consideration, separated by a comma.
161 | notes=FIXME,
162 | XXX,
163 | TODO
164 |
165 | # Regular expression of note tags to take in consideration.
166 | #notes-rgx=
167 |
168 |
169 | [TYPECHECK]
170 |
171 | # List of decorators that produce context managers, such as
172 | # contextlib.contextmanager. Add to this list to register other decorators that
173 | # produce valid context managers.
174 | contextmanager-decorators=contextlib.contextmanager
175 |
176 | # List of members which are set dynamically and missed by pylint inference
177 | # system, and so shouldn't trigger E1101 when accessed. Python regular
178 | # expressions are accepted.
179 | generated-members=
180 |
181 | # Tells whether missing members accessed in mixin class should be ignored. A
182 | # mixin class is detected if its name ends with "mixin" (case insensitive).
183 | ignore-mixin-members=yes
184 |
185 | # Tells whether to warn about missing members when the owner of the attribute
186 | # is inferred to be None.
187 | ignore-none=yes
188 |
189 | # This flag controls whether pylint should warn about no-member and similar
190 | # checks whenever an opaque object is returned when inferring. The inference
191 | # can return multiple potential results while evaluating a Python object, but
192 | # some branches might not be evaluated, which results in partial inference. In
193 | # that case, it might be useful to still emit no-member and other checks for
194 | # the rest of the inferred objects.
195 | ignore-on-opaque-inference=yes
196 |
197 | # List of class names for which member attributes should not be checked (useful
198 | # for classes with dynamically set attributes). This supports the use of
199 | # qualified names.
200 | ignored-classes=optparse.Values,thread._local,_thread._local,numpy,torch
201 |
202 | # List of module names for which member attributes should not be checked
203 | # (useful for modules/projects where namespaces are manipulated during runtime
204 | # and thus existing member attributes cannot be deduced by static analysis). It
205 | # supports qualified module names, as well as Unix pattern matching.
206 | ignored-modules=numpy,torch
207 |
208 | # Show a hint with possible names when a member name was not found. The aspect
209 | # of finding the hint is based on edit distance.
210 | missing-member-hint=yes
211 |
212 | # The minimum edit distance a name should have in order to be considered a
213 | # similar match for a missing member name.
214 | missing-member-hint-distance=1
215 |
216 | # The total number of similar names that should be taken in consideration when
217 | # showing a hint for a missing member.
218 | missing-member-max-choices=1
219 |
220 | # List of decorators that change the signature of a decorated function.
221 | signature-mutators=
222 |
223 |
224 | [VARIABLES]
225 |
226 | # List of additional names supposed to be defined in builtins. Remember that
227 | # you should avoid defining new builtins when possible.
228 | additional-builtins=
229 |
230 | # Tells whether unused global variables should be treated as a violation.
231 | allow-global-unused-variables=yes
232 |
233 | # List of names allowed to shadow builtins
234 | allowed-redefined-builtins=
235 |
236 | # List of strings which can identify a callback function by name. A callback
237 | # name must start or end with one of those strings.
238 | callbacks=cb_,
239 | _cb
240 |
241 | # A regular expression matching the name of dummy variables (i.e. expected to
242 | # not be used).
243 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
244 |
245 | # Argument names that match this expression will be ignored. Default to name
246 | # with leading underscore.
247 | ignored-argument-names=_.*|^ignored_|^unused_
248 |
249 | # Tells whether we should check for unused import in __init__ files.
250 | init-import=no
251 |
252 | # List of qualified module names which can have objects that can redefine
253 | # builtins.
254 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
255 |
256 |
257 | [FORMAT]
258 |
259 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
260 | expected-line-ending-format=
261 |
262 | # Regexp for a line that is allowed to be longer than the limit.
263 | ignore-long-lines=^\s*(# )??$
264 |
265 | # Number of spaces of indent required inside a hanging or continued line.
266 | indent-after-paren=4
267 |
268 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
269 | # tab).
270 | indent-string=' '
271 |
272 | # Maximum number of characters on a single line.
273 | max-line-length=100
274 |
275 | # Maximum number of lines in a module.
276 | max-module-lines=1000
277 |
278 | # Allow the body of a class to be on the same line as the declaration if body
279 | # contains single statement.
280 | single-line-class-stmt=no
281 |
282 | # Allow the body of an if to be on the same line as the test if there is no
283 | # else.
284 | single-line-if-stmt=no
285 |
286 |
287 | [SIMILARITIES]
288 |
289 | # Comments are removed from the similarity computation
290 | ignore-comments=yes
291 |
292 | # Docstrings are removed from the similarity computation
293 | ignore-docstrings=yes
294 |
295 | # Imports are removed from the similarity computation
296 | ignore-imports=no
297 |
298 | # Signatures are removed from the similarity computation
299 | ignore-signatures=no
300 |
301 | # Minimum lines number of a similarity.
302 | min-similarity-lines=4
303 |
304 |
305 | [BASIC]
306 |
307 | # Naming style matching correct argument names.
308 | argument-naming-style=snake_case
309 |
310 | # Regular expression matching correct argument names. Overrides argument-
311 | # naming-style.
312 | #argument-rgx=
313 |
314 | # Naming style matching correct attribute names.
315 | attr-naming-style=snake_case
316 |
317 | # Regular expression matching correct attribute names. Overrides attr-naming-
318 | # style.
319 | #attr-rgx=
320 |
321 | # Bad variable names which should always be refused, separated by a comma.
322 | bad-names=foo,
323 | bar,
324 | baz,
325 | toto,
326 | tutu,
327 | tata
328 |
329 | # Bad variable names regexes, separated by a comma. If names match any regex,
330 | # they will always be refused
331 | bad-names-rgxs=
332 |
333 | # Naming style matching correct class attribute names.
334 | class-attribute-naming-style=any
335 |
336 | # Regular expression matching correct class attribute names. Overrides class-
337 | # attribute-naming-style.
338 | #class-attribute-rgx=
339 |
340 | # Naming style matching correct class constant names.
341 | class-const-naming-style=UPPER_CASE
342 |
343 | # Regular expression matching correct class constant names. Overrides class-
344 | # const-naming-style.
345 | #class-const-rgx=
346 |
347 | # Naming style matching correct class names.
348 | class-naming-style=PascalCase
349 |
350 | # Regular expression matching correct class names. Overrides class-naming-
351 | # style.
352 | #class-rgx=
353 |
354 | # Naming style matching correct constant names.
355 | const-naming-style=UPPER_CASE
356 |
357 | # Regular expression matching correct constant names. Overrides const-naming-
358 | # style.
359 | #const-rgx=
360 |
361 | # Minimum line length for functions/classes that require docstrings, shorter
362 | # ones are exempt.
363 | docstring-min-length=-1
364 |
365 | # Naming style matching correct function names.
366 | function-naming-style=snake_case
367 |
368 | # Regular expression matching correct function names. Overrides function-
369 | # naming-style.
370 | #function-rgx=
371 |
372 | # Good variable names which should always be accepted, separated by a comma.
373 | good-names=i,
374 | j,
375 | k,
376 | m,
377 | n,
378 | x,
379 | y,
380 | z,
381 | ex,
382 | Run,
383 | _
384 |
385 | # Good variable names regexes, separated by a comma. If names match any regex,
386 | # they will always be accepted
387 | good-names-rgxs=
388 |
389 | # Include a hint for the correct naming format with invalid-name.
390 | include-naming-hint=no
391 |
392 | # Naming style matching correct inline iteration names.
393 | inlinevar-naming-style=any
394 |
395 | # Regular expression matching correct inline iteration names. Overrides
396 | # inlinevar-naming-style.
397 | #inlinevar-rgx=
398 |
399 | # Naming style matching correct method names.
400 | method-naming-style=snake_case
401 |
402 | # Regular expression matching correct method names. Overrides method-naming-
403 | # style.
404 | #method-rgx=
405 |
406 | # Naming style matching correct module names.
407 | module-naming-style=snake_case
408 |
409 | # Regular expression matching correct module names. Overrides module-naming-
410 | # style.
411 | #module-rgx=
412 |
413 | # Colon-delimited sets of names that determine each other's naming style when
414 | # the name regexes allow several styles.
415 | name-group=
416 |
417 | # Regular expression which should only match function or class names that do
418 | # not require a docstring.
419 | no-docstring-rgx=^_
420 |
421 | # List of decorators that produce properties, such as abc.abstractproperty. Add
422 | # to this list to register other decorators that produce valid properties.
423 | # These decorators are taken in consideration only for invalid-name.
424 | property-classes=abc.abstractproperty
425 |
426 | # Naming style matching correct variable names.
427 | variable-naming-style=snake_case
428 |
429 | # Regular expression matching correct variable names. Overrides variable-
430 | # naming-style.
431 | #variable-rgx=
432 |
433 |
434 | [STRING]
435 |
436 | # This flag controls whether inconsistent-quotes generates a warning when the
437 | # character used as a quote delimiter is used inconsistently within a module.
438 | check-quote-consistency=no
439 |
440 | # This flag controls whether the implicit-str-concat should generate a warning
441 | # on implicit string concatenation in sequences defined over several lines.
442 | check-str-concat-over-line-jumps=no
443 |
444 |
445 | [IMPORTS]
446 |
447 | # List of modules that can be imported at any level, not just the top level
448 | # one.
449 | allow-any-import-level=
450 |
451 | # Allow wildcard imports from modules that define __all__.
452 | allow-wildcard-with-all=no
453 |
454 | # Analyse import fallback blocks. This can be used to support both Python 2 and
455 | # 3 compatible code, which means that the block might have code that exists
456 | # only in one or another interpreter, leading to false positives when analysed.
457 | analyse-fallback-blocks=no
458 |
459 | # Deprecated modules which should not be used, separated by a comma.
460 | deprecated-modules=
461 |
462 | # Output a graph (.gv or any supported image format) of external dependencies
463 | # to the given file (report RP0402 must not be disabled).
464 | ext-import-graph=
465 |
466 | # Output a graph (.gv or any supported image format) of all (i.e. internal and
467 | # external) dependencies to the given file (report RP0402 must not be
468 | # disabled).
469 | import-graph=
470 |
471 | # Output a graph (.gv or any supported image format) of internal dependencies
472 | # to the given file (report RP0402 must not be disabled).
473 | int-import-graph=
474 |
475 | # Force import order to recognize a module as part of the standard
476 | # compatibility libraries.
477 | known-standard-library=
478 |
479 | # Force import order to recognize a module as part of a third party library.
480 | known-third-party=enchant
481 |
482 | # Couples of modules and preferred modules, separated by a comma.
483 | preferred-modules=
484 |
485 |
486 | [CLASSES]
487 |
488 | # Warn about protected attribute access inside special methods
489 | check-protected-access-in-special-methods=no
490 |
491 | # List of method names used to declare (i.e. assign) instance attributes.
492 | defining-attr-methods=__init__,
493 | __new__,
494 | setUp,
495 | __post_init__
496 |
497 | # List of member names, which should be excluded from the protected access
498 | # warning.
499 | exclude-protected=_asdict,
500 | _fields,
501 | _replace,
502 | _source,
503 | _make
504 |
505 | # List of valid names for the first argument in a class method.
506 | valid-classmethod-first-arg=cls
507 |
508 | # List of valid names for the first argument in a metaclass class method.
509 | valid-metaclass-classmethod-first-arg=cls
510 |
511 |
512 | [DESIGN]
513 |
514 | # List of qualified class names to ignore when countint class parents (see
515 | # R0901)
516 | ignored-parents=
517 |
518 | # Maximum number of arguments for function / method.
519 | max-args=5
520 |
521 | # Maximum number of attributes for a class (see R0902).
522 | max-attributes=7
523 |
524 | # Maximum number of boolean expressions in an if statement (see R0916).
525 | max-bool-expr=5
526 |
527 | # Maximum number of branch for function / method body.
528 | max-branches=12
529 |
530 | # Maximum number of locals for function / method body.
531 | max-locals=15
532 |
533 | # Maximum number of parents for a class (see R0901).
534 | max-parents=7
535 |
536 | # Maximum number of public methods for a class (see R0904).
537 | max-public-methods=20
538 |
539 | # Maximum number of return / yield for function / method body.
540 | max-returns=6
541 |
542 | # Maximum number of statements in function / method body.
543 | max-statements=50
544 |
545 | # Minimum number of public methods for a class (see R0903).
546 | min-public-methods=2
547 |
548 |
549 | [EXCEPTIONS]
550 |
551 | # Exceptions that will emit a warning when being caught. Defaults to
552 | # "BaseException, Exception".
553 | overgeneral-exceptions=builtins.BaseException,
554 | builtins.Exception
555 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, STracking
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SDeconv
2 |
3 | **SDeconv** is a python framework to develop scientific image deconvolution algorithms. This
4 | library has been developed for microscopy 2D and 3D images, but can be use to any image
5 | deconvolution application.
6 |
7 | # System Requirements
8 |
9 | ## Software Requirements
10 |
11 | ### OS Requirements
12 |
13 | The `SDeconv` development version is tested on *Windows 10*, *MacOS* and *Linux* operating systems.
14 | The developmental version of the package has been tested on the following systems:
15 |
16 | - Linux: 20.04.4
17 | - Mac OSX: Mac OS Catalina 10.15.7
18 | - Windows: 10
19 |
20 | # install
21 |
22 | ## Library installation from PyPI
23 |
24 | 1. Install an [Anaconda](https://www.anaconda.com/download/) distribution of Python -- Choose **Python 3.9** and your operating system. Note you might need to use an anaconda prompt if you did not add anaconda to the path.
25 | 2. Open an anaconda prompt / command prompt with `conda` for **python 3** in the path
26 | 3. Create a new environment with `conda create --name sdeconv python=3.9`.
27 | 4. To activate this new environment, run `conda activate sdeconv`
28 | 5. To install the `SDeconv`library, run `python -m pip install sdeconv`.
29 |
30 | if you need to update to a new release, use:
31 | ~~~sh
32 | python -m pip install sdeconv --upgrade
33 | ~~~
34 |
35 | ## Library installation from source
36 |
37 | This installation is for developers or people who want the last features in the ``main`` branch.
38 |
39 | 1. Install an [Anaconda](https://www.anaconda.com/download/) distribution of Python -- Choose **Python 3.9** and your operating system. Note you might need to use an anaconda prompt if you did not add anaconda to the path.
40 | 2. Open an anaconda prompt / command prompt with `conda` for **python 3** in the path
41 | 3. Create a new environment with `conda create --name sdeconv python=3.9`.
42 | 4. To activate this new environment, run `conda activate sdeconv`
43 | 5. Pull the source code from git with `git pull https://github.com/sylvainprigent/sdeconv.git
44 | 6. Then install the `SDeconv` library from you local dir with: `python -m pip install -e ./sdeconv`.
45 |
46 | ## Use SDeconv with napari
47 |
48 | The SDeconv library is embedded in a napari plugin that allows using ``SDeconv`` with a graphical interface.
49 | Please refer to the [`SDeconv` napari plugin](https://www.napari-hub.org/plugins/napari-sdeconv) documentation to install and use it.
50 |
51 | # SDeconv documentation
52 |
53 | The full documentation with tutorial and docstring is available [here](https://sylvainprigent.github.io/sdeconv/)
54 |
--------------------------------------------------------------------------------
/build.rst:
--------------------------------------------------------------------------------
1 | Build locally
2 | =============
3 |
4 | do::
5 | python3 setup.py build_ext --inplace
6 |
7 |
8 | Building from source
9 | ====================
10 |
11 | Building from source is required to work on a contribution (bug fix, new
12 | feature, code or documentation improvement).
13 |
14 | .. _git_repo:
15 |
16 | #. Use `Git `_ to check out the latest source from the
17 | `simglibpy repository `_ on
18 | GitLab.::
19 |
20 | git clone git://github.com/sylvainprigent/sdeconv.git # add --depth 1 if your connection is slow
21 | cd simglibpy
22 |
23 | If you plan on submitting a pull-request, you should clone from your fork
24 | instead.
25 |
26 | #. Install a compiler with OpenMP_ support for your platform.
27 |
28 | #. Optional (but recommended): create and activate a dedicated virtualenv_
29 | or `conda environment`_.
30 |
31 | #. Install Cython_ and build the project with pip in :ref:`editable_mode`::
32 |
33 | pip install cython
34 | pip install --verbose --no-build-isolation --editable .
35 |
36 | #. Check that the installed simglibpy has a version number ending with
37 | `.dev0`::
38 |
39 | python -c "import simglibpy; print(simglibpy.__version__)"
40 |
41 |
42 | .. note::
43 |
44 | You will have to run the ``pip install --no-build-isolation --editable .``
45 | command every time the source code of a Cython file is updated
46 | (ending in `.pyx` or `.pxd`). Use the ``--no-build-isolation`` flag to
47 | avoid compiling the whole project each time, only the files you have
48 | modified.
49 |
50 | Create a wheel
51 | ==============
52 |
53 | do::
54 |
55 | python3 setup.py bdist_wheel
56 |
57 | Testing
58 | =======
59 |
60 | run tests by running::
61 |
62 | pytest simglibpy
63 |
64 | or
65 |
66 | python3 -m pytest simglibpy
67 |
68 |
69 | Profiling
70 | =========
71 |
72 | python -m cProfile -o out_profile script_name.py
73 | cprofilev -f out_profile
74 |
75 | Build documentation
76 | ===================
77 |
78 | without example gallery::
79 |
80 | cd doc
81 | make
82 |
83 | with the example gallery (may take a while)::
84 |
85 | cd doc
86 | make html
87 |
88 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | build/
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/docs/source/_static/.gitkeep
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('../../'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'SDeconv'
21 | copyright = '2022-2024, SDeconv'
22 | author = 'Sylvain Prigent'
23 |
24 | # The full version, including alpha/beta/rc tags
25 | release = '1.0.4'
26 |
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = ['sphinx.ext.autodoc',
34 | 'sphinx.ext.autosummary']
35 |
36 | # Add any paths that contain templates here, relative to this directory.
37 | templates_path = ['_templates']
38 |
39 | # List of patterns, relative to source directory, that match files and
40 | # directories to ignore when looking for source files.
41 | # This pattern also affects html_static_path and html_extra_path.
42 | exclude_patterns = []
43 |
44 |
45 | # -- Options for HTML output -------------------------------------------------
46 |
47 | # The theme to use for HTML and HTML Help pages. See the documentation for
48 | # a list of builtin themes.
49 | #
50 | html_theme = 'sphinx_book_theme'
51 |
52 | # Add any paths that contain custom static files (such as style sheets) here,
53 | # relative to this directory. They are copied after the builtin static files,
54 | # so a file named "default.css" will overwrite the builtin "default.css".
55 | html_static_path = ['_static']
56 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. SDeconv documentation master file.
2 |
3 | SDeconv's documentation
4 | =======================
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 | :caption: Contents:
9 |
10 | intro
11 | install
12 | quickstart
13 | modules
14 | references
15 |
16 |
17 | Indices and tables
18 | ==================
19 |
20 | * :ref:`genindex`
21 | * :ref:`modindex`
22 | * :ref:`search`
--------------------------------------------------------------------------------
/docs/source/install.rst:
--------------------------------------------------------------------------------
1 | Install
2 | =======
3 |
4 | This section contains the instructions to install ``SDeconv``
5 |
6 | Using PyPI
7 | ----------
8 |
9 | Releases are available in PyPI a repository. We recommend using virtual environment
10 |
11 | .. code-block:: shell
12 |
13 | python -m venv .sdeconv-env
14 | source .sdeconv-env/bin/activate
15 | pip install sdeconv
16 |
17 |
18 | From source
19 | -----------
20 |
21 | If you plan to develop ``SDeconv`` we recommend installing locally
22 |
23 | .. code-block:: shell
24 |
25 | python -m venv .sdeconv-env
26 | source .sdeconv-env/bin/activate
27 | git clone https://github.com/sylvainprigent/sdeconv.git
28 | cd sdeconv
29 | pip install -e .
30 |
--------------------------------------------------------------------------------
/docs/source/interfaces.rst:
--------------------------------------------------------------------------------
1 | Interfaces
2 | ==========
3 |
4 |
5 | .. autoclass:: sdeconv.psfs.interface.SPSFGenerator
6 | :members:
7 |
8 | .. autoclass:: sdeconv.deconv.interface.SDeconvFilter
9 | :members:
10 |
11 | .. autoclass:: sdeconv.deconv.interface_nn.NNModule
12 | :members:
13 |
--------------------------------------------------------------------------------
/docs/source/intro.rst:
--------------------------------------------------------------------------------
1 | Introduction
2 | ============
3 |
4 | SDeconv is a library for 2D and 3D deconvolution of scientific images
5 |
6 | Context
7 | -------
8 | SDeconv has been developed in the `Serpico `_ research team. The goal is to provide a
9 | modular library to perform deconvolution of microscopy images. A classical application of our team is to apply deconvolution in 3D+t
10 | images depecting endosomes with Lattice LightSheet microscopy, and then ease the analysis.
11 |
12 | Library components
13 | ------------------
14 | SDeconv is written in python3 with pytorch. SDeconv library provides a module for each components of deconvolution algorithms:
15 |
16 | * **psfs**: this module defines the interface to implement Point Spread Function generators.
17 | * **deconv**: this module defines the interfaces to implement a deconvolution algorithm with or without neural networks
18 |
19 | Furthermore, the library provides sample data, a command line interface, and a application
20 | programing interface to ease the integration of the sdeconv deconvolution algorithms into softwares.
21 |
--------------------------------------------------------------------------------
/docs/source/modules.rst:
--------------------------------------------------------------------------------
1 | Modules
2 | =======
3 |
4 | Point Spread Functions
5 | ----------------------
6 |
7 | .. currentmodule:: sdeconv.psfs
8 |
9 | .. autosummary::
10 | :toctree: generated
11 | :nosignatures:
12 |
13 | SPSFGaussian
14 | SPSFGibsonLanni
15 | SPSFLorentz
16 |
17 |
18 | Deconvolution algorithms
19 | ------------------------
20 |
21 | .. currentmodule:: sdeconv.deconv
22 |
23 | .. autosummary::
24 | :toctree: generated
25 | :nosignatures:
26 |
27 | SWiener
28 | SRichardsonLucy
29 | Spitfire
30 | Noise2VoidDeconv
31 | SelfSupervisedNNDeconv
32 | NNDeconv
33 |
34 |
35 | Interfaces
36 | ----------
37 |
38 | Available interfaces to create a new PSF generator or a new deconvolution algorithm are:
39 |
40 | .. list-table:: Interfaces
41 | :widths: 25 75
42 |
43 | * - :class:`SPSFGenerator `
44 | - Interface for creating a new PSF generator
45 | * - :class:`SDeconvFilter `
46 | - Interface for creating a deconvolution filter that does not need neural network
47 | * - :class:`NNModule `
48 | - Interface for creating a deconvolution filter using a neural network
49 |
--------------------------------------------------------------------------------
/docs/source/quickstart.rst:
--------------------------------------------------------------------------------
1 | Quick start
2 | ===========
3 |
4 | This is a quick start example of how to use the **SDeconv** library. This section supposes you to know the principles
5 | of deconvolution. If it is not the case, please refer to the
6 | :doc:`References `.
7 |
8 | Input images
9 | ------------
10 | Input images are 2D or 3D gray scaled images. 2D images are represented as torch tensors with the following
11 | columns ordering ``[Y, X]`` and 3D images are represented with torch tensors with ``[Z, Y, X]`` columns ordering
12 |
13 | Sample images can be loaded using the data module:
14 |
15 | .. code-block:: python3
16 |
17 | from sdeconv import data
18 |
19 | image = data.celegans()
20 |
21 |
22 | Deconvolution using the API
23 | ---------------------------
24 |
25 | Bellow is an example how to write a deconvolution script with the API. In this example, we run the Wiener deconvolution algorithm:
26 |
27 | .. code-block:: python3
28 |
29 | from sdeconv import data
30 | from sdeconv.api import SDeconvAPI
31 | import matplotlib.pyplot as plt
32 |
33 | # Instantiate the API
34 | api = SDeconvAPI()
35 |
36 | # Load image
37 | image = data.celegans()
38 |
39 | # Generate a PSF
40 | psf = api.generate_psf('SPSFGaussian', sigma=[1.5, 1.5], shape=[13, 13])
41 |
42 | # Deconvolution with API
43 | image_decon = api.deconvolve(image, "SWiener", plane_by_plane=False, psf=psf, beta=0.005, pad=13)
44 |
45 | # Plot the results
46 | plt.figure()
47 | plt.subplot(131)
48 | plt.title('Original')
49 | plt.imshow(image.detach().cpu().numpy(), cmap='gray')
50 | plt.axis('off')
51 |
52 | plt.subplot(132)
53 | plt.title('PSF')
54 | plt.imshow(psf.detach().cpu().numpy(), cmap='gray')
55 | plt.axis('off')
56 |
57 | plt.subplot(133)
58 | plt.title('Wiener deconvolution')
59 | plt.imshow(image_decon.detach().cpu().numpy(), cmap='gray')
60 | plt.axis('off')
61 |
62 | plt.show()
63 |
64 |
65 | The advantage of using the API is that it implements several strategies to deconvolve a 3D or 3D+t image either plane
66 | by plane then frame by frame, or frame by frame in 3D.
67 |
68 |
69 | Deconvolution using the library classes
70 | ---------------------------------------
71 |
72 | When we need only one method, the easiest way may be to call direclty the class that implements the deconvolution
73 | algorithm:
74 |
75 | .. code-block:: python3
76 |
77 | import matplotlib.pyplot as plt
78 | from sdeconv.data import celegans
79 | from sdeconv.psfs import SPSFGaussian
80 | from sdeconv.deconv import SWiener
81 |
82 | # Load a 2D sample
83 | image = celegans()
84 |
85 | # Generate a 2D PSF
86 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13))
87 | psf = psf_generator()
88 |
89 | # Apply Wiener filter
90 | wiener = SWiener(psf, beta=0.005, pad=13)
91 | out_image = wiener(image)
92 |
93 | # Display results
94 | plt.figure()
95 | plt.title('PSF')
96 | plt.imshow(psf.detach().numpy(), cmap='gray')
97 |
98 | plt.figure()
99 | plt.title('C. elegans original')
100 | plt.imshow(image.detach().numpy(), cmap='gray')
101 |
102 | plt.figure()
103 | plt.title('C. elegans Wiener')
104 | plt.imshow(out_image.detach().numpy(), cmap='gray')
105 |
106 | plt.show()
107 |
108 | Please refer to :doc:`Modules ` for more details on the interfaces and the list of available PSFs and deconvolution methods.
109 |
--------------------------------------------------------------------------------
/docs/source/references.rst:
--------------------------------------------------------------------------------
1 | References
2 | ==========
3 |
4 | .. rubric:: References
5 |
6 | .. [Siba05] Sibarita, J.-B. Deconvolution Microscopy 201-243 (Springer Berlin Heidelberg, Berlin Heidelberg, 2005). https://doi.org/10.1007/b102215
7 | .. [Rich72] Richardson, W. H. Bayesian-based iterative method of image restoration. J. Opt. Soc. Am. 62, 55-59 (1972).
8 | .. [Lucy74] Lucy, L. B. An iterative technique for the rectification of observed distributions. Astron. J. 79, 745-754 (1974).
9 | .. [VdSt95] Van Der Voort, H. T. M. & Strasters, K. C. Restoration of confocal images for quantitative image analysis. J. Microsc. 178, 165-181. https://doi.org/10.1111/j.1365-2818.1995.tb03593.x (1995).
10 | .. [VkVB96] van Kempen, G. M. P., van der Voort, H. T. M., Bauman, J. G. J. & Strasters, K. C. Comparing maximum likelihood estimation and constrained Tikhonov-Miller restoration. IEEE Engineering in Medicine and Biology Magazine 15, 76-83 (1996).
11 | .. [VkVV97] van Kempen, G. M. P., van Vliet, L. J., Verveer, P. J. & van der Voort, H. T. M. A quantitative comparison of image restoration methods for confocal microscopy. J. Microscopy 185, 354-365. https://doi.org/10.1046/j.1365-2818.1997.d01-629.x (1997).
12 | .. [PNLV22] Prigent, S., Nguyen, HN., Leconte, L. et al. SPITFIR(e): a supermaneuverable algorithm for fast denoising and deconvolution of 3D fluorescence microscopy images and videos. Sci Rep 13, 1489 (2023). https://doi.org/10.1038/s41598-022-26178-y
13 | .. [KrBF19] Krull, A., Buchholz, T-O, Jug, F. Noise2Void - Learning Denoising From Single Noisy Images. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). Jine, 2019
14 |
--------------------------------------------------------------------------------
/examples/README.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
--------------------------------------------------------------------------------
/examples/api.py:
--------------------------------------------------------------------------------
1 | """
2 | SDeconv API
3 |
4 | This example shows how to use the SDeconv application programming interface
5 | """
6 |
7 | from sdeconv import data
8 | from sdeconv.api import SDeconvAPI
9 | import matplotlib.pyplot as plt
10 |
11 | # instantiate the API
12 | api = SDeconvAPI()
13 |
14 | # load image
15 | image = data.celegans()
16 |
17 | # Generate a PSF
18 | psf = api.generate_psf('SPSFGaussian', sigma=[1.5, 1.5], shape=[13, 13])
19 |
20 | # deconvolution with API
21 | image_decon = api.deconvolve(image, "SWiener", plane_by_plane=False, psf=psf, beta=0.005, pad=13)
22 |
23 | # plot the result
24 | plt.figure()
25 | plt.subplot(131)
26 | plt.title('Original')
27 | plt.imshow(image.detach().cpu().numpy(), cmap='gray')
28 | plt.axis('off')
29 |
30 | plt.subplot(132)
31 | plt.title('PSF')
32 | plt.imshow(psf.detach().cpu().numpy(), cmap='gray')
33 | plt.axis('off')
34 |
35 | plt.subplot(133)
36 | plt.title('Wiener deconvolution')
37 | plt.imshow(image_decon.detach().cpu().numpy(), cmap='gray')
38 | plt.axis('off')
39 |
40 | plt.show()
41 |
--------------------------------------------------------------------------------
/examples/gaussian_psf_2d.py:
--------------------------------------------------------------------------------
1 | """
2 | Gaussian PSF
3 |
4 | This example shows how to generate a Gaussian PSF
5 | """
6 |
7 | import matplotlib.pyplot as plt
8 | from sdeconv.psfs import SPSFGaussian
9 |
10 |
11 | psf_generator = SPSFGaussian(sigma=(1.5, 1.5), shape=(13, 13))
12 | psf = psf_generator()
13 |
14 | plt.figure()
15 | plt.title('Gaussian PSF')
16 | plt.imshow(psf.detach().numpy(), cmap='gray')
17 | plt.show()
--------------------------------------------------------------------------------
/examples/gaussian_psf_3d.py:
--------------------------------------------------------------------------------
1 | """
2 | Gaussian PSF
3 |
4 | This example shows how to generate a Gaussian PSF
5 | """
6 |
7 | from sdeconv.psfs import SPSFGaussian
8 | import napari
9 | import time
10 |
11 | psf_generator = SPSFGaussian(sigma=(0.5, 1.5, 1.5), shape=(25, 128, 128))
12 | t = time.time()
13 | psf = psf_generator()
14 | elapsed = time.time() - t
15 | print('elapsed = ', elapsed)
16 |
17 | viewer = napari.view_image(psf.detach().numpy(), scale=[200, 100, 100])
18 | napari.run()
19 |
--------------------------------------------------------------------------------
/examples/gibsonlanni.py:
--------------------------------------------------------------------------------
1 | """
2 | Gibson Lanni PSF
3 |
4 | This example shows how to generate a Gibson Lanni PSF
5 | """
6 |
7 | import matplotlib.pyplot as plt
8 | from sdeconv.psfs import SPSFGibsonLanni
9 | import napari
10 | import time
11 |
12 | psf_generator = SPSFGibsonLanni((11, 128, 128), use_square=True)
13 | t = time.time()
14 |
15 | psf = psf_generator()
16 | elapsed = time.time() - t
17 | print('elapsed = ', elapsed)
18 |
19 | viewer = napari.view_image(psf.detach().numpy(), scale=[200, 100, 100])
20 | napari.run()
21 |
--------------------------------------------------------------------------------
/examples/richardson_lucy.py:
--------------------------------------------------------------------------------
1 | """
2 | Richardson-Lucy deconvolution
3 |
4 | This example shows how to use the Richardson-Lucy deconvolution on a 2D image
5 | """
6 |
7 | import matplotlib.pyplot as plt
8 | from sdeconv.data import celegans
9 | from sdeconv.psfs import SPSFGaussian
10 | from sdeconv.deconv import SRichardsonLucy
11 |
12 |
13 | # load a 2D sample
14 | image = celegans()
15 |
16 | # Generate a 2D PSF
17 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13))
18 | psf = psf_generator()
19 |
20 | # apply Wiener filter
21 | filter_ = SRichardsonLucy(psf, niter=30, pad=13)
22 | out_image = filter_(image)
23 |
24 | # display results
25 | plt.figure()
26 | plt.title('PSF')
27 | plt.imshow(psf.detach().numpy(), cmap='gray')
28 |
29 | plt.figure()
30 | plt.title('C. elegans original')
31 | plt.imshow(image.detach().numpy(), cmap='gray')
32 |
33 | plt.figure()
34 | plt.title('C. elegans Richardson-Lucy')
35 | plt.imshow(out_image.detach().numpy(), cmap='gray')
36 |
37 | plt.show()
38 |
39 |
--------------------------------------------------------------------------------
/examples/spitfire.py:
--------------------------------------------------------------------------------
1 | """
2 | Spitfire deconvolution
3 |
4 | This example shows how to use the Spitfire deconvolution on a 2D image
5 | """
6 |
7 | import matplotlib.pyplot as plt
8 | from sdeconv.data import celegans
9 | from sdeconv.psfs import SPSFGaussian
10 | from sdeconv.deconv import Spitfire
11 |
12 |
13 | # load a 2D sample
14 | image = celegans()
15 |
16 | # Generate a 2D PSF
17 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13))
18 | psf = psf_generator()
19 |
20 | # apply Spitfire filter
21 | filter_ = Spitfire(psf, weight=0.6, reg=0.995, gradient_step=0.01, precision=1e-7, pad=13)
22 | out_image = filter_(image)
23 |
24 | # display results
25 | plt.figure()
26 | plt.title('PSF')
27 | plt.imshow(psf.detach().numpy(), cmap='gray')
28 |
29 | plt.figure()
30 | plt.title('C. elegans original')
31 | plt.imshow(image.detach().numpy(), cmap='gray')
32 |
33 | plt.figure()
34 | plt.title('C. elegans Spitfire')
35 | plt.imshow(out_image.detach().numpy(), cmap='gray')
36 |
37 | plt.show()
38 |
39 |
--------------------------------------------------------------------------------
/examples/wiener.py:
--------------------------------------------------------------------------------
1 | """
2 | Wiener deconvolution
3 |
4 | This example shows how to use the Wiener deconvolution on a 2D image
5 | """
6 |
7 | import matplotlib.pyplot as plt
8 | from sdeconv.data import celegans
9 | from sdeconv.psfs import SPSFGaussian
10 | from sdeconv.deconv import SWiener
11 |
12 |
13 | # load a 2D sample
14 | image = celegans()
15 |
16 | # Generate a 2D PSF
17 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13))
18 | psf = psf_generator()
19 |
20 | # apply Wiener filter
21 | filter_ = SWiener(psf, beta=0.005, pad=13)
22 | out_image = filter_(image)
23 |
24 | # display results
25 | plt.figure()
26 | plt.title('PSF')
27 | plt.imshow(psf.detach().numpy(), cmap='gray')
28 |
29 | plt.figure()
30 | plt.title('C. elegans original')
31 | plt.imshow(image.detach().numpy(), cmap='gray')
32 |
33 | plt.figure()
34 | plt.title('C. elegans Wiener')
35 | plt.imshow(out_image.detach().numpy(), cmap='gray')
36 |
37 | plt.show()
38 |
39 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42.0.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 |
6 |
7 | [tool.black]
8 | line-length = 79
9 |
10 | [tool.isort]
11 | profile = "black"
12 | line_length = 79
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy>=1.14.1
2 | numpy==1.26.4
3 | torch==2.2.1
4 | torchvision>=0.17.1
5 | scikit-image>=0.24.0
6 | pylint>=3.2.6
7 | pytest>=8.3.2
8 | sphinx~=8.0.2
9 | sphinx_book_theme~=1.1.3
--------------------------------------------------------------------------------
/sdeconv/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Bio-image analysis module for Python
3 | ====================================
4 | sdeconv is a Python module for microscopy image deconvolution based on the scientific
5 | ecosystem in python (numpy, torch, torchvision, scipy, scikit-image).
6 | See http://github.com/sylvainprigent/sdeconv for complete documentation.
7 | """
8 |
9 | __version__ = '1.0.4'
10 |
11 | __all__ = []
12 |
--------------------------------------------------------------------------------
/sdeconv/api/__init__.py:
--------------------------------------------------------------------------------
1 | """Module defining the SDeconv Application Programming Interface (API)"""
2 | from .factory import SDeconvModuleFactory, SDeconvModuleBuilder
3 | from .api import SDeconvAPI
4 |
5 | __all__ = ['SDeconvAPI']
6 |
--------------------------------------------------------------------------------
/sdeconv/api/api.py:
--------------------------------------------------------------------------------
1 | """Application programing interface for SDeconv"""
2 |
3 | import os
4 | import importlib
5 | import torch
6 |
7 | from ..psfs import SPSFGenerator
8 | from ..deconv import SDeconvFilter
9 | from .factory import SDeconvModuleFactory, SDeconvFactoryError
10 |
11 |
12 | class SDeconvAPI:
13 | """Main API to call SDeconv methods.
14 |
15 | The API implements a factory that instantiate the deconvolution
16 | """
17 | def __init__(self):
18 | self.psfs = SDeconvModuleFactory()
19 | self.filters = SDeconvModuleFactory()
20 | for name in self._find_modules('deconv'):
21 | mod = importlib.import_module(name)
22 | self.filters.register(mod.metadata['name'], mod.metadata)
23 | for name in self._find_modules('psfs'):
24 | mod = importlib.import_module(name)
25 | self.psfs.register(mod.metadata['name'], mod.metadata)
26 |
27 | @staticmethod
28 | def _find_modules(directory: str) -> list[str]:
29 | """Search sub modules in a directory
30 |
31 | :param directory: Directory to search
32 | :return: The founded module names
33 | """
34 | path = os.path.abspath(os.path.dirname(__file__))
35 | path = os.path.dirname(path)
36 | modules = []
37 | for parent in [directory]:
38 | path_ = os.path.join(path, parent)
39 | for module_path in os.listdir(path_):
40 | if module_path.endswith(
41 | ".py") and 'setup' not in module_path and \
42 | 'interface' not in module_path and \
43 | '__init__' not in module_path and not module_path.startswith(
44 | "_"):
45 | modules.append(f"sdeconv.{parent}.{module_path.split('.')[0]}")
46 | return modules
47 |
48 | def filter(self, name: str, **kwargs) -> SDeconvFilter | None:
49 | """Instantiate a deconvolution filter
50 |
51 | :param name: Unique name of the filter to instantiate
52 | :param kwargs: arguments of the filter
53 | :return: An instance of the filter
54 | """
55 | if name == 'None':
56 | return None
57 | return self.filters.get(name, **kwargs)
58 |
59 | def psf(self, method_name, **kwargs) -> SPSFGenerator:
60 | """Instantiate a psf generator
61 |
62 | :param method_name: name of the PSF generator
63 | :param kwargs: parameters of the PSF generator
64 | :return: An instance of the generator
65 | """
66 | if method_name == 'None':
67 | return None
68 | filter_ = self.psfs.get(method_name, **kwargs)
69 | if filter_.type == 'SPSFGenerator':
70 | return filter_
71 | raise SDeconvFactoryError(f'The method {method_name} is not a PSF generator')
72 |
73 | def generate_psf(self, method_name, **kwargs) -> torch.Tensor:
74 | """Generates a Point SPread Function
75 |
76 | :param method_name: Name of the PSF Generator method
77 | :param kwargs: Parameters of the PSF generator
78 | :return: The generated PSF
79 | """
80 | # print('generate psf args=', **kwargs)
81 | generator = self.psf(method_name, **kwargs)
82 | return generator()
83 |
84 | def deconvolve(self,
85 | image: torch.Tensor,
86 | method_name: str,
87 | plane_by_plane: bool,
88 | **kwargs
89 | ) -> torch.Tensor:
90 | """Run the deconvolution on an image
91 |
92 | :param image: Image to deconvolve. Can be 2D to 5D
93 | :param method_name: Name of the deconvolution method to use
94 | :param plane_by_plane: True to process the image plane by plane
95 | when dimension is more than 2
96 | :param kwargs: Parameters of the deconvolution method
97 | :return: The deblurred image
98 | """
99 | filter_ = self.filter(method_name, **kwargs)
100 | if filter_.type == 'SDeconvFilter':
101 | return self._deconv_dims(image, filter_, plane_by_plane=plane_by_plane)
102 | raise SDeconvFactoryError(f'The method {method_name} is not a deconvolution filter')
103 |
104 | @staticmethod
105 | def _deconv_3d_by_plane(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor:
106 | """Call the 3D deconvolution plane by plane
107 |
108 | :param image: 3D image tensor
109 | :param filter_: deconvolution class
110 | :return: The deblurred image
111 | """
112 | out_image = torch.zeros(image.shape)
113 | for i in range(image.shape[0]):
114 | out_image[i, ...] = filter_(image[i, ...])
115 | return out_image
116 |
117 | @staticmethod
118 | def _deconv_4d(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor:
119 | """Call the 3D+t deconvolution
120 |
121 | :param image: 3D+t image tensor
122 | :param filter_: deconvolution class
123 | :return: The deblurred stack
124 | """
125 | out_image = torch.zeros(image.shape)
126 | for i in range(image.shape[0]):
127 | out_image[i, ...] = filter_(image[i, ...])
128 | return out_image
129 |
130 | @staticmethod
131 | def _deconv_4d_by_plane(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor:
132 | """Call the 3D+t deconvolution plane by plane
133 |
134 | :param image: 3D+t image tensor
135 | :param filter_: deconvolution class
136 | :return: The deblurred stack
137 | """
138 | out_image = torch.zeros(image.shape)
139 | for i in range(image.shape[0]):
140 | for j in range(image.shape[1]):
141 | out_image[i, j, ...] = filter_(image[i, j, ...])
142 | return out_image
143 |
144 | @staticmethod
145 | def _deconv_5d(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor:
146 | """Call the 3D+t multi-channel deconvolution
147 |
148 | :param image: 3D+t image tensor
149 | :param filter_: deconvolution class
150 | :return: The deblurred hyper-stack
151 | """
152 | out_image = torch.zeros(image.shape)
153 | for i in range(image.shape[0]):
154 | for j in range(image.shape[1]):
155 | out_image[i, j, ...] = filter_(image[i, j, ...])
156 | return out_image
157 |
158 | @staticmethod
159 | def _deconv_5d_by_plane(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor:
160 | """Call the 3D+t multi-channel deconvolution plane by plane
161 |
162 | :param image: 3D+t image tensor
163 | :param filter_: deconvolution class
164 | :return: The deblurred hyper-stack
165 | """
166 | out_image = torch.zeros(image.shape)
167 | for batch in range(image.shape[0]):
168 | for channel in range(image.shape[1]):
169 | for plane in range(image.shape[2]):
170 | out_image[batch, channel, plane, ...] = \
171 | filter_(image[batch, channel, plane, ...])
172 | return out_image
173 |
174 | @staticmethod
175 | def _deconv_dims(image: torch.Tensor,
176 | filter_: SDeconvFilter,
177 | plane_by_plane: bool = False):
178 | """Call the deconvolution method depending on the image dimension
179 |
180 | :param image: 3D+t image tensor
181 | :param filter_: deconvolution class
182 | :param plane_by_plane: True to deblur third dimention as independent planes
183 | :return: The deblurred image, stack or hyper-stack
184 | """
185 | out_image = None
186 | if image.ndim == 2:
187 | out_image = filter_(image)
188 | elif image.ndim == 3 and plane_by_plane:
189 | out_image = SDeconvAPI._deconv_3d_by_plane(image, filter_)
190 | elif image.ndim == 3 and not plane_by_plane:
191 | out_image = filter_(image)
192 | elif image.ndim == 4 and not plane_by_plane:
193 | out_image = SDeconvAPI._deconv_4d(image, filter_)
194 | elif image.ndim == 4 and plane_by_plane:
195 | out_image = SDeconvAPI._deconv_4d_by_plane(image, filter_)
196 | elif image.ndim == 5 and not plane_by_plane:
197 | out_image = SDeconvAPI._deconv_5d(image, filter_)
198 | elif image.ndim == 5 and plane_by_plane:
199 | out_image = SDeconvAPI._deconv_5d_by_plane(image, filter_)
200 | else:
201 | raise SDeconvFactoryError('SDeconv can process only images up to 5 dims')
202 | return out_image
203 |
--------------------------------------------------------------------------------
/sdeconv/api/factory.py:
--------------------------------------------------------------------------------
1 | """Implements factory for PSF and deconvolution modules"""
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from ..psfs import SPSFGenerator
7 | from ..deconv import SDeconvFilter
8 |
9 |
10 | class SDeconvFactoryError(Exception):
11 | """Raised when an error happen when a module is built in the factory"""
12 |
13 |
14 | class SDeconvModuleFactory:
15 | """Factory for SDeconv modules"""
16 | def __init__(self):
17 | self._data = {}
18 |
19 | def register(self, key: str, metadata: dict[str, any]):
20 | """Register a new builder to the factory
21 |
22 | :param key: Name of the module to register
23 | :param metadata: Dictionary containing the filter metadata
24 | """
25 | self._data[key] = metadata
26 |
27 | def get_parameters(self, key: str) -> dict[str, any]:
28 | """Parameters getter method
29 |
30 | :param key: Name of the module builder
31 | :return: The module parameters
32 | """
33 | return self._data[key]['inputs']
34 |
35 | def get_keys(self) -> list[str]:
36 | """Get the names of all the registered modules
37 |
38 | :return: The list of all the registered modules names
39 | """
40 | return self._data.keys()
41 |
42 | def get(self, key: str, **kwargs) -> SPSFGenerator | SDeconvFilter:
43 | """Get an instance of the SDeconv module
44 |
45 | :param key: Name of the module to load
46 | :param kwargs: Dictionary of args for models parameters (ex: number of channels)
47 | :return: The instance of the module
48 | """
49 | metadata = self._data.get(key)
50 | if not metadata:
51 | raise ValueError(key)
52 | builder = SDeconvModuleBuilder()
53 | return builder.get_instance(metadata, kwargs)
54 |
55 |
56 | class SDeconvModuleBuilder:
57 | """Interface for a SDeconv module builder
58 |
59 | The builder is used by the factory to instantiate a module
60 |
61 | """
62 | def __init__(self):
63 | self._instance = None
64 |
65 | def get_instance(self, metadata: dict[str, any], args: dict) -> SPSFGenerator | SDeconvFilter:
66 | """Get the instance of the module
67 |
68 | :param metadata: Metadata of the module
69 | :param args: Argument to pass for the module instantiation
70 | :return: Instance of the module
71 | """
72 | # check the args
73 | instance_args = {}
74 | for key, value in metadata['inputs'].items():
75 | val = self._get_arg(value, key, args)
76 | instance_args[key] = val
77 | return metadata['class'](**instance_args)
78 |
79 | def _get_arg(self, param_metadata: dict[str, any], key: str, args: dict[str, any]) -> any:
80 | """Retrieve the value of a parameter with a type check
81 |
82 | :param param_metadata: Metadata of the parameter,
83 | :param key: Name of the parameter,
84 | :param args: Value of the parameters
85 | :return: The value of the parameter if check is successful
86 | """
87 | type_ = param_metadata['type']
88 | range_ = None
89 | if 'range' in param_metadata:
90 | range_ = param_metadata['range']
91 | arg_value = None
92 | if type_ == 'float':
93 | arg_value = self.get_arg_float(args, key, param_metadata['default'],
94 | range_)
95 | elif type_ == 'int':
96 | arg_value = self.get_arg_int(args, key, param_metadata['default'],
97 | range_)
98 | elif type_ == 'bool':
99 | arg_value = self.get_arg_bool(args, key, param_metadata['default'],
100 | range_)
101 | elif type_ == 'str':
102 | arg_value = self.get_arg_str(args, key, param_metadata['default'])
103 | elif type_ is torch.Tensor:
104 | arg_value = self.get_arg_array(args, key, param_metadata['default'])
105 | elif type_ == 'select':
106 | arg_value = self.get_arg_select(args, key, param_metadata['values'])
107 | elif 'zyx' in type_:
108 | arg_value = self.get_arg_list(args, key, param_metadata['default'])
109 | return arg_value
110 |
111 | @staticmethod
112 | def _error_message(key: str, value_type: str, value_range: tuple | None):
113 | """Throw an exception if an input parameter is not correct
114 |
115 | :param key: Input parameter key
116 | :param value_type: String naming the input type (int, float...)
117 | :param value_range: Min and max values of the parameter
118 | """
119 | range_message = ''
120 | if value_range and len(value_range) == 2:
121 | range_message = f' in range [{str(value_range[0]), str(value_range[1])}]'
122 |
123 | message = f'Parameter {key} must be of type `{value_type}` {range_message}'
124 | return message
125 |
126 | def get_arg_int(self,
127 | args: dict[str, any],
128 | key: str,
129 | default_value: int,
130 | value_range: tuple = None
131 | ) -> int:
132 | """Get the value of a parameter from the args list
133 |
134 | The default value of the parameter is returned if the
135 | key is not in args
136 |
137 | :param args: Dictionary of the input args
138 | :param key: Name of the parameters
139 | :param default_value: Default value of the parameter
140 | :param value_range: Min and max value of the parameter
141 | :return: The arg value
142 | """
143 | value = default_value
144 | if isinstance(args, dict) and key in args:
145 | # cast
146 | try:
147 | value = int(args[key])
148 | except ValueError as exc:
149 | raise SDeconvFactoryError(self._error_message(key, 'int', value_range)) from exc
150 | # test range
151 | if value_range and len(value_range) == 2:
152 | if value > value_range[1] or value < value_range[0]:
153 | raise SDeconvFactoryError(self._error_message(key, 'int', value_range))
154 | return value
155 |
156 | def get_arg_float(self,
157 | args: dict[str, any],
158 | key: str,
159 | default_value: float,
160 | value_range: tuple = None
161 | ) -> str:
162 | """Get the value of a parameter from the args list
163 |
164 | The default value of the parameter is returned if the
165 | key is not in args
166 |
167 | :param args: Dictionary of the input args
168 | :param key: Name of the parameters
169 | :param default_value: Default value of the parameter
170 | :param value_range: Min and max value of the parameter
171 | :return: The arg value
172 | """
173 | value = default_value
174 | if isinstance(args, dict) and key in args:
175 | # cast
176 | try:
177 | value = float(args[key])
178 | except ValueError as exc:
179 | raise SDeconvFactoryError(self._error_message(key, 'float', value_range)) from exc
180 | # test range
181 | if value_range and len(value_range) == 2:
182 | if value > value_range[1] or value < value_range[0]:
183 | raise SDeconvFactoryError(self._error_message(key, 'float', value_range))
184 | return value
185 |
186 | def get_arg_str(self,
187 | args: dict[str, any],
188 | key: str,
189 | default_value: str,
190 | value_range: tuple = None
191 | ) -> str:
192 | """Get the value of a parameter from the args list
193 |
194 | The default value of the parameter is returned if the
195 | key is not in args
196 |
197 | :param args: Dictionary of the input args
198 | :param key: Name of the parameters
199 | :param default_value: Default value of the parameter
200 | :param value_range: Min and max value of the parameter
201 | :return: The arg value
202 | """
203 | value = default_value
204 | if isinstance(args, dict) and key in args:
205 | # cast
206 | try:
207 | value = str(args[key])
208 | except ValueError as exc:
209 | raise SDeconvFactoryError(self._error_message(key, 'str', value_range)) from exc
210 | # test range
211 | if value_range and len(value_range) == 2:
212 | if value > value_range[1] or value < value_range[0]:
213 | raise SDeconvFactoryError(self._error_message(key, 'str', value_range))
214 | return value
215 |
216 | @staticmethod
217 | def _str2bool(value: str) -> bool:
218 | """Convert a string to a boolean
219 |
220 | :param value: String to convert
221 | :return: The boolean conversion
222 | """
223 | return value.lower() in ("yes", "true", "t", "1")
224 |
225 | def get_arg_bool(self,
226 | args: dict[str, any],
227 | key: str,
228 | default_value: bool,
229 | value_range: tuple = None
230 | ) -> bool:
231 | """Get the value of a parameter from the args list
232 |
233 | The default value of the parameter is returned if the
234 | key is not in args
235 |
236 | :param args: Dictionary of the input args
237 | :param key: Name of the parameters
238 | :param default_value: Default value of the parameter
239 | :param value_range: Min and max value of the parameter
240 | :return: The arg value
241 | """
242 | value = default_value
243 | # cast
244 | if isinstance(args, dict) and key in args:
245 | if isinstance(args[key], str):
246 | value = SDeconvModuleBuilder._str2bool(args[key])
247 | elif isinstance(args[key], bool):
248 | value = args[key]
249 | else:
250 | raise SDeconvFactoryError(self._error_message(key, 'bool', value_range))
251 | # test range
252 | if value_range and len(value_range) == 2:
253 | if value > value_range[1] or value < value_range[0]:
254 | raise SDeconvFactoryError(self._error_message(key, 'bool', value_range))
255 | return value
256 |
257 | def get_arg_array(self,
258 | args: dict[str, any],
259 | key: str,
260 | default_value: torch.Tensor
261 | ) -> torch.Tensor:
262 | """Get the value of a parameter from the args list
263 |
264 | The default value of the parameter is returned if the
265 | key is not in args
266 |
267 | :param args: Dictionary of the input args
268 | :param key: Name of the parameters
269 | :param default_value: Default value of the parameter
270 | :return: The arg value
271 | """
272 | value = default_value
273 | if isinstance(args, dict) and key in args:
274 | if isinstance(args[key], torch.Tensor):
275 | value = args[key]
276 | elif isinstance(args[key], np.ndarray):
277 | value = torch.Tensor(args[key])
278 | else:
279 | raise SDeconvFactoryError(self._error_message(key, 'array', None))
280 | return value
281 |
282 | def get_arg_list(self,
283 | args: dict[str, any],
284 | key: str,
285 | default_value: list
286 | ) -> list:
287 | """Get the value of a parameter from the args list
288 |
289 | The default value of the parameter is returned if the
290 | key is not in args
291 |
292 | :param args: Dictionary of the input args
293 | :param key: Name of the parameters
294 | :param default_value: Default value of the parameter
295 | :return: The arg value
296 | """
297 | value = default_value
298 | if isinstance(args, dict) and key in args:
299 | if isinstance(args[key], (list, tuple)):
300 | value = args[key]
301 | else:
302 | raise SDeconvFactoryError(self._error_message(key, 'list', None))
303 | return value
304 |
305 | def get_arg_select(self,
306 | args: dict[str, any],
307 | key: str,
308 | values: list
309 | ) -> str:
310 | """Get the value of a parameter from the args list as a select input
311 |
312 | :param args: Dictionary of the input args
313 | :param key: Name of the parameters
314 | :param values: Possible values in select input
315 | :return: The arg value
316 | """
317 | if isinstance(args, dict) and key in args:
318 | value = str(args[key])
319 | for val in values:
320 | if str(val) == value:
321 | return val
322 | raise SDeconvFactoryError(self._error_message(key, 'select', None))
323 |
--------------------------------------------------------------------------------
/sdeconv/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/cli/__init__.py
--------------------------------------------------------------------------------
/sdeconv/cli/sdeconv.py:
--------------------------------------------------------------------------------
1 | """Command line interface module for 2D to 5D image deconvolution"""
2 |
3 | import argparse
4 | from skimage.io import imread, imsave
5 | import torch
6 | import numpy as np
7 |
8 | from sdeconv.api import SDeconvAPI
9 |
10 |
11 | def add_args_to_parser(parser: argparse.ArgumentParser, api: SDeconvAPI):
12 | """Add all the parameters available in the API as an argument in the parser
13 |
14 | :param parser: Argument parser instance
15 | :param api: SDeconv Application Programming Interface instance
16 |
17 | """
18 | for filter_name in api.filters.get_keys():
19 | params = api.filters.get_parameters(filter_name)
20 | for key, value in params.items():
21 | parser.add_argument(f"--{key}", help=value['help'], default=value['default'])
22 |
23 |
24 | def main():
25 | """Command line interface entrypoint function"""
26 | parser = argparse.ArgumentParser(description='2D to 5D image deconvolution',
27 | conflict_handler='resolve')
28 |
29 | parser.add_argument('-i', '--input', help='Input image file', default='.tif')
30 | parser.add_argument('-m', '--method', help='Deconvolution method', default='wiener')
31 | parser.add_argument('-o', '--output', help='Output image file', default='.tif')
32 | parser.add_argument('-p', '--plane', help='Plane by plane deconvolution', default=False)
33 |
34 | api = SDeconvAPI()
35 | add_args_to_parser(parser, api)
36 | args = parser.parse_args()
37 |
38 | args_dict = vars(args)
39 |
40 | image = torch.Tensor(np.float32(imread(args.input)))
41 | out_image = api.deconvolve(image, args.method, args.plane, **args_dict)
42 | imsave(args.output, out_image.detach().numpy())
43 |
--------------------------------------------------------------------------------
/sdeconv/cli/spsf.py:
--------------------------------------------------------------------------------
1 | """Command line interface module for Point Spread Function generator"""
2 |
3 | import argparse
4 | from skimage.io import imsave
5 |
6 | from sdeconv.api import SDeconvAPI
7 |
8 |
9 | def add_args_to_parser(parser: argparse.ArgumentParser, api: SDeconvAPI):
10 | """Add all the parameters available in the API as an argument in the parser
11 |
12 | :param parser: Argument parser instance
13 | :param api: SDeconv Application Programming Interface instance
14 | """
15 | for filter_name in api.filters.get_keys():
16 | params = api.psfs.get_parameters(filter_name)
17 | for key, value in params.items():
18 | parser.add_argument(f"--{key}", help=value['help'], default=value['default'])
19 |
20 |
21 | def main():
22 | """Command line interface entrypoint function"""
23 | parser = argparse.ArgumentParser(description='2D to 5D image deconvolution',
24 | conflict_handler='resolve')
25 |
26 | parser.add_argument('-m', '--method', help='Deconvolution method', default='wiener')
27 | parser.add_argument('-o', '--output', help='Output image file', default='.tif')
28 |
29 | api = SDeconvAPI()
30 | add_args_to_parser(parser, api)
31 | args = parser.parse_args()
32 |
33 | args_dict = vars(args)
34 | out_image = api.psf(args.method, **args_dict)
35 | imsave(args.output, out_image.detach().numpy())
36 |
--------------------------------------------------------------------------------
/sdeconv/core/__init__.py:
--------------------------------------------------------------------------------
1 | """Core module for the SDeconv library. It implements settings and observer/observable design
2 | patterns"""
3 | from ._observers import SObservable, SObserver, SObserverConsole
4 | from ._settings import SSettings, SSettingsContainer
5 | from ._timing import seconds2str
6 | from ._progress_logger import SConsoleLogger
7 |
8 |
9 | __all__ = ['SSettings',
10 | 'SSettingsContainer',
11 | 'SObservable',
12 | 'SObserver',
13 | 'SObserverConsole',
14 | 'seconds2str',
15 | 'SConsoleLogger']
16 |
--------------------------------------------------------------------------------
/sdeconv/core/_observers.py:
--------------------------------------------------------------------------------
1 | """Module that implements the observer/observable design pattern to display progress"""
2 | from abc import ABC, abstractmethod
3 |
4 |
5 | class SObserver(ABC):
6 | """Interface of observer to notify progress
7 |
8 | An observer must implement the progress and message
9 | """
10 | @abstractmethod
11 | def notify(self, message: str):
12 | """Notify a progress message
13 |
14 | :param message: Progress message
15 | """
16 | raise NotImplementedError('SObserver is abstract')
17 |
18 | @abstractmethod
19 | def progress(self, value: int):
20 | """Notify progress value
21 |
22 | :param value: Progress value in [0, 100]
23 | """
24 | raise NotImplementedError('SObserver is abstract')
25 |
26 |
27 | class SObservable:
28 | """Interface for data processing class
29 |
30 | The observable class can notify the observers for progress
31 | """
32 | def __init__(self):
33 | self._observers = []
34 |
35 | def add_observer(self, observer: SObserver):
36 | """Add an observer
37 |
38 | :param observer: Observer instance to add
39 | """
40 | self._observers.append(observer)
41 |
42 | def notify(self, message: str):
43 | """Notify progress to observers
44 |
45 | :param message: Progress message
46 | """
47 | for obs in self._observers:
48 | obs.notify(message)
49 |
50 | def progress(self, value):
51 | """Notify progress to observers
52 |
53 | :param value: Progress value in [0, 100]
54 | """
55 | for obs in self._observers:
56 | obs.progress(value)
57 |
58 |
59 | class SObserverConsole(SObserver):
60 | """print message and progress to console"""
61 |
62 | def notify(self, message: str):
63 | """Print message
64 |
65 | :param message: Progress message
66 | """
67 | print(message)
68 |
69 | def progress(self, value: str):
70 | """Print progress
71 |
72 | :param value: Progress value in [0, 100]
73 | """
74 | print('progress:', value, '%')
75 |
--------------------------------------------------------------------------------
/sdeconv/core/_progress_logger.py:
--------------------------------------------------------------------------------
1 | """Set of classes to log a workflow run"""
2 | COLOR_WARNING = '\033[93m'
3 | COLOR_ERROR = '\033[91m'
4 | COLOR_GREEN = '\033[92m'
5 | COLOR_ENDC = '\033[0m'
6 |
7 |
8 | class SProgressLogger:
9 | """Default logger
10 |
11 | A logger is used by a workflow to print the warnings, errors and progress.
12 | A logger can be used to print in the console or in a log file
13 |
14 | """
15 | def __init__(self):
16 | self.prefix = ''
17 |
18 | def new_line(self):
19 | """Print a new line in the log"""
20 | raise NotImplementedError()
21 |
22 | def message(self, message: str):
23 | """Log a default message
24 |
25 | :param message: Message to log
26 | """
27 | raise NotImplementedError()
28 |
29 | def error(self, message: str):
30 | """Log an error message
31 |
32 | :param message: Message to log
33 | """
34 | raise NotImplementedError()
35 |
36 | def warning(self, message: str):
37 | """Log a warning
38 |
39 | :param message: Message to log
40 | """
41 | raise NotImplementedError()
42 |
43 | def progress(self, iteration: int, total: int, prefix: str, suffix: str):
44 | """Log a progress
45 |
46 | :param iteration: Current iteration
47 | :param total: Total number of iteration
48 | :param prefix: Text to print before the progress
49 | :param suffix: Text to print after the message
50 | """
51 | raise NotImplementedError()
52 |
53 | def close(self):
54 | """Close the logger"""
55 | raise NotImplementedError()
56 |
57 |
58 | class SProgressObservable:
59 | """Observable pattern
60 |
61 | This pattern allows to set multiple progress logger to
62 | one workflow
63 |
64 | """
65 | def __init__(self):
66 | self._loggers = []
67 |
68 | def set_prefix(self, prefix: str):
69 | """Set the prefix for all loggers
70 |
71 | The prefix is a printed str ad the beginning of each
72 | line of the logger
73 |
74 | :param prefix: Prefix content
75 | """
76 | for logger in self._loggers:
77 | logger.prefix = prefix
78 |
79 | def add_logger(self, logger: SProgressLogger):
80 | """Add a logger to the observer
81 |
82 | :param logger: Logger to add to the observer
83 | """
84 | self._loggers.append(logger)
85 |
86 | def new_line(self):
87 | """Print a new line in the loggers"""
88 | for logger in self._loggers:
89 | logger.new_line()
90 |
91 | def message(self, message: str):
92 | """Log a default message
93 |
94 | :param message: Message to log
95 | """
96 | for logger in self._loggers:
97 | logger.message(message)
98 |
99 | def error(self, message: str):
100 | """Log an error message
101 |
102 | :param message: Message to log
103 | """
104 | for logger in self._loggers:
105 | logger.error(message)
106 |
107 | def warning(self, message: str):
108 | """Log a warning message
109 |
110 | :param message: Message to log
111 | """
112 | for logger in self._loggers:
113 | logger.warning(message)
114 |
115 | def progress(self, iteration: int, total: int, prefix: str, suffix: str):
116 | """Log a progress
117 |
118 | :param iteration: Current iteration
119 | :param total: Total number of iteration
120 | :param prefix: Text to print before the progress
121 | :param suffix: Text to print after the message
122 | """
123 | for logger in self._loggers:
124 | logger.progress(iteration, total, prefix, suffix)
125 |
126 | def close(self):
127 | """Close the loggers"""
128 | for logger in self._loggers:
129 | logger.close()
130 |
131 |
132 | class SConsoleLogger(SProgressLogger):
133 | """Console logger displaying a progress bar
134 |
135 | The progress bar display the basic information of a batch loop (loss,
136 | batch id, time/remaining time)
137 |
138 | """
139 | def __init__(self):
140 | super().__init__()
141 | self.decimals = 1
142 | self.print_end = "\r"
143 | self.length = 100
144 | self.fill = '█'
145 |
146 | def new_line(self):
147 | print(f"{self.prefix}:\n")
148 |
149 | def message(self, message):
150 | print(f'{self.prefix}: {message}')
151 |
152 | def error(self, message):
153 | print(f'{COLOR_ERROR}{self.prefix} ERROR: '
154 | f'{message}{COLOR_ENDC}')
155 |
156 | def warning(self, message):
157 | print(f'{COLOR_WARNING}{self.prefix} WARNING: '
158 | f'{message}{COLOR_ENDC}')
159 |
160 | def progress(self, iteration, total, prefix, suffix):
161 | percent = ("{0:." + str(self.decimals) + "f}").format(
162 | 100 * (iteration / float(total)))
163 | filled_length = int(self.length * iteration // total)
164 | bar_ = self.fill * filled_length + ' ' * (self.length - filled_length)
165 | print(f'\r{prefix} {percent}% |{bar_}| {suffix}',
166 | end=self.print_end)
167 |
168 | def close(self):
169 | pass
170 |
--------------------------------------------------------------------------------
/sdeconv/core/_settings.py:
--------------------------------------------------------------------------------
1 | """Implements setting management"""
2 | import torch
3 |
4 |
5 | class SSettingsContainer:
6 | """Container for the SDeconv library settings"""
7 | def __init__(self):
8 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9 |
10 | def get_device(self) -> str:
11 | """Returns the device name for torch"""
12 | return self.device
13 |
14 | def print(self):
15 | """Display the settings in the console"""
16 | print(f'SDeconv settings: device={self.device}')
17 |
18 |
19 | class SSettings:
20 | """Singleton to access the Settings container
21 |
22 | :raises: Exception: if multiple instantiation of the Settings container is tried
23 | """
24 | __instance = None
25 |
26 | def __init__(self):
27 | """ Virtually private constructor. """
28 | if SSettings.__instance is not None:
29 | raise RuntimeError("Settings container can be initialized only once!")
30 | SSettings.__instance = SSettingsContainer()
31 |
32 | @staticmethod
33 | def instance():
34 | """ Static access method to the Config. """
35 | if SSettings.__instance is None:
36 | SSettings.__instance = SSettingsContainer()
37 | return SSettings.__instance
38 |
39 | @staticmethod
40 | def print():
41 | """Print the settings to the console"""
42 | SSettings.instance().print()
43 |
--------------------------------------------------------------------------------
/sdeconv/core/_timing.py:
--------------------------------------------------------------------------------
1 | """Module to implement useflkl function for time formating"""
2 |
3 |
4 | def seconds2str(sec: int) -> str:
5 | """Convert seconds to printable string in hh:mm:ss
6 |
7 | :param sec: Duration in seconds
8 | :return: human-readable time string
9 | """
10 | sec_value = sec % (24 * 3600)
11 | hour_value = sec_value // 3600
12 | sec_value %= 3600
13 | min_value = sec_value // 60
14 | sec_value %= 60
15 | if hour_value > 0:
16 | return f"{hour_value:02d}:{min_value:02d}:{sec_value:02d}"
17 | return f"{min_value:02d}:{sec_value:02d}"
18 |
--------------------------------------------------------------------------------
/sdeconv/data/__init__.py:
--------------------------------------------------------------------------------
1 | """Module that provides sample images"""
2 |
3 | import os.path as osp
4 | import os
5 | import numpy as np
6 | import torch
7 | from skimage.io import imread
8 | from sdeconv.core import SSettings
9 |
10 |
11 | __all__ = ['celegans',
12 | 'pollen',
13 | 'pollen_poison_noise_blurred',
14 | 'pollen_psf']
15 |
16 | legacy_data_dir = osp.abspath(osp.dirname(__file__))
17 |
18 |
19 | def _fetch(data_filename: str) -> str:
20 | """Fetch a given data file from the local data dir.
21 |
22 | This function provides the path location of the data file given
23 | its name in the scikit-image repository.
24 |
25 | :param data_filename: Name of the file in the scikit-bioimaging data dir,
26 | :return: Path of the local file as a python string
27 | """
28 |
29 | filepath = os.path.join(legacy_data_dir, data_filename)
30 |
31 | if os.path.isfile(filepath):
32 | return filepath
33 | raise FileExistsError("Cannot find the file:", filepath)
34 |
35 |
36 | def _load(filename: str) -> np.ndarray:
37 | """Load an image file located in the data directory.
38 |
39 | :param: filename: Path of the file to load.
40 | :return: The data loaded in a numpy array
41 | """
42 | return torch.tensor(np.float32(imread(_fetch(filename)))).to(SSettings.instance().device)
43 |
44 |
45 | def celegans():
46 | """2D confocal (Airyscan) image of a c. elegans intestine.
47 |
48 | :return: (310, 310) uint16 ndarray
49 | """
50 | return _load("celegans.tif")[3:-3, 3:-3]
51 |
52 |
53 | def pollen():
54 | """3D Pollen image.
55 |
56 | :return: (32, 256, 256) uint16 ndarray
57 | """
58 | return _load("pollen.tif")
59 |
60 |
61 | def pollen_poison_noise_blurred():
62 | """3D Pollen image corrupted with Poisson noise and blurred .
63 |
64 | :return: (32, 256, 256) uint16 ndarray
65 | """
66 | return _load("pollen_poisson_noise_blurred.tif")
67 |
68 |
69 | def pollen_psf():
70 | """3D PSF to deblur the pollen image.
71 |
72 | :return: (32, 256, 256) uint16 ndarray
73 | """
74 | psf = _load("pollen_psf.tif")
75 | return psf / torch.sum(psf)
76 |
--------------------------------------------------------------------------------
/sdeconv/data/celegans.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/data/celegans.tif
--------------------------------------------------------------------------------
/sdeconv/data/pollen.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/data/pollen.tif
--------------------------------------------------------------------------------
/sdeconv/data/pollen_poisson_noise_blurred.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/data/pollen_poisson_noise_blurred.tif
--------------------------------------------------------------------------------
/sdeconv/data/pollen_psf.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/data/pollen_psf.tif
--------------------------------------------------------------------------------
/sdeconv/deconv/__init__.py:
--------------------------------------------------------------------------------
1 | """Module that implements the image deconvolution algorithms"""
2 | from .interface import SDeconvFilter
3 | from .wiener import SWiener, swiener
4 | from .richardson_lucy import SRichardsonLucy, srichardsonlucy
5 | from .spitfire import Spitfire, spitfire
6 | from .noise2void import Noise2VoidDeconv
7 | from .self_supervised_nn import SelfSupervisedNNDeconv
8 | from .nn_deconv import NNDeconv
9 |
10 | __all__ = ['SDeconvFilter',
11 | 'SWiener',
12 | 'swiener',
13 | 'SRichardsonLucy',
14 | 'srichardsonlucy',
15 | 'Spitfire',
16 | 'spitfire',
17 | 'Noise2VoidDeconv',
18 | 'SelfSupervisedNNDeconv',
19 | 'NNDeconv'
20 | ]
21 |
--------------------------------------------------------------------------------
/sdeconv/deconv/_datasets.py:
--------------------------------------------------------------------------------
1 | """This module implements the datasets for deep learning training"""
2 | from typing import Callable
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import torch
7 | from skimage.io import imread
8 |
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class SelfSupervisedPatchDataset(Dataset):
13 | """Gray scaled image patched dataset for Self supervised learning
14 |
15 | :param images_dir: Directory containing the training images
16 | :param patch_size: Size of the squared training patches
17 | :param stride: Stride used to extract overlapping patches from images
18 | :param transform: Transformation to images before model
19 | """
20 | def __init__(self,
21 | images_dir: Path,
22 | patch_size: int = 40,
23 | stride: int = 10,
24 | transform: Callable = None):
25 | super().__init__()
26 | self.images_dir = Path(images_dir)
27 | self.patch_size = patch_size
28 | self.stride = stride
29 | self.transform = transform
30 |
31 | source_images = sorted(self.images_dir.glob('*.*'))
32 |
33 | self.nb_images = len(source_images)
34 | image = imread(source_images[0])
35 | self.n_patches = self.nb_images * ((image.shape[0] - patch_size) // stride) * \
36 | ((image.shape[1] - patch_size) // stride)
37 | print('num patches = ', self.n_patches)
38 |
39 | # Load all the images in a list
40 | self.images_data = []
41 | for source in source_images:
42 | self.images_data.append(np.float32(imread(source)))
43 |
44 | def __len__(self):
45 | return self.n_patches
46 |
47 | def __getitem__(self, idx):
48 | nb_patch_per_img = self.n_patches // self.nb_images
49 |
50 | img_number = idx // nb_patch_per_img
51 |
52 | img_np = self.images_data[img_number]
53 |
54 | nb_patch_w = (img_np.shape[1] - self.patch_size) // self.stride
55 | idx = idx % nb_patch_per_img
56 | i, j = idx // nb_patch_w, idx % nb_patch_w
57 | img_patch = \
58 | img_np[i * self.stride:i * self.stride + self.patch_size,
59 | j * self.stride:j * self.stride + self.patch_size]
60 |
61 | if self.transform:
62 | img_patch = self.transform(torch.Tensor(img_patch))
63 | else:
64 | img_patch = torch.Tensor(img_patch).float()
65 |
66 | return (
67 | img_patch.view(1, *img_patch.shape),
68 | str(idx)
69 | )
70 |
71 |
72 | class SelfSupervisedDataset(Dataset):
73 | """Gray scaled image dataset for Self supervised learning
74 |
75 | :param images_dir: Directory containing the training images
76 | :param transform: Transformation to images before model
77 | """
78 | def __init__(self,
79 | images_dir: Path,
80 | transform: Callable = None):
81 | super().__init__()
82 | self.images_dir = Path(images_dir)
83 | self.transform = transform
84 |
85 | self.source_images = sorted(self.images_dir.glob('*.*'))
86 |
87 | self.nb_images = len(self.source_images)
88 |
89 | # Load all the images in a list
90 | self.images_data = []
91 | for source in self.source_images:
92 | self.images_data.append(np.float32(imread(source)))
93 |
94 | def __len__(self):
95 | return self.nb_images
96 |
97 | def __getitem__(self, idx):
98 |
99 | img_patch = self.images_data[idx]
100 | if self.transform:
101 | img_patch = self.transform(torch.Tensor(img_patch))
102 | else:
103 | img_patch = torch.Tensor(img_patch).float()
104 |
105 | return (
106 | img_patch.view(1, *img_patch.shape),
107 | self.source_images[idx].stem
108 | )
109 |
110 |
111 | class RestorationDataset(Dataset):
112 | """Dataset for image restoration from full images
113 |
114 | All the training images must be saved as individual images in source and
115 | target folders.
116 |
117 | :param source_dir: Path of the noisy training images (or patches)
118 | :param target_dir: Path of the ground truth images (or patches)
119 | :param transform: Transformation to apply to the image before model call
120 | """
121 | def __init__(self,
122 | source_dir: str | Path,
123 | target_dir: str | Path,
124 | transform: Callable = None):
125 | super().__init__()
126 | self.device = None
127 | self.source_dir = Path(source_dir)
128 | self.target_dir = Path(target_dir)
129 | self.transform = transform
130 |
131 | self.source_images = sorted(self.source_dir.glob('*.*'))
132 | self.target_images = sorted(self.target_dir.glob('*.*'))
133 | if len(self.source_images) != len(self.target_images):
134 | raise ValueError("Source and target dirs are not the same length")
135 |
136 | self.nb_images = len(self.source_images)
137 |
138 | def __len__(self):
139 | return self.nb_images
140 |
141 | def __getitem__(self, idx):
142 | source_patch = np.float32(imread(self.source_images[idx]))
143 | target_patch = np.float32(imread(self.target_images[idx]))
144 |
145 | # numpy to tensor
146 | source_patch = torch.from_numpy(source_patch).view(1, *source_patch.shape).float()
147 | target_patch = torch.from_numpy(target_patch).view(1, *target_patch.shape).float()
148 |
149 | # data augmentation
150 | if self.transform:
151 | both_images = torch.cat((source_patch.unsqueeze(0), target_patch.unsqueeze(0)), 0)
152 | transformed_images = self.transform(both_images)
153 | source_patch = transformed_images[0, ...]
154 | target_patch = transformed_images[1, ...]
155 |
156 | return source_patch, target_patch, self.source_images[idx].stem
157 |
158 |
159 | class RestorationPatchDataset(Dataset):
160 | """Dataset for image restoration using patches
161 |
162 | All the training images must be saved as individual images in source and
163 | target folders.
164 |
165 | :param source_dir: Path of the noisy training images (or patches)
166 | :param target_dir: Path of the ground truth images (or patches)
167 | :param patch_size: Size of the patches (width=height)
168 | :param stride: Length of the patch overlapping
169 | :param transform: Transformation to apply to the image before model call
170 |
171 | """
172 | def __init__(self,
173 | source_dir: str | Path,
174 | target_dir: str | Path,
175 | patch_size: int = 40,
176 | stride: int = 10,
177 | transform: Callable = None):
178 | super().__init__()
179 | self.device = None
180 | self.source_dir = Path(source_dir)
181 | self.target_dir = Path(target_dir)
182 | self.patch_size = patch_size
183 | self.stride = stride
184 | self.transform = transform
185 |
186 | self.source_images = sorted(self.source_dir.glob('*.*'))
187 | self.target_images = sorted(self.target_dir.glob('*.*'))
188 | if len(self.source_images) != len(self.target_images):
189 | raise ValueError("Source and target dirs are not the same length")
190 |
191 | self.nb_images = len(self.source_images)
192 | image = imread(self.source_images[0])
193 | self.n_patches = self.nb_images * ((image.shape[0] - patch_size) // stride) * \
194 | ((image.shape[1] - patch_size) // stride)
195 |
196 | def __len__(self):
197 | return self.n_patches
198 |
199 | def __getitem__(self, idx):
200 | # Crop a patch from original image
201 | nb_patch_per_img = self.n_patches // self.nb_images
202 |
203 | elt = self.source_images[idx // nb_patch_per_img]
204 |
205 | img_source_np = \
206 | np.float32(imread(self.source_dir / elt))
207 | img_target_np = \
208 | np.float32(imread(self.target_dir / elt))
209 |
210 | nb_patch_w = (img_source_np.shape[1] - self.patch_size) // self.stride
211 | idx = idx % nb_patch_per_img
212 | i, j = idx // nb_patch_w, idx % nb_patch_w
213 | source_patch = \
214 | img_source_np[i * self.stride:i * self.stride + self.patch_size,
215 | j * self.stride:j * self.stride + self.patch_size]
216 | target_patch = \
217 | img_target_np[i * self.stride:i * self.stride + self.patch_size,
218 | j * self.stride:j * self.stride + self.patch_size]
219 |
220 | # numpy to tensor
221 | source_patch = torch.from_numpy(source_patch).view(1, *source_patch.shape).float()
222 | target_patch = torch.from_numpy(target_patch).view(1, *target_patch.shape).float()
223 |
224 | # data augmentation
225 | if self.transform:
226 | both_images = torch.cat((source_patch.unsqueeze(0), target_patch.unsqueeze(0)), 0)
227 | transformed_images = self.transform(both_images)
228 | source_patch = transformed_images[0, ...]
229 | target_patch = transformed_images[1, ...]
230 |
231 | # to tensor
232 | return (source_patch,
233 | target_patch,
234 | str(idx)
235 | )
236 |
237 |
238 | class RestorationPatchDatasetLoad(Dataset):
239 | """Dataset for image restoration using patches preloaded in memory
240 |
241 | All the training images must be saved as individual images in source and
242 | target folders.
243 | This version load all the dataset in the CPU
244 |
245 | :param source_dir: Path of the noisy training images (or patches)
246 | :param target_dir: Path of the ground truth images (or patches)
247 | :param patch_size: Size of the patches (width=height)
248 | :param stride: Length of the patch overlapping
249 | :param transform: Transformation to apply to the image before model call
250 |
251 | """
252 | # pylint: disable=too-many-instance-attributes
253 | # pylint: disable=too-many-arguments
254 | def __init__(self,
255 | source_dir: str | Path,
256 | target_dir: str | Path,
257 | patch_size: int = 40,
258 | stride: int = 10,
259 | transform: Callable = None):
260 | super().__init__()
261 | self.source_dir = Path(source_dir)
262 | self.target_dir = Path(target_dir)
263 | self.patch_size = patch_size
264 | self.stride = stride
265 | self.transform = transform
266 |
267 | self.source_images = sorted(self.source_dir.glob('*.*'))
268 | self.target_images = sorted(self.target_dir.glob('*.*'))
269 |
270 | if len(self.source_images) != len(self.target_images):
271 | raise ValueError("Source and target dirs are not the same length")
272 |
273 | self.nb_images = len(self.source_images)
274 | image = imread(self.source_images[0])
275 | self.n_patches = self.nb_images * ((image.shape[0] - patch_size) // stride) * \
276 | ((image.shape[1] - patch_size) // stride)
277 | print('num patches = ', self.n_patches)
278 |
279 | # Load all the images in a list
280 | self.source_data = []
281 | for source in self.source_images:
282 | self.source_data.append(np.float32(imread(source)))
283 | self.target_data = []
284 | for target in self.target_images:
285 | self.target_data.append(np.float32(imread(target)))
286 |
287 | def __len__(self):
288 | return self.n_patches
289 |
290 | def __getitem__(self, idx):
291 | # Crop a patch from original image
292 | nb_patch_per_img = self.n_patches // self.nb_images
293 |
294 | img_number = idx // nb_patch_per_img
295 |
296 | img_source_np = self.source_data[img_number]
297 | img_target_np = self.target_data[img_number]
298 |
299 | nb_patch_w = (img_source_np.shape[1] - self.patch_size) // self.stride
300 | idx = idx % nb_patch_per_img
301 | i, j = idx // nb_patch_w, idx % nb_patch_w
302 | source_patch = \
303 | img_source_np[i * self.stride:i * self.stride + self.patch_size,
304 | j * self.stride:j * self.stride + self.patch_size]
305 | target_patch = \
306 | img_target_np[i * self.stride:i * self.stride + self.patch_size,
307 | j * self.stride:j * self.stride + self.patch_size]
308 |
309 | # numpy to tensor
310 | source_patch = torch.from_numpy(source_patch).view(1, *source_patch.shape).float()
311 | target_patch = torch.from_numpy(target_patch).view(1, *target_patch.shape).float()
312 |
313 | # data augmentation
314 | if self.transform:
315 | both_images = torch.cat((source_patch.unsqueeze(0), target_patch.unsqueeze(0)), 0)
316 | transformed_images = self.transform(both_images)
317 | source_patch = transformed_images[0, ...]
318 | target_patch = transformed_images[1, ...]
319 |
320 | return (source_patch,
321 | target_patch,
322 | str(idx)
323 | )
324 |
--------------------------------------------------------------------------------
/sdeconv/deconv/_transforms.py:
--------------------------------------------------------------------------------
1 | """Data transformation for image restoration workflow"""
2 | import torch
3 | from torchvision.transforms import v2
4 |
5 |
6 | class FlipAugmentation:
7 | """Data augmentation by flipping images"""
8 | def __init__(self):
9 | self.__transform = v2.Compose([
10 | v2.RandomHorizontalFlip(p=0.5),
11 | v2.RandomVerticalFlip(p=0.5),
12 | v2.ToDtype(torch.float32, scale=False)
13 | ])
14 |
15 | def __call__(self, image):
16 | return self.__transform(image)
17 |
18 |
19 | class VisionScale:
20 | """Scale images in [-1, 1]"""
21 |
22 | def __init__(self):
23 | self.__transform = v2.Compose([
24 | v2.ToDtype(torch.float32, scale=True)
25 | ])
26 |
27 | def __call__(self, image):
28 | return self.__transform(image)
29 |
--------------------------------------------------------------------------------
/sdeconv/deconv/_unet_2d.py:
--------------------------------------------------------------------------------
1 | """Module to implement a UNet in 2D wwith pytorch module"""
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class UNetConvBlock(nn.Module):
7 | """Convolution block for UNet architecture
8 |
9 | This block is 2 convolution layers with a ReLU.
10 | An optional batch norm can be added after each convolution layer
11 |
12 | :param n_channels_in: Number of input channels (or features)
13 | :param n_channels_out: Number of output channels (or features)
14 | :param use_batch_norm: True to use the batch norm layers
15 | """
16 | def __init__(self,
17 | n_channels_in: int,
18 | n_channels_out: int,
19 | use_batch_norm: bool = True):
20 | super().__init__()
21 |
22 | self.use_batch_norm = use_batch_norm
23 | self.conv1 = nn.Conv2d(n_channels_in, n_channels_out,
24 | kernel_size=3, padding=1)
25 | self.bn1 = nn.BatchNorm2d(n_channels_out)
26 |
27 | self.conv2 = nn.Conv2d(n_channels_out, n_channels_out,
28 | kernel_size=3, padding=1)
29 | self.bn2 = nn.BatchNorm2d(n_channels_out)
30 |
31 | self.relu = nn.ReLU()
32 |
33 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
34 | """Apply the model
35 |
36 | :param inputs: Data to process
37 | :return: The processed data
38 | """
39 | x = self.conv1(inputs)
40 | if self.use_batch_norm:
41 | x = self.bn1(x)
42 | x = self.relu(x)
43 |
44 | x = self.conv2(x)
45 | if self.use_batch_norm:
46 | x = self.bn2(x)
47 | x = self.relu(x)
48 |
49 | return x
50 |
51 |
52 | class UNetEncoderBlock(nn.Module):
53 | """Encoder block of the UNet architecture
54 |
55 | The encoder block is a convolution block and a max polling layer
56 |
57 | :param n_channels_in: Number of input channels (or features)
58 | :param n_channels_out: Number of output channels (or features)
59 | :param use_batch_norm: True to use the batch norm layers
60 | """
61 | def __init__(self,
62 | n_channels_in: int,
63 | n_channels_out: int,
64 | use_batch_norm: bool = True):
65 | super().__init__()
66 |
67 | self.conv = UNetConvBlock(n_channels_in, n_channels_out,
68 | use_batch_norm)
69 | self.pool = nn.MaxPool2d((2, 2))
70 |
71 | def forward(self, inputs: torch.Tensor):
72 | """torch module forward method
73 |
74 | :param inputs: tensor to process
75 | """
76 | x = self.conv(inputs)
77 | p = self.pool(x)
78 |
79 | return x, p
80 |
81 |
82 | class UNetDecoderBlock(nn.Module):
83 | """Decoder block of a UNet architecture
84 |
85 | The decoder is an up-sampling concatenation and convolution block
86 |
87 | :param n_channels_in: Number of input channels (or features)
88 | :param n_channels_out: Number of output channels (or features)
89 | :param use_batch_norm: True to use the batch norm layers
90 | """
91 | def __init__(self,
92 | n_channels_in: int,
93 | n_channels_out: int,
94 | use_batch_norm: bool = True):
95 | super().__init__()
96 |
97 | self.up = nn.Upsample(scale_factor=(2, 2), mode='nearest')
98 | self.conv = UNetConvBlock(n_channels_in+n_channels_out,
99 | n_channels_out, use_batch_norm)
100 |
101 | def forward(self, inputs: torch.Tensor, skip: torch.Tensor):
102 | """Module torch forward
103 |
104 | :param inputs: input tensor
105 | :param skip: skip connection tensor
106 | """
107 | x = self.up(inputs)
108 | x = torch.cat([x, skip], dim=1)
109 | x = self.conv(x)
110 |
111 | return x
112 |
113 |
114 | class UNet2D(nn.Module):
115 | """Implementation of a UNet network
116 |
117 | :param n_channels_in: Number of input channels (or features),
118 | :param n_channels_out: Number of output channels (or features),
119 | :param n_feature_first: Number of channels (or features) in the first convolution block,
120 | :param use_batch_norm: True to use the batch norm layers
121 | """
122 | def __init__(self,
123 | n_channels_in: int = 1,
124 | n_channels_out: int = 1,
125 | n_channels_layers: list[int] = (32, 64, 128),
126 | use_batch_norm: bool = False):
127 | super().__init__()
128 |
129 | # Encoder
130 | self.encoder = nn.ModuleList()
131 | for idx, n_channels in enumerate(n_channels_layers[:-1]):
132 | n_in = n_channels_in if idx == 0 else n_channels_layers[idx-1]
133 | n_out = n_channels
134 | self.encoder.append(UNetEncoderBlock(n_in, n_out, use_batch_norm))
135 |
136 | # Bottleneck
137 | self.bottleneck = UNetConvBlock(n_channels_layers[-2],
138 | n_channels_layers[-1],
139 | use_batch_norm)
140 |
141 | # Decoder
142 | self.decoder = nn.ModuleList()
143 | for idx in reversed(range(len(n_channels_layers))):
144 | if idx > 0:
145 | n_in = n_channels_layers[idx]
146 | n_out = n_channels_layers[idx-1]
147 | self.decoder.append(UNetDecoderBlock(n_in, n_out, use_batch_norm))
148 |
149 | self.outputs = nn.Conv2d(n_channels_layers[0], n_channels_out,
150 | kernel_size=1, padding=0)
151 |
152 | self.num_layers = len(n_channels_layers)-1
153 |
154 |
155 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
156 | """Module torch forward
157 |
158 | :param inputs: input tensor
159 | :return: the tensor processed by the network
160 | """
161 | # Encoder
162 | skips = []
163 | p = [None] * (len(self.encoder)+1)
164 | p[0] = inputs
165 | for idx, layer in enumerate(self.encoder):
166 | s, p[idx+1] = layer(p[idx])
167 | skips.append(s)
168 |
169 | # Bottleneck
170 | d = [None] * (len(self.decoder)+1)
171 | d[0] = self.bottleneck(p[-1])
172 |
173 | # decoder
174 | for idx, layer in enumerate(self.decoder):
175 | d[idx+1] = layer(d[idx], skips[self.num_layers-idx-1])
176 |
177 | # Classifier
178 | return self.outputs(d[-1])
179 |
--------------------------------------------------------------------------------
/sdeconv/deconv/_utils.py:
--------------------------------------------------------------------------------
1 | """Implementation of shared methods for all multiple deconvolution algorithms"""
2 | import numpy as np
3 | import torch
4 | from sdeconv.core import SSettings
5 |
6 |
7 | def np_to_torch(image: np.ndarray | torch.Tensor) -> torch.Tensor:
8 | """Convert a numpy array into a torch tensor
9 |
10 | The array is then loaded on the GPU if available
11 |
12 | :param image: Image to convert,
13 | :return: The converted image
14 | """
15 | if isinstance(image, np.ndarray):
16 | image_ = torch.tensor(image).to(SSettings.instance().device)
17 | else:
18 | image_ = image
19 | return image_
20 |
21 |
22 | def resize_psf_2d(image: torch.Tensor, psf: torch.Tensor) -> torch.Tensor:
23 | """Resize a 2D PSF image to the target image size
24 |
25 | :param image: Reference image tensor,
26 | :param psf: Point Spread Function tensor to resize,
27 | :return: the psf tensor padded to get the same shape as image
28 | """
29 | kernel = torch.zeros(image.shape).to(SSettings.instance().device)
30 | x_start = int(image.shape[0] / 2 - psf.shape[0] / 2) + 1
31 | y_start = int(image.shape[1] / 2 - psf.shape[1] / 2) + 1
32 | kernel[x_start:x_start + psf.shape[0], y_start:y_start + psf.shape[1]] = psf
33 | return kernel
34 |
35 |
36 | def resize_psf_3d(image: torch.Tensor, psf: torch.Tensor) -> torch.Tensor:
37 | """Resize a 3D PSF image to the target image size
38 |
39 | :param image: Reference image tensor
40 | :param psf: Point Spread Function tensor to resize
41 | :returns: the psf tensor padded to get the same shape as image
42 | """
43 | kernel = torch.zeros(image.shape).to(SSettings.instance().device)
44 | x_start = int(image.shape[0] / 2 - psf.shape[0] / 2) + 1
45 | y_start = int(image.shape[1] / 2 - psf.shape[1] / 2) + 1
46 | z_start = int(image.shape[2] / 2 - psf.shape[2] / 2) + 1
47 | kernel[x_start:x_start + psf.shape[0], y_start:y_start + psf.shape[1],
48 | z_start:z_start + psf.shape[2]] = psf
49 | return kernel
50 |
51 |
52 | def pad_2d(image: torch.Tensor,
53 | psf: torch.Tensor,
54 | pad: int | tuple[int, int]
55 | ) -> tuple[torch.Tensor, torch.Tensor, int | tuple[int, int]]:
56 | """Pad an image and it PSF for deconvolution
57 |
58 | :param image: 2D image tensor
59 | :param psf: 2D Point Spread Function.
60 | :param pad: Padding in each dimension.
61 | :return image: psf, padding: padded versions of the image and the PSF, plus the padding tuple
62 | """
63 | padding = pad
64 | if isinstance(pad, tuple) and len(pad) != image.ndim:
65 | raise ValueError("Padding must be the same dimension as image")
66 | if isinstance(pad, int):
67 | if pad == 0:
68 | return image, psf, (0, 0)
69 | padding = (pad, pad)
70 |
71 | if padding[0] > 0 and padding[1] > 0:
72 |
73 | pad_fn = torch.nn.ReflectionPad2d((padding[0], padding[0], padding[1], padding[1]))
74 | image_pad = pad_fn(image.detach().clone().to(
75 | SSettings.instance().device).view(1, 1, image.shape[0], image.shape[1])).view(
76 | (image.shape[0] + 2 * padding[0], image.shape[1] + 2 * padding[0]))
77 | else:
78 | image_pad = image.detach().clone().to(SSettings.instance().device)
79 | psf_pad = resize_psf_2d(image_pad, psf)
80 | return image_pad, psf_pad, padding
81 |
82 |
83 | def pad_3d(image: torch.Tensor,
84 | psf: torch.Tensor,
85 | pad: int | tuple[int, int, int]
86 | ) -> tuple[torch.Tensor, torch.Tensor, int | tuple[int, int, int]]:
87 | """Pad an image and it PSF for deconvolution
88 |
89 | :param image: 2D image tensor
90 | :param psf: 2D Point Spread Function.
91 | :param pad: Padding in each dimension.
92 | :return: image, psf, padding: padded versions of the image and the PSF, plus the padding tuple
93 | """
94 | padding = pad
95 | if isinstance(pad, tuple) and len(pad) != image.ndim:
96 | raise ValueError("Padding must be the same dimension as image")
97 | if isinstance(pad, int):
98 | if pad == 0:
99 | return image, psf, (0, 0, 0)
100 | padding = (pad, pad, pad)
101 |
102 | if padding[0] > 0 and padding[1] > 0 and padding[2] > 0:
103 | p3d = (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0])
104 | pad_fn = torch.nn.ReflectionPad3d(p3d)
105 | image_pad = pad_fn(
106 | image.detach().clone().to(SSettings.instance().device).view(1, 1, image.shape[0],
107 | image.shape[1],
108 | image.shape[2])).view(
109 | (image.shape[0] + 2 * padding[0], image.shape[1] + 2 * padding[1],
110 | image.shape[2] + 2 * padding[2]))
111 | psf_pad = torch.nn.functional.pad(psf, p3d, "constant", 0)
112 | else:
113 | image_pad = image
114 | psf_pad = psf
115 | return image_pad, psf_pad, padding
116 |
117 |
118 | def unpad_3d(image: torch.Tensor, padding: tuple[int, int, int]) -> torch.Tensor:
119 | """Remove the padding of an image
120 |
121 | :param image: 3D image to un-pad
122 | :param padding: Padding in each dimension
123 | :return: a torch.Tensor of the un-padded image
124 | """
125 | return image[padding[0]:-padding[0],
126 | padding[1]:-padding[1],
127 | padding[2]:-padding[2]]
128 |
129 |
130 | # define the PSF parameter
131 | psf_parameter = {
132 | 'type': 'torch.Tensor',
133 | 'label': 'psf',
134 | 'help': 'Point Spread Function',
135 | 'default': None
136 | }
137 |
--------------------------------------------------------------------------------
/sdeconv/deconv/interface.py:
--------------------------------------------------------------------------------
1 | """Interface for a deconvolution filter"""
2 | import torch
3 | from sdeconv.core import SObservable
4 |
5 |
6 | class SDeconvFilter(SObservable):
7 | """Interface for a deconvolution filter
8 |
9 | All the algorithm settings must be set in the `__init__` method (PSF included) and the
10 | `__call__` method is used to actually do the calculation
11 | """
12 | def __init__(self):
13 | super().__init__()
14 | self.type = 'SDeconvFilter'
15 |
16 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
17 | """Do the deconvolution
18 |
19 | :param image: Blurry image for a single channel time point [(Z) Y X]
20 | :return: deblurred image [(Z) Y X]
21 | """
22 | raise NotImplementedError('SDeconvFilter is an interface. Please implement the'
23 | ' __call__ method')
24 |
--------------------------------------------------------------------------------
/sdeconv/deconv/interface_nn.py:
--------------------------------------------------------------------------------
1 | """This module implements interface for deconvolution based on neural network"""
2 | from abc import abstractmethod
3 | from pathlib import Path
4 |
5 | import torch
6 | from skimage.io import imsave
7 |
8 | from ..core import seconds2str
9 | from ..core import SConsoleLogger
10 |
11 | from ._unet_2d import UNet2D
12 | from ._transforms import VisionScale
13 |
14 |
15 | class NNModule(torch.nn.Module):
16 | """Deconvolution using the noise to void algorithm"""
17 | def __init__(self):
18 | super().__init__()
19 |
20 | self._model_args = None
21 | self._model = None
22 | self._loss_fn = None
23 | self._optimizer = None
24 | self._save_all = True
25 | self._device = None
26 | self._out_dir = None
27 | self._val_data_loader = None
28 | self._train_data_loader = None
29 | self._progress = SConsoleLogger()
30 | self._current_epoch = None
31 | self._current_loss = None
32 |
33 | @abstractmethod
34 | def fit(self,
35 | train_directory: Path,
36 | val_directory: Path,
37 | n_channel_in: int = 1,
38 | n_channels_layer: list[int] = (32, 64, 128),
39 | patch_size: int = 32,
40 | n_epoch: int = 25,
41 | learning_rate: float = 1e-3,
42 | out_dir: Path = None
43 | ):
44 | """Train a model on a dataset
45 |
46 | :param train_directory: Directory containing the images used for
47 | training. One file per image,
48 | :param val_directory: Directory containing the images used for validation of
49 | the training. One file per image,
50 | :param n_channel_in: Number of channels in the input images
51 | :param n_channels_layer: Number of channels for each hidden layers of the model,
52 | :param patch_size: Size of square patches used for training the model,
53 | :param n_epoch: Number of epochs,
54 | :param learning_rate: Adam optimizer learning rate
55 | """
56 | raise NotImplementedError('NNModule is abstract')
57 |
58 | def _train_loop(self, n_epoch: int):
59 | """Run the main train loop (should be called in fit)
60 |
61 | :param n_epoch: Number of epochs to run
62 | """
63 | for epoch in range(n_epoch):
64 | self._current_epoch = epoch
65 | train_data = self._train_step()
66 | self._after_train_step(train_data)
67 | self._after_train()
68 |
69 | @abstractmethod
70 | def _train_step(self):
71 | """Runs one step of training"""
72 |
73 | def _after_train_batch(self, data: dict[str, any]):
74 | """Instructions runs after one batch
75 |
76 | :param data: Dictionary of metadata to log or process
77 | """
78 | prefix = f"Epoch = {self._current_epoch+1:d}"
79 | loss_str = f"{data['loss']:.7f}"
80 | full_time_str = seconds2str(int(data['full_time']))
81 | remains_str = seconds2str(int(data['remain_time']))
82 | suffix = str(data['id_batch']) + '/' + str(data['total_batch']) + \
83 | ' [' + full_time_str + '<' + remains_str + ', loss=' + \
84 | loss_str + '] '
85 | self._progress.progress(data['id_batch'],
86 | data['total_batch'],
87 | prefix=prefix,
88 | suffix=suffix)
89 |
90 | def _after_train_step(self, data: dict):
91 | """Instructions runs after one train step.
92 |
93 | This method can be used to log data or print console messages
94 |
95 | :param data: Dictionary of metadata to log or process
96 | """
97 | if self._save_all:
98 | self.save(Path(self._out_dir, f'model_{self._current_epoch}.ml'))
99 |
100 | def _after_train(self):
101 | """Instructions runs after the train."""
102 | # create the output dir
103 | predictions_dir = self._out_dir / 'predictions'
104 | predictions_dir.mkdir(parents=True, exist_ok=True)
105 |
106 | # predict on all the test set
107 | self._model.eval()
108 | for x, names in self._val_data_loader:
109 | x = x.to(self.device())
110 |
111 | with torch.no_grad():
112 | prediction = self._model(x)
113 | for i, name in enumerate(names):
114 | imsave(predictions_dir / f'{name}.tif',
115 | prediction[i, ...].cpu().numpy())
116 |
117 | def _init_model(self,
118 | n_channel_in: int = 1,
119 | n_channel_out: int = 1,
120 | n_channels_layer: list[int] = (32, 64, 128)
121 | ):
122 | """Initialize the model
123 |
124 | :param n_channel_in: Number of channels for the input image
125 | :param n_channel_in: Number of channels for the output image
126 | :param n_channels_layer: Number of channels for each layers of the UNet
127 | """
128 | self._model_args = {
129 | "n_channel_in": n_channel_in,
130 | "n_channel_out": n_channel_out,
131 | "n_channels_layer": n_channels_layer
132 | }
133 | self._model = UNet2D(n_channel_in, n_channel_out, n_channels_layer, True)
134 | self._model.to(self.device())
135 |
136 | def device(self) -> str:
137 | """Get the GPU if exists
138 |
139 | :return: The device name (cuda or CPU)
140 | """
141 | if self._device is None:
142 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
143 | return self._device
144 |
145 | def load(self, filename: Path):
146 | """Load pre-trained model from file
147 |
148 | :param: Path of the model file
149 | """
150 | params = torch.load(filename, map_location=torch.device(self.device()))
151 | self._init_model(**params["model_args"])
152 | self._model.load_state_dict(params['model_state_dict'])
153 | self._model.to(self.device())
154 |
155 | def save(self, filename: Path):
156 | """Save the model into file
157 |
158 | :param: Path of the model file
159 | """
160 | torch.save({
161 | 'model_args': self._model_args,
162 | 'model_state_dict': self._model.state_dict(),
163 | }, filename)
164 |
165 |
166 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
167 | """Apply the model on a image or batch of image
168 |
169 | :param image: Blurry image for a single channel or batch [(B) Y X]
170 | :return: deblurred image [(B) Y X]
171 | """
172 | if image.ndim == 2:
173 | image = image.view(1, 1, *image.shape)
174 | elif image.ndim > 2:
175 | raise ValueError("The current implementation of neural network "
176 | "deconvolution works only with 2D images")
177 |
178 | self._model.eval()
179 | scaler = VisionScale()
180 | with torch.no_grad():
181 | return self._model(scaler(image))
182 |
--------------------------------------------------------------------------------
/sdeconv/deconv/nn_deconv.py:
--------------------------------------------------------------------------------
1 | """This module implements self supervised deconvolution with Spitfire regularisation"""
2 | from pathlib import Path
3 | from timeit import default_timer as timer
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from skimage.io import imsave
8 |
9 | from .interface_nn import NNModule
10 | from ._datasets import RestorationDataset
11 | from ._datasets import RestorationPatchDatasetLoad
12 | from ._datasets import RestorationPatchDataset
13 | from ._transforms import FlipAugmentation, VisionScale
14 |
15 |
16 | class NNDeconv(NNModule):
17 | """Deconvolution using a neural network trained using ground truth"""
18 | def fit(self,
19 | train_directory: Path,
20 | val_directory: Path,
21 | n_channel_in: int = 1,
22 | n_channels_layer: list[int] = (32, 64, 128),
23 | patch_size: int = 32,
24 | n_epoch: int = 25,
25 | learning_rate: float = 1e-3,
26 | out_dir: Path = None,
27 | preload: bool = True
28 | ):
29 | """Train a model on a dataset
30 |
31 | :param train_directory: Directory containing the images used for
32 | training. One file per image,
33 | :param val_directory: Directory containing the images used for validation of the
34 | training. One file per image,
35 | :param n_channel_in: Number of channels in the input images
36 | :param n_channels_layer: Number of channels for each hidden layers of the model,
37 | :param patch_size: Size of square patches used for training the model,
38 | :param n_epoch: Number of epochs,
39 | :param learning_rate: Adam optimizer learning rate
40 | """
41 | self._init_model(n_channel_in, n_channel_in, n_channels_layer)
42 | self._out_dir = out_dir
43 | self._loss_fn = torch.nn.MSELoss()
44 | self._optimizer = torch.optim.Adam(self._model.parameters(), lr=learning_rate)
45 | if preload:
46 | train_dataset = RestorationPatchDatasetLoad(train_directory / "source",
47 | train_directory / "target",
48 | patch_size=patch_size,
49 | stride=int(patch_size/2),
50 | transform=FlipAugmentation())
51 | else:
52 | train_dataset = RestorationPatchDataset(train_directory / "source",
53 | train_directory / "target",
54 | patch_size=patch_size,
55 | stride=int(patch_size/2),
56 | transform=FlipAugmentation())
57 | val_dataset = RestorationDataset(val_directory / "source",
58 | val_directory / "target",
59 | transform=VisionScale())
60 | self._train_data_loader = DataLoader(train_dataset,
61 | batch_size=300,
62 | shuffle=True,
63 | drop_last=True,
64 | num_workers=0)
65 | self._val_data_loader = DataLoader(val_dataset,
66 | batch_size=1,
67 | shuffle=False,
68 | drop_last=False,
69 | num_workers=0)
70 | self._train_loop(n_epoch)
71 |
72 | def _train_step(self):
73 | """Runs one step of training"""
74 | size = len(self._train_data_loader.dataset)
75 | self._model.train()
76 | step_loss = 0
77 | count_step = 0
78 | tic = timer()
79 | for batch, (x, y, _) in enumerate(self._train_data_loader):
80 | count_step += 1
81 |
82 | x = x.to(self.device())
83 |
84 | # Compute prediction error
85 | prediction = self._model(x)
86 | loss = self._loss_fn(prediction, y)
87 | step_loss += loss
88 |
89 | # Backpropagation
90 | self._optimizer.zero_grad()
91 | loss.backward()
92 | self._optimizer.step()
93 |
94 | # count time
95 | toc = timer()
96 | full_time = toc - tic
97 | total_batch = int(size / len(x))
98 | remains = full_time * (total_batch - (batch+1)) / (batch+1)
99 |
100 | self._after_train_batch({'loss': loss,
101 | 'id_batch': batch+1,
102 | 'total_batch': total_batch,
103 | 'remain_time': int(remains+0.5),
104 | 'full_time': int(full_time+0.5)
105 | })
106 | if count_step > 0:
107 | step_loss /= count_step
108 | self._current_loss = step_loss
109 | return {'train_loss': step_loss}
110 |
111 | def _after_train(self):
112 | """Instructions runs after the train."""
113 | # create the output dir
114 | predictions_dir = self._out_dir / 'predictions'
115 | predictions_dir.mkdir(parents=True, exist_ok=True)
116 |
117 | # predict on all the test set
118 | self._model.eval()
119 | for x, _, names in self._val_data_loader:
120 | x = x.to(self.device())
121 |
122 | with torch.no_grad():
123 | prediction = self._model(x)
124 | for i, name in enumerate(names):
125 | imsave(predictions_dir / f'{name}.tif',
126 | prediction[i, ...].cpu().numpy())
127 |
--------------------------------------------------------------------------------
/sdeconv/deconv/noise2void.py:
--------------------------------------------------------------------------------
1 | """Implementation of deconvolution using noise2void algorithm"""
2 | from pathlib import Path
3 | from timeit import default_timer as timer
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import DataLoader
8 |
9 | from .interface_nn import NNModule
10 | from ._datasets import SelfSupervisedPatchDataset
11 | from ._datasets import SelfSupervisedDataset
12 | from ._transforms import FlipAugmentation, VisionScale
13 |
14 |
15 | def generate_2d_points(shape: tuple[int, int], n_point: int) -> tuple[np.ndarray, np.ndarray]:
16 | """Generate random 2D coordinates to mask
17 |
18 | :param shape: Shape of the image to mask
19 | :param n_point: Number of coordinates to mask
20 | :return: (y, x) coordinates to mask
21 | """
22 | idy_msk = np.random.randint(0, int(shape[0]/2), n_point)
23 | idx_msk = np.random.randint(0, int(shape[1]/2), n_point)
24 |
25 | idy_msk = 2*idy_msk
26 | idx_msk = 2*idx_msk
27 | if np.random.randint(2) == 1:
28 | idy_msk += 1
29 | if np.random.randint(2) == 1:
30 | idx_msk += 1
31 |
32 | return idy_msk, idx_msk
33 |
34 |
35 | def generate_mask_n2v(image: torch.Tensor, ratio: float) -> tuple[torch.Tensor, torch.Tensor]:
36 | """Generate a blind spots mask fot the patch image by randomly switch pixels values
37 |
38 | :param image: Image patch to add blind spots
39 | :param ratio: Ratio of blind spots for input patch masking
40 | :return: the transformed image and the mask image
41 | """
42 | img_shape = image.shape
43 | size_window = (5, 5)
44 | num_sample = int(img_shape[-2] * img_shape[-1] * ratio)
45 |
46 | mask = torch.zeros((img_shape[-2], img_shape[-1]), dtype=torch.float32)
47 | output = image.clone()
48 |
49 | idy_msk, idx_msk = generate_2d_points((img_shape[-2], img_shape[-1]), num_sample)
50 | num_sample = len(idy_msk)
51 |
52 | idy_neigh = np.random.randint(-size_window[0] // 2 + size_window[0] % 2,
53 | size_window[0] // 2 + size_window[0] % 2,
54 | num_sample)
55 | idx_neigh = np.random.randint(-size_window[1] // 2 + size_window[1] % 2,
56 | size_window[1] // 2 + size_window[1] % 2,
57 | num_sample)
58 |
59 | idy_msk_neigh = idy_msk + idy_neigh
60 | idx_msk_neigh = idx_msk + idx_neigh
61 |
62 | idy_msk_neigh = (idy_msk_neigh + (idy_msk_neigh < 0) * size_window[0] -
63 | (idy_msk_neigh >= img_shape[-2]) * size_window[0])
64 | idx_msk_neigh = (idx_msk_neigh + (idx_msk_neigh < 0) * size_window[1] -
65 | (idx_msk_neigh >= img_shape[-1]) * size_window[1])
66 |
67 | id_msk = (idy_msk, idx_msk)
68 | #id_msk_neigh = (idy_msk_neigh, idx_msk_neigh)
69 |
70 | output[:, :, idy_msk, idx_msk] = image[:, :, idy_msk_neigh, idx_msk_neigh]
71 | mask[id_msk] = 1.0
72 |
73 | return output, mask
74 |
75 |
76 | class N2VDeconLoss(torch.nn.Module):
77 | """MSE Loss with mask for Noise2Void deconvolution
78 |
79 | :param psf_file: File image containing the Point Spread Function
80 | :return: Loss tensor
81 | """
82 | def __init__(self,
83 | psf_image: torch.Tensor):
84 | super().__init__()
85 |
86 | self.__psf = psf_image
87 | if self.__psf.ndim > 2:
88 | raise ValueError('N2VDeconLoss PSF must be a gray scaled 2D image')
89 |
90 | self.__psf = self.__psf.view((1, 1, *self.__psf.shape))
91 | print('psf shape=', self.__psf.shape)
92 | self.__conv_op = torch.nn.Conv2d(1, 1,
93 | kernel_size=self.__psf.shape[2],
94 | stride=1,
95 | padding=int((self.__psf.shape[2] - 1) / 2),
96 | bias=False)
97 | with torch.no_grad():
98 | self.__conv_op.weight = torch.nn.Parameter(self.__psf)
99 |
100 | def forward(self, predict: torch.Tensor, target: torch.Tensor, mask: torch.Tensor):
101 | """Calculate forward loss
102 |
103 | :param predict: tensor predicted by the model
104 | :param target: Reference target tensor
105 | :param mask: Mask to select pixels of interest
106 | """
107 | conv_img = self.__conv_op(predict)
108 |
109 | num = torch.sum((conv_img*mask - target*mask)**2)
110 | den = torch.sum(mask)
111 | return num/den
112 |
113 |
114 | class Noise2VoidDeconv(NNModule):
115 | """Deconvolution using the noise to void algorithm"""
116 | def fit(self,
117 | train_directory: Path,
118 | val_directory: Path,
119 | n_channel_in: int = 1,
120 | n_channels_layer: list[int] = (32, 64, 128),
121 | patch_size: int = 32,
122 | n_epoch: int = 25,
123 | learning_rate: float = 1e-3,
124 | out_dir: Path = None,
125 | psf: torch.Tensor = None
126 | ):
127 | """Train a model on a dataset
128 |
129 | :param train_directory: Directory containing the images used for
130 | training. One file per image,
131 | :param val_directory: Directory containing the images used for validation of
132 | the training. One file per image,
133 | :param psf: Point spread function for deconvolution,
134 | :param n_channel_in: Number of channels in the input images
135 | :param n_channels_layer: Number of channels for each hidden layers of the model,
136 | :param patch_size: Size of square patches used for training the model,
137 | :param n_epoch: Number of epochs,
138 | :param learning_rate: Adam optimizer learning rate
139 | """
140 | self._init_model(n_channel_in, n_channel_in, n_channels_layer)
141 | self._out_dir = out_dir
142 | self._loss_fn = N2VDeconLoss(psf.to(self.device()))
143 | self._optimizer = torch.optim.Adam(self._model.parameters(), lr=learning_rate)
144 | train_dataset = SelfSupervisedPatchDataset(train_directory,
145 | patch_size=patch_size,
146 | stride=int(patch_size/2),
147 | transform=FlipAugmentation())
148 | val_dataset = SelfSupervisedDataset(val_directory, transform=VisionScale())
149 | self._train_data_loader = DataLoader(train_dataset,
150 | batch_size=300,
151 | shuffle=True,
152 | drop_last=True,
153 | num_workers=0)
154 | self._val_data_loader = DataLoader(val_dataset,
155 | batch_size=1,
156 | shuffle=False,
157 | drop_last=False,
158 | num_workers=0)
159 |
160 | self._train_loop(n_epoch)
161 |
162 | def _train_step(self):
163 | """Runs one step of training"""
164 | size = len(self._train_data_loader.dataset)
165 | self._model.train()
166 | step_loss = 0
167 | count_step = 0
168 | tic = timer()
169 | for batch, (x, _) in enumerate(self._train_data_loader):
170 | count_step += 1
171 |
172 | masked_x, mask = generate_mask_n2v(x, 0.1)
173 | x, masked_x, mask = (x.to(self.device()),
174 | masked_x.to(self.device()),
175 | mask.to(self.device()))
176 |
177 | # Compute prediction error
178 | prediction = self._model(masked_x)
179 | loss = self._loss_fn(prediction, x, mask)
180 | step_loss += loss
181 |
182 | # Backpropagation
183 | self._optimizer.zero_grad()
184 | loss.backward()
185 | self._optimizer.step()
186 |
187 | # count time
188 | toc = timer()
189 | full_time = toc - tic
190 | total_batch = int(size / len(x))
191 | remains = full_time * (total_batch - (batch+1)) / (batch+1)
192 |
193 | self._after_train_batch({'loss': loss,
194 | 'id_batch': batch+1,
195 | 'total_batch': total_batch,
196 | 'remain_time': int(remains+0.5),
197 | 'full_time': int(full_time+0.5)
198 | })
199 |
200 | if count_step > 0:
201 | step_loss /= count_step
202 | self._current_loss = step_loss
203 | return {'train_loss': step_loss}
204 |
--------------------------------------------------------------------------------
/sdeconv/deconv/richardson_lucy.py:
--------------------------------------------------------------------------------
1 | """Implementation of Richardson-Lucy deconvolution for 2D and 3D images"""
2 | import torch
3 | import numpy as np
4 | from .interface import SDeconvFilter
5 | from ._utils import pad_2d, pad_3d, np_to_torch
6 |
7 |
8 | class SRichardsonLucy(SDeconvFilter):
9 | """Implements the Richardson-Lucy deconvolution
10 |
11 | :param psf: Point spread function
12 | :param niter: Number of iterations
13 | :param pad: image padding size
14 | """
15 | def __init__(self,
16 | psf: torch.Tensor,
17 | niter: int = 30,
18 | pad: int | tuple[int, int] | tuple[int, int, int] = 0):
19 | super().__init__()
20 | self.psf = psf
21 | self.niter = niter
22 | self.pad = pad
23 |
24 | @staticmethod
25 | def _resize_psf(psf, width, height) -> torch.Tensor:
26 | """Resize the PSF to match the image size for Fourier transform
27 |
28 | :param psf: Point spread function
29 | :param width: Width of the resized PSF
30 | :param height: Height of the resized PSF
31 | :return: The resized PSF
32 | """
33 | kernel = torch.zeros((width, height))
34 | x_start = int(width / 2 - psf.shape[0] / 2) + 1
35 | y_start = int(height / 2 - psf.shape[1] / 2) + 1
36 | kernel[x_start:x_start+psf.shape[0], y_start:y_start+psf.shape[1]] = psf
37 | return kernel
38 |
39 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
40 | """Apply the Richardson-Lucy deconvolution
41 |
42 | :param image: Blurry image for a single channel time point [(Z) Y X]
43 | :return: deblurred image [(Z) Y X]
44 | """
45 | if image.ndim == 2:
46 | return self._deconv_2d(image)
47 | if image.ndim == 3:
48 | return self._deconv_3d(image)
49 | raise ValueError('Richardson-Lucy can only deblur 2D or 3D tensors')
50 |
51 | def _deconv_2d(self, image: torch.Tensor) -> torch.Tensor:
52 | """Implements Richardson-Lucy for 2D images
53 |
54 | :param image: 2D image tensor
55 | :return: 2D deblurred image
56 | """
57 | image_pad, psf_pad, padding = pad_2d(image, self.psf / torch.sum(self.psf), self.pad)
58 |
59 | psf_roll = torch.roll(psf_pad, [int(-psf_pad.shape[0] / 2),
60 | int(-psf_pad.shape[1] / 2)], dims=(0, 1))
61 | fft_psf = torch.fft.fft2(psf_roll)
62 | fft_psf_mirror = torch.fft.fft2(torch.flip(psf_roll, dims=[0, 1]))
63 |
64 | out_image = image_pad.detach().clone()
65 | for _ in range(self.niter):
66 | fft_out = torch.fft.fft2(out_image)
67 | fft_tmp = fft_out * fft_psf
68 | tmp = torch.real(torch.fft.ifft2(fft_tmp))
69 | tmp = image_pad / tmp
70 | fft_tmp = torch.fft.fft2(tmp)
71 | fft_tmp = fft_tmp * fft_psf_mirror
72 | tmp = torch.real(torch.fft.ifft2(fft_tmp))
73 | out_image = out_image * tmp
74 |
75 | if image_pad.shape != image.shape:
76 | return out_image[padding[0]:-padding[0], padding[1]:-padding[1]]
77 | return out_image
78 |
79 | def _deconv_3d(self, image: torch.Tensor) -> torch.Tensor:
80 | """Implements Richardson-Lucy for 3D images
81 |
82 | :param image: 3D image tensor
83 | :return: 3D deblurred image
84 | """
85 | image_pad, psf_pad, padding = pad_3d(image, self.psf / torch.sum(self.psf), self.pad)
86 |
87 | psf_roll = torch.roll(psf_pad, int(-psf_pad.shape[0] / 2), dims=0)
88 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[1] / 2), dims=1)
89 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[2] / 2), dims=2)
90 |
91 | fft_psf = torch.fft.fftn(psf_roll)
92 | fft_psf_mirror = torch.fft.fftn(torch.flip(psf_roll, dims=[0, 1]))
93 |
94 | out_image = image_pad.detach().clone()
95 | for _ in range(self.niter):
96 | fft_out = torch.fft.fftn(out_image)
97 | fft_tmp = fft_out * fft_psf
98 | tmp = torch.real(torch.fft.ifftn(fft_tmp))
99 | tmp = image_pad / tmp
100 | fft_tmp = torch.fft.fftn(tmp)
101 | fft_tmp = fft_tmp * fft_psf_mirror
102 | tmp = torch.real(torch.fft.ifftn(fft_tmp))
103 | out_image = out_image * tmp
104 |
105 | if image_pad.shape != image.shape:
106 | return out_image[padding[0]:-padding[0],
107 | padding[1]:-padding[1],
108 | padding[2]:-padding[2]]
109 | return out_image
110 |
111 |
112 | def srichardsonlucy(image: torch.Tensor,
113 | psf: torch.Tensor,
114 | niter: int = 30,
115 | pad: int | tuple[int, int] | tuple[int, int, int] = 0
116 | ) -> torch.Tensor:
117 | """Convenient function to call the SRichardsonLucy using numpy array
118 |
119 | :param image: Image to deblur
120 | :param psf: Point spread function
121 | :param niter: Number of iterations
122 | :param pad: image padding size
123 | :return: the deblurred image
124 | """
125 | psf_ = np_to_torch(psf)
126 | image_ = np_to_torch(image)
127 | filter_ = SRichardsonLucy(psf_, niter, pad)
128 | if isinstance(image, np.ndarray):
129 | return filter_(image_)
130 | return filter_(image)
131 |
132 |
133 | metadata = {
134 | 'name': 'SRichardsonLucy',
135 | 'label': 'Richardson-Lucy',
136 | 'fnc': srichardsonlucy,
137 | 'inputs': {
138 | 'image': {
139 | 'type': 'Image',
140 | 'label': 'Image',
141 | 'help': 'Input image'
142 | },
143 | 'psf': {
144 | 'type': 'Image',
145 | 'label': 'PSF',
146 | 'help': 'Point Spread Function'
147 | },
148 | 'niter': {
149 | 'type': 'int',
150 | 'label': 'niter',
151 | 'help': 'Number of iterations',
152 | 'default': 30,
153 | 'range': (0, 999999)
154 | },
155 | 'pad': {
156 | 'type': 'int',
157 | 'label': 'Padding',
158 | 'help': 'Padding to avoid spectrum artifacts',
159 | 'default': 13,
160 | 'range': (0, 999999)
161 | }
162 | },
163 | 'outputs': {
164 | 'image': {
165 | 'type': 'Image',
166 | 'label': 'Richardson-Lucy'
167 | },
168 | }
169 | }
170 |
--------------------------------------------------------------------------------
/sdeconv/deconv/self_supervised_nn.py:
--------------------------------------------------------------------------------
1 | """This module implements self supervised deconvolution with Spitfire regularisation"""
2 | from pathlib import Path
3 | from timeit import default_timer as timer
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 |
8 | from .interface_nn import NNModule
9 | from ._datasets import SelfSupervisedPatchDataset
10 | from ._datasets import SelfSupervisedDataset
11 | from .spitfire import hv_loss
12 |
13 |
14 | class DeconSpitfireLoss(torch.nn.Module):
15 | """MSE LOSS with a (de)convolution filter and Spitfire regularisation
16 |
17 | :param psf_file: File containing the PSF for deconvolution
18 | :return: Loss tensor
19 | """
20 | def __init__(self,
21 | psf: torch.Tensor,
22 | regularization: float = 1e-3,
23 | weighting: float = 0.6
24 | ):
25 | super().__init__()
26 | self.__psf = psf
27 | self.regularization = regularization
28 | self.weighting = weighting
29 |
30 | if self.__psf.ndim > 2:
31 | raise ValueError('DeconMSE PSF must be a gray scaled 2D image')
32 |
33 | self.__psf = self.__psf.view((1, 1, *self.__psf.shape))
34 | print('psf shape=', self.__psf.shape)
35 | self.__conv_op = torch.nn.Conv2d(1, 1,
36 | kernel_size=self.__psf.shape[2],
37 | stride=1,
38 | padding=int((self.__psf.shape[2] - 1) / 2),
39 | bias=False)
40 | with torch.no_grad():
41 | self.__conv_op.weight = torch.nn.Parameter(self.__psf, requires_grad=False)
42 | self.__conv_op.requires_grad_(False)
43 |
44 | def forward(self, input_image: torch.Tensor, target: torch.Tensor):
45 | """Deconvolution L2 data-term
46 |
47 | Compute the L2 error between the original image (input) and the
48 | convoluted reconstructed image (target)
49 |
50 | :param input_image: Tensor of shape BCYX containing the original blurry image
51 | :param target: Tensor of shape BCYX containing the estimated deblurred image
52 | """
53 | conv_img = self.__conv_op(input_image)
54 | mse = torch.nn.MSELoss()
55 | return self.regularization*mse(target, conv_img) + \
56 | (1-self.regularization)*hv_loss(input_image, weighting=self.weighting)
57 |
58 |
59 | class SelfSupervisedNNDeconv(NNModule):
60 | """Deconvolution using a neural network trained using the Spitfire loss"""
61 | def fit(self,
62 | train_directory: Path,
63 | val_directory: Path,
64 | n_channel_in: int = 1,
65 | n_channels_layer: list[int] = (32, 64, 128),
66 | patch_size: int = 32,
67 | n_epoch: int = 25,
68 | learning_rate: float = 1e-3,
69 | out_dir: Path = None,
70 | weight: float = 0.9,
71 | reg: float = 0.95,
72 | psf: torch.Tensor = None
73 | ):
74 | """Train a model on a dataset
75 |
76 | :param train_directory: Directory containing the images used
77 | for training. One file per image,
78 | :param val_directory: Directory containing the images used for validation of
79 | the training. One file per image,
80 | :param psf: Point spread function for deconvolution,
81 | :param n_channel_in: Number of channels in the input images
82 | :param n_channels_layer: Number of channels for each hidden layers of the model,
83 | :param patch_size: Size of square patches used for training the model,
84 | :param n_epoch: Number of epochs,
85 | :param learning_rate: Adam optimizer learning rate
86 | """
87 | self._init_model(n_channel_in, n_channel_in, n_channels_layer)
88 | self._out_dir = out_dir
89 | self._loss_fn = DeconSpitfireLoss(psf.to(self.device()), reg, weight)
90 | self._optimizer = torch.optim.Adam(self._model.parameters(), lr=learning_rate)
91 | train_dataset = SelfSupervisedPatchDataset(train_directory,
92 | patch_size=patch_size,
93 | stride=int(patch_size/2))
94 | val_dataset = SelfSupervisedDataset(val_directory)
95 | self._train_data_loader = DataLoader(train_dataset,
96 | batch_size=300,
97 | shuffle=True,
98 | drop_last=True,
99 | num_workers=0)
100 | self._val_data_loader = DataLoader(val_dataset,
101 | batch_size=1,
102 | shuffle=False,
103 | drop_last=False,
104 | num_workers=0)
105 | self._train_loop(n_epoch)
106 |
107 | def _train_step(self):
108 | """Runs one step of training"""
109 | size = len(self._train_data_loader.dataset)
110 | self._model.train()
111 | step_loss = 0
112 | count_step = 0
113 | tic = timer()
114 | for batch, (x, _) in enumerate(self._train_data_loader):
115 | count_step += 1
116 |
117 | x = x.to(self.device())
118 |
119 | # Compute prediction error
120 | prediction = self._model(x)
121 | loss = self._loss_fn(prediction, x)
122 | step_loss += loss
123 |
124 | # Backpropagation
125 | self._optimizer.zero_grad()
126 | loss.backward()
127 | self._optimizer.step()
128 |
129 | # count time
130 | toc = timer()
131 | full_time = toc - tic
132 | total_batch = int(size / len(x))
133 | remains = full_time * (total_batch - (batch+1)) / (batch+1)
134 |
135 | self._after_train_batch({'loss': loss,
136 | 'id_batch': batch+1,
137 | 'total_batch': total_batch,
138 | 'remain_time': int(remains+0.5),
139 | 'full_time': int(full_time+0.5)
140 | })
141 | if count_step > 0:
142 | step_loss /= count_step
143 | self._current_loss = step_loss
144 | return {'train_loss': step_loss}
145 |
--------------------------------------------------------------------------------
/sdeconv/deconv/spitfire.py:
--------------------------------------------------------------------------------
1 | """Implements the Spitfire deconvolution algorithms for 2D and 3D images"""
2 | import torch
3 |
4 | from ..core import SSettings
5 | from .interface import SDeconvFilter
6 | from ._utils import pad_2d, pad_3d, unpad_3d, np_to_torch
7 |
8 |
9 | def hv_loss(img: torch.Tensor, weighting: float = 0.5) -> torch.Tensor:
10 | """Sparse Hessian regularization term
11 |
12 | :param img: Tensor of shape BCYX containing the estimated image
13 | :param weighting: Sparse weighting parameter in [0, 1]. 0 sparse, and 1 not sparse
14 | :return: the loss value in torch.Tensor
15 | """
16 | dxx2 = torch.square(-img[:, :, 2:, 1:-1] + 2 * img[:, :, 1:-1, 1:-1] - img[:, :, :-2, 1:-1])
17 | dyy2 = torch.square(-img[:, :, 1:-1, 2:] + 2 * img[:, :, 1:-1, 1:-1] - img[:, :, 1:-1, :-2])
18 | dxy2 = torch.square(img[:, :, 2:, 2:] - img[:, :, 2:, 1:-1] - img[:, :, 1:-1, 2:] +
19 | img[:, :, 1:-1, 1:-1])
20 | h_v = torch.sqrt(weighting * weighting * (dxx2 + dyy2 + 2 * dxy2) +
21 | (1 - weighting) * (1 - weighting) * torch.square(img[:, :, 1:-1, 1:-1]))
22 | return torch.mean(h_v)
23 |
24 |
25 | def hv_loss_3d(img: torch.Tensor,
26 | delta: float = 1,
27 | weighting: float = 0.5
28 | ) -> torch.Tensor:
29 | """Sparse Hessian regularization term
30 |
31 | :param img: Tensor of shape BCZYX containing the estimated image
32 | :param delta: Resolution delta between XY and Z
33 | :param weighting: Sparse weighting parameter in [0, 1]. 0 sparse, and 1 not sparse
34 | :return: the loss value in torch.Tensor
35 | """
36 | img_ = img[:, :, 1:-1, 1:-1, 1:-1]
37 | d11 = -img[:, :, 1:-1, 1:-1, 2:] + 2*img_ - img[:, :, 1:-1, 1:-1, :-2]
38 | d22 = -img[:, :, 1:-1, 2:, 1:-1] + 2*img_ - img[:, :, 1:-1, :-2, 1:-1]
39 | d33 = delta*delta*(-img[:, :, 2:, 1:-1, 1:-1] + 2*img_ - img[:, :, :-2, 1:-1, 1:-1])
40 | d12_d21 = img[:, :, 1:-1, 2:, 2:] - img[:, :, 1:-1, 1:-1, 2:] - img[:, :, 1:-1, 2:, 1:-1] + img_
41 | d13_d31 = delta*(img[:, :, 2:, 1:-1, 2:] - img[:, :, 1:-1, 1:-1, 2:]
42 | - img[:, :, 2:, 1:-1, 1:-1] + img_)
43 | d23_d32 = delta*(img[:, :, 2:, 2:, 1:-1] - img[:, :, 1:-1, 2:, 1:-1]
44 | - img[:, :, 2:, 1:-1, 1:-1] + img_)
45 |
46 | h_v = torch.square(weighting*d11) + torch.square(weighting*d22) + torch.square(
47 | weighting*d33) + 2 * torch.square(weighting*d12_d21) + 2 * torch.square(
48 | weighting*d13_d31) + 2 * torch.square(weighting*d23_d32) + torch.square((1-weighting)*img_)
49 |
50 | return torch.mean(torch.sqrt(h_v))
51 |
52 |
53 | def dataterm_deconv(blurry_image: torch.Tensor,
54 | deblurred_image: torch.Tensor,
55 | psf: torch.Tensor
56 | ) -> torch.Tensor:
57 | """Deconvolution L2 data-term
58 |
59 | Compute the L2 error between the original image and the convoluted reconstructed image
60 |
61 | :param blurry_image: Tensor of shape BCYX containing the original blurry image
62 | :param deblurred_image: Tensor of shape BCYX containing the estimated deblurred image
63 | :param psf: Tensor containing the point spread function
64 | :return: the loss value in torch.Tensor
65 | """
66 | conv_op = torch.nn.Conv2d(1, 1, kernel_size=psf.shape[2],
67 | stride=1,
68 | padding=int((psf.shape[2] - 1) / 2),
69 | bias=False)
70 | with torch.no_grad():
71 | conv_op.weight = torch.nn.Parameter(psf)
72 | mse = torch.nn.MSELoss()
73 | return mse(blurry_image, conv_op(deblurred_image))
74 |
75 |
76 | class DataTermDeconv3D(torch.autograd.Function):
77 | """Deconvolution L2 data term
78 |
79 | This class manually implement the 3D data term backward
80 | """
81 | @staticmethod
82 | def forward(ctx,
83 | deblurred_image: torch.Tensor,
84 | blurry_image: torch.Tensor,
85 | fft_blurry_image: torch.Tensor,
86 | fft_psf: torch.Tensor,
87 | adjoint_otf: torch.Tensor
88 | ) -> torch.Tensor:
89 | """Pytorch forward method
90 |
91 | :param deblurred_image: Candidate deblurred image
92 | :param blurry_image: Original image
93 | :param fft_blurry_image: FFT of the original image
94 | :param fft_psf: FFT of the Point Spread Function
95 | :param adjoint_otf: Fourier adjoint of PSF
96 | :return: The loss value
97 | """
98 | fft_deblurred_image = torch.fft.fftn(deblurred_image)
99 | ctx.save_for_backward(deblurred_image, fft_deblurred_image, fft_blurry_image, fft_psf,
100 | adjoint_otf)
101 |
102 | conv_deblured_image = torch.real(torch.fft.ifftn(fft_deblurred_image * fft_psf))
103 | mse = torch.nn.MSELoss()
104 | return mse(blurry_image, conv_deblured_image)
105 |
106 | @staticmethod
107 | def backward(ctx, grad_output):
108 | """Pytorch backward method"""
109 | deblurred_image, fft_deblurred_image, fft_blurry_image, \
110 | fft_psf, adjoint_otf = ctx.saved_tensors
111 |
112 | real_tmp = fft_psf.real * fft_deblurred_image.real - \
113 | fft_psf.imag * fft_deblurred_image.imag - fft_blurry_image.real
114 | imag_tmp = fft_psf.real * fft_deblurred_image.imag + \
115 | fft_psf.imag * fft_deblurred_image.real - fft_blurry_image.imag
116 |
117 | residue_image_real = adjoint_otf.real * real_tmp - adjoint_otf.imag * imag_tmp
118 | residue_image_imag = adjoint_otf.real * imag_tmp + adjoint_otf.imag * real_tmp
119 |
120 | grad_ = torch.real(torch.fft.ifftn(
121 | torch.complex(residue_image_real,
122 | residue_image_imag))) / torch.numel(deblurred_image)
123 | return grad_output * grad_, None, None, None, None
124 |
125 |
126 | class Spitfire(SDeconvFilter):
127 | """Variational deconvolution using the Spitfire algorithm
128 |
129 | :param psf: Point spread function
130 | :param weight: model weight between hessian and sparsity. Value is in ]0, 1[
131 | :param delta: For 3D images resolution delta between xy and z
132 | :param reg: Regularization weight. Value is in [0, 1]
133 | :param gradient_step: Gradient descent step
134 | :param precision: Stop criterion. Stop gradient descent when the
135 | loss decrease less than precision
136 | :param pad: Image padding to avoid Fourier artefacts
137 | """
138 | def __init__(self,
139 | psf: torch.Tensor,
140 | weight: float = 0.6,
141 | delta: float = 1,
142 | reg: float = 0.995,
143 | gradient_step: float = 0.01,
144 | precision: float = 1e-7,
145 | pad: int | tuple[int, int] | tuple[int, int, int] = 0):
146 | super().__init__()
147 | self.psf = psf.to(SSettings.instance().device)
148 | self.weight = weight
149 | self.reg = reg
150 | self.precision = precision
151 | self.delta = delta
152 | self.pad = pad
153 | self.niter_ = 0
154 | self.max_iter_ = 2500
155 | self.gradient_step_ = gradient_step
156 | self.loss_ = None
157 |
158 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
159 | """Run the Spitfire deconvolution
160 |
161 | :param image: Blurry image for a single channel time point [(Z) Y X]
162 | :return: deblurred image [(Z) Y X]
163 | """
164 | if image.ndim == 2:
165 | return self.run_2d(image)
166 | if image.ndim == 3:
167 | return self.run_3d(image)
168 | raise ValueError("Spitfire can process only 2D or 3D tensors")
169 |
170 | def run_2d(self, image: torch.Tensor) -> torch.Tensor:
171 | """Implements Spitfire for 2D images
172 |
173 | :param image: Blurry 2D image tensor
174 | :return: Deblurred image (2D torch.Tensor)
175 | """
176 | self.progress(0)
177 | mini = torch.min(image) + 1e-5
178 | maxi = torch.max(image)
179 | image = (image-mini)/(maxi-mini)
180 |
181 | image_pad, psf_pad, padding = pad_2d(image, self.psf/torch.sum(self.psf), self.pad)
182 | psf_pad = self.psf/torch.sum(self.psf)
183 |
184 | image_pad = image_pad.view(1, 1, image_pad.shape[0], image_pad.shape[1])
185 | psf_pad = psf_pad.view(1, 1, psf_pad.shape[0], psf_pad.shape[1])
186 | deconv_image = image_pad.detach().clone()
187 | deconv_image.requires_grad = True
188 | optimizer = torch.optim.Adam([deconv_image], lr=self.gradient_step_)
189 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
190 | previous_loss = 9e12
191 | count_eq = 0
192 | self.niter_ = 0
193 | loss = 0
194 |
195 | for i in range(self.max_iter_):
196 | self.progress(int(100*i/self.max_iter_))
197 | self.niter_ += 1
198 | optimizer.zero_grad()
199 | loss = self.reg * dataterm_deconv(image_pad, deconv_image, psf_pad) + \
200 | (1-self.reg) * hv_loss(deconv_image, self.weight)
201 | #print('iter:', self.niter_, ' loss:', loss.item())
202 | if abs(loss - previous_loss) < self.precision:
203 | count_eq += 1
204 | else:
205 | previous_loss = loss
206 | count_eq = 0
207 | if count_eq > 5:
208 | break
209 | loss.backward()
210 | optimizer.step()
211 | scheduler.step()
212 | self.loss_ = loss
213 | self.progress(100)
214 | deconv_image = (maxi-mini)*deconv_image + mini
215 | deconv_image = deconv_image.view(deconv_image.shape[2], deconv_image.shape[3])
216 | if image_pad.shape[2] != image.shape[0] and image_pad.shape[3] != image.shape[1]:
217 | return deconv_image[padding[0]: -padding[0], padding[1]: -padding[1]]
218 | return deconv_image
219 |
220 | @staticmethod
221 | def otf_3d(psf: torch.Tensor) -> torch.Tensor:
222 | """Calculate the OTF of a PSF
223 |
224 | :param psf: 3D Point Spread Function
225 | :return: the OTF in a torch.Tensor 3D
226 | """
227 | psf_roll = torch.roll(psf, int(-psf.shape[0] / 2), dims=0)
228 | psf_roll = torch.roll(psf_roll, int(-psf.shape[1] / 2), dims=1)
229 | psf_roll = torch.roll(psf_roll, int(-psf.shape[2] / 2), dims=2)
230 | psf_roll.view(1, psf.shape[0], psf.shape[1], psf.shape[2])
231 | fft_psf = torch.fft.fftn(psf_roll)
232 | return fft_psf
233 |
234 | @staticmethod
235 | def adjoint_otf(psf: torch.Tensor) -> torch.Tensor:
236 | """Calculate the adjoint OTF of a PSF
237 |
238 | :param psf: 3D Point Spread Function
239 | :return: the OTF in a torch.Tensor 3D
240 | """
241 | adjoint_psf = torch.flip(psf, [0, 1, 2])
242 | adjoint_psf = torch.roll(adjoint_psf, -int(psf.shape[0] - 1) % 2, dims=0)
243 | adjoint_psf = torch.roll(adjoint_psf, -int(psf.shape[1] - 1) % 2, dims=1)
244 | adjoint_psf = torch.roll(adjoint_psf, -int(psf.shape[2] - 1) % 2, dims=2)
245 |
246 | adjoint_psf = torch.roll(adjoint_psf, int(-psf.shape[0] / 2), dims=0)
247 | adjoint_psf = torch.roll(adjoint_psf, int(-psf.shape[1] / 2), dims=1)
248 | adjoint_psf = torch.roll(adjoint_psf, int(-psf.shape[2] / 2), dims=2)
249 | return torch.fft.fftn(adjoint_psf)
250 |
251 | def run_3d(self, image: torch.Tensor) -> torch.Tensor:
252 | """Implements Spitfire for 3D images
253 |
254 | :param image: Blurry 2D image tensor
255 | :return: Deblurred image (2D torch.Tensor)
256 | """
257 | self.progress(0)
258 | mini = torch.min(image) + 1e-5
259 | maxi = torch.max(image)
260 | image = (image-mini)/(maxi-mini)
261 | image_pad, psf_pad, padding = pad_3d(image, self.psf / torch.sum(self.psf), self.pad)
262 |
263 | deconv_image = image_pad.detach().clone()
264 | image_pad = image_pad.view(1, 1, image_pad.shape[0], image_pad.shape[1], image_pad.shape[2])
265 | deconv_image = deconv_image.view(1, 1, deconv_image.shape[0],
266 | deconv_image.shape[1], deconv_image.shape[2])
267 | deconv_image.requires_grad = True
268 | optimizer = torch.optim.Adam([deconv_image], lr=self.gradient_step_)
269 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
270 | previous_loss = 9e12
271 | count_eq = 0
272 | self.niter_ = 0
273 |
274 | fft_psf = Spitfire.otf_3d(psf_pad)
275 | adjoint_otf = Spitfire.adjoint_otf(psf_pad)
276 | fft_image = torch.fft.fftn(image_pad)
277 | dataterm_ = DataTermDeconv3D.apply
278 | loss = 0
279 | for i in range(self.max_iter_):
280 | self.progress(int(100*i/self.max_iter_))
281 | self.niter_ += 1
282 | optimizer.zero_grad()
283 | loss = self.reg * dataterm_(deconv_image, image_pad,
284 | fft_image, fft_psf, adjoint_otf) + (
285 | 1 - self.reg) * hv_loss_3d(deconv_image, self.delta, self.weight)
286 | print('iter:', self.niter_, ' loss:', loss.item())
287 | if loss > previous_loss:
288 | break
289 | if abs(loss - previous_loss) < self.precision:
290 | count_eq += 1
291 | else:
292 | previous_loss = loss
293 | count_eq = 0
294 | if count_eq > 5:
295 | break
296 | loss.backward()
297 | optimizer.step()
298 | scheduler.step()
299 | self.loss_ = loss
300 | self.progress(100)
301 | deconv_image = deconv_image.view(image_pad.shape[2],
302 | image_pad.shape[3],
303 | image_pad.shape[4])
304 | deconv_image = (maxi-mini)*deconv_image + mini
305 | if image_pad.shape[2] != image.shape[0] and image_pad.shape[3] != image.shape[1] and \
306 | image_pad.shape[4] != image.shape[2]:
307 | return unpad_3d(deconv_image, padding)
308 | return deconv_image
309 |
310 |
311 | def spitfire(image: torch.Tensor,
312 | psf: torch.Tensor,
313 | weight: float = 0.6,
314 | delta: float = 1,
315 | reg: float = 0.995,
316 | gradient_step: float = 0.01,
317 | precision: float = 1e-7,
318 | pad: int | tuple[int, int] | tuple[int, int, int] = 13,
319 | observers=None
320 | ) -> torch.Tensor:
321 | """Convenient function to call Spitfire with numpy
322 |
323 | :param image: Image to deblur
324 | :param psf: Point spread function
325 | :param weight: model weight between hessian and sparsity. Value is in ]0, 1[
326 | :param delta: For 3D images resolution delta between xy and z
327 | :param reg: Regularization weight. Value is in [0, 1]
328 | :param gradient_step: Gradient descent step
329 | :param precision: Stop criterion. Stop gradient descent when the loss decrease less than
330 | precision
331 | :param pad: Image padding to avoid Fourier artefacts
332 | :return: the deblurred image
333 | """
334 | if observers is None:
335 | observers = []
336 |
337 | psf_ = np_to_torch(psf)
338 | filter_ = Spitfire(psf_, weight, delta, reg, gradient_step, precision, pad)
339 | for observer in observers:
340 | filter_.add_observer(observer)
341 | return filter_(np_to_torch(image))
342 |
343 |
344 | metadata = {
345 | 'name': 'Spitfire',
346 | 'label': 'Spitfire',
347 | 'fnc': spitfire,
348 | 'inputs': {
349 | 'image': {
350 | 'type': 'Image',
351 | 'label': 'Image',
352 | 'help': 'Input image'
353 | },
354 | 'psf': {
355 | 'type': 'Image',
356 | 'label': 'PSF',
357 | 'help': 'Point Spread Function'
358 | },
359 | 'weight': {
360 | 'type': 'float',
361 | 'label': 'weight',
362 | 'help': 'Model weight between hessian and sparsity. Value is in ]0, 1[',
363 | 'default': 0.6,
364 | 'range': (0, 1)
365 | },
366 | 'reg': {
367 | 'type': 'float',
368 | 'label': 'Regularization',
369 | 'help': 'Regularization weight. Value is in [0, 1]',
370 | 'default': 0.995,
371 | 'range': (0, 1)
372 | },
373 | 'delta': {
374 | 'type': 'float',
375 | 'label': 'Delta',
376 | 'help': 'For 3D images resolution delta between xy and z',
377 | 'default': 1,
378 | 'advanced': True,
379 | },
380 | 'gradient_step': {
381 | 'type': 'float',
382 | 'label': 'Gradient Step',
383 | 'help': 'Step for ADAM gradient descente optimization',
384 | 'default': 0.01,
385 | 'advanced': True,
386 | },
387 | 'precision': {
388 | 'type': 'float',
389 | 'label': 'Precision',
390 | 'help': 'Stop criterion. Stop gradient descent when the loss '
391 | 'decrease less than precision',
392 | 'default': 1e-7,
393 | 'advanced': True,
394 | },
395 | 'pad': {
396 | 'type': 'int',
397 | 'label': 'Padding',
398 | 'help': 'Padding to avoid spectrum artifacts',
399 | 'default': 13,
400 | 'range': (0, 999999),
401 | 'advanced': True
402 | }
403 | },
404 | 'outputs': {
405 | 'image': {
406 | 'type': 'Image',
407 | 'label': 'Spitfire'
408 | },
409 | }
410 | }
411 |
--------------------------------------------------------------------------------
/sdeconv/deconv/wiener.py:
--------------------------------------------------------------------------------
1 | """Implements the Wiener deconvolution for 2D and 3D images"""
2 | import torch
3 | from sdeconv.core import SSettings
4 | from .interface import SDeconvFilter
5 | from ._utils import pad_2d, pad_3d, unpad_3d, np_to_torch
6 |
7 |
8 | def laplacian_2d(shape: tuple[int, int]) -> torch.Tensor:
9 | """Define the 2D laplacian matrix
10 |
11 | :param shape: 2D image shape
12 | :return: torch.Tensor with size defined by shape and laplacian coefficients at it center
13 | """
14 | image = torch.zeros(shape).to(SSettings.instance().device)
15 |
16 | x_c = int(shape[0] / 2)
17 | y_c = int(shape[1] / 2)
18 |
19 | image[x_c, y_c] = 4
20 | image[x_c, y_c - 1] = -1
21 | image[x_c, y_c + 1] = -1
22 | image[x_c - 1, y_c] = -1
23 | image[x_c + 1, y_c] = -1
24 | return image
25 |
26 |
27 | def laplacian_3d(shape: tuple[int, int]) -> torch.Tensor:
28 | """Define the 3D laplacian matrix
29 |
30 | :param shape: 3D image shape
31 | :return: torch.Tensor with size defined by shape and laplacian coefficients at it center
32 | """
33 | image = torch.zeros(shape).to(SSettings.instance().device)
34 |
35 | x_c = int(shape[2] / 2)
36 | y_c = int(shape[1] / 2)
37 | z_c = int(shape[0] / 2)
38 |
39 | image[z_c, y_c, x_c] = 6
40 | image[z_c - 1, y_c, x_c] = -1
41 | image[z_c + 1, y_c, x_c] = -1
42 | image[z_c, y_c - 1, x_c] = -1
43 | image[z_c, y_c + 1, x_c] = -1
44 | image[z_c, y_c, x_c - 1] = -1
45 | image[z_c, y_c, x_c - 1] = -1
46 | return image
47 |
48 |
49 | class SWiener(SDeconvFilter):
50 | """Apply a Wiener deconvolution
51 |
52 | :param psf: Point spread function
53 | :param beta: Regularisation weight
54 | :param pad: Padding in each dimension
55 | """
56 | def __init__(self,
57 | psf: torch.Tensor,
58 | beta: float = 1e-5,
59 | pad: int | tuple[int, int] | tuple[int, int, int] = 0):
60 | super().__init__()
61 | self.psf = psf
62 | self.beta = beta
63 | self.pad = pad
64 |
65 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
66 | """Run the Wiener deconvolution
67 |
68 | :param image: Blurry image for a single channel time point [(Z) Y X]
69 | :return: deblurred image [(Z) Y X]
70 | """
71 | if image.ndim == 2:
72 | return self._wiener_2d(image)
73 | if image.ndim == 3:
74 | return self._wiener_3d(image)
75 | raise ValueError('Wiener can only deblur 2D or 3D tensors')
76 |
77 | def _wiener_2d(self, image: torch.Tensor) -> torch.Tensor:
78 | """Compute the 2D wiener deconvolution
79 |
80 | :param image: 2D image tensor
81 | :return: torch.Tensor of the 2D deblurred image
82 | """
83 | image_pad, psf_pad, padding = pad_2d(image, self.psf / torch.sum(self.psf), self.pad)
84 |
85 | fft_source = torch.fft.fft2(image_pad)
86 | psf_roll = torch.roll(psf_pad, int(-psf_pad.shape[0] / 2), dims=0)
87 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[1] / 2), dims=1)
88 | fft_psf = torch.fft.fft2(psf_roll)
89 | fft_laplacian = torch.fft.fft2(laplacian_2d(image_pad.shape))
90 |
91 | den = fft_psf * torch.conj(fft_psf) + self.beta * fft_laplacian * torch.conj(fft_laplacian)
92 | out_image = torch.real(torch.fft.ifftn((fft_source * torch.conj(fft_psf)) / den))
93 | if image_pad.shape != image.shape:
94 | return out_image[padding[0]: -padding[0], padding[1]: -padding[1]]
95 | return out_image
96 |
97 | def _wiener_3d(self, image: torch.Tensor) -> torch.Tensor:
98 | """Compute the 3D wiener deconvolution
99 |
100 | :param image: 2D image tensor
101 | :return: torch.Tensor of the 2D deblurred image
102 | """
103 | image_pad, psf_pad, padding = pad_3d(image, self.psf / torch.sum(self.psf), self.pad)
104 |
105 | fft_source = torch.fft.fftn(image_pad)
106 | psf_roll = torch.roll(psf_pad, int(-psf_pad.shape[0] / 2), dims=0)
107 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[1] / 2), dims=1)
108 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[2] / 2), dims=2)
109 |
110 | fft_psf = torch.fft.fftn(psf_roll)
111 | fft_laplacian = torch.fft.fftn(laplacian_3d(image_pad.shape))
112 |
113 | den = fft_psf * torch.conj(fft_psf) + self.beta * fft_laplacian * torch.conj(fft_laplacian)
114 | out_image = torch.real(torch.fft.ifftn((fft_source * torch.conj(fft_psf)) / den))
115 | if image_pad.shape != image.shape:
116 | return unpad_3d(out_image, padding)
117 | return out_image
118 |
119 |
120 | def swiener(image: torch.Tensor,
121 | psf: torch.Tensor,
122 | beta: float = 1e-5,
123 | pad: int | tuple[int, int] | tuple[int, int, int] = 0
124 | ):
125 | """Convenient function to call the SWiener on numpy array
126 |
127 | :param image: Image to deblur,
128 | :param psf: Point spread function,
129 | :param beta: Regularisation weight,
130 | :param pad: Padding in each dimension,
131 | :return: the deblurred image
132 | """
133 | psf_ = np_to_torch(psf)
134 | filter_ = SWiener(psf_, beta, pad)
135 | return filter_(np_to_torch(image))
136 |
137 |
138 | metadata = {
139 | 'name': 'SWiener',
140 | 'label': 'Wiener',
141 | 'fnc': swiener,
142 | 'inputs': {
143 | 'image': {
144 | 'type': 'Image',
145 | 'label': 'Image',
146 | 'help': 'Input image'
147 | },
148 | 'psf': {
149 | 'type': 'Image',
150 | 'label': 'PSF',
151 | 'help': 'Point Spread Function'
152 | },
153 | 'beta': {
154 | 'type': 'float',
155 | 'label': 'Beta',
156 | 'help': 'Regularisation parameter',
157 | 'default': 1e-5,
158 | 'range': (0, 999999)
159 | },
160 | 'pad': {
161 | 'type': 'int',
162 | 'label': 'Padding',
163 | 'help': 'Padding to avoid spectrum artifacts',
164 | 'default': 13,
165 | 'range': (0, 999999)
166 | }
167 | },
168 | 'outputs': {
169 | 'image': {
170 | 'type': 'Image',
171 | 'label': 'Wiener'
172 | },
173 | }
174 | }
175 |
--------------------------------------------------------------------------------
/sdeconv/psfs/__init__.py:
--------------------------------------------------------------------------------
1 | """Module that implements Point Spread Function generators"""
2 | from .interface import SPSFGenerator
3 | from .gaussian import SPSFGaussian, spsf_gaussian
4 | from .gibson_lanni import SPSFGibsonLanni, spsf_gibson_lanni
5 | from .lorentz import SPSFLorentz, spsf_lorentz
6 |
7 | __all__ = ['SPSFGenerator',
8 | 'SPSFGaussian',
9 | 'spsf_gaussian',
10 | 'SPSFGibsonLanni',
11 | 'spsf_gibson_lanni',
12 | 'SPSFLorentz',
13 | 'spsf_lorentz']
14 |
--------------------------------------------------------------------------------
/sdeconv/psfs/gaussian.py:
--------------------------------------------------------------------------------
1 | """Implements the Gaussian Point Spread Function generator"""
2 | import math
3 | import torch
4 | from sdeconv.core import SSettings
5 | from .interface import SPSFGenerator
6 |
7 |
8 | class SPSFGaussian(SPSFGenerator):
9 | """Generate a Gaussian PSF
10 |
11 | :param sigma: width of the PSF. [Z, Y, X] in 3D, [Y, X] in 2D
12 | :param shape: Shape of the PSF support image. [Z, Y, X] in 3D, [Y, X] in 2D
13 | """
14 | def __init__(self,
15 | sigma: tuple[float, float] | tuple[float, float, float],
16 | shape: tuple[int, int] | tuple[int, int, int]):
17 | super().__init__()
18 | self.sigma = sigma
19 | self.shape = shape
20 | self.psf_ = None
21 |
22 | @staticmethod
23 | def _normalize_inputs(sigma: tuple[float, float] | tuple[float, float, float],
24 | shape: tuple[int, int, int] | tuple[int, int, int]
25 | ) -> tuple:
26 | """Remove batch dimention if it exists
27 |
28 | :param sigma: Width of the PSF
29 | :param shape: Shape of the PSF
30 | :return: The modified sigma and shape
31 | """
32 | if len(shape) == 3 and shape[0] == 1:
33 | return sigma[1:], shape[1:]
34 | return sigma, shape
35 |
36 | def __call__(self) -> torch.Tensor:
37 | """Calculate the PSF image
38 |
39 | :return: The PSF image in a Tensor
40 | """
41 | self.sigma, self.shape = SPSFGaussian._normalize_inputs(self.sigma, self.shape)
42 | if len(self.shape) == 2:
43 | self.psf_ = torch.zeros((self.shape[0], self.shape[1])).to(SSettings.instance().device)
44 | x_0 = math.floor(self.shape[0] / 2)
45 | y_0 = math.floor(self.shape[1] / 2)
46 | # print('center= (', x0, ', ', y0, ')')
47 | sigma_x2 = 0.5 / (self.sigma[0] * self.sigma[0])
48 | sigma_y2 = 0.5 / (self.sigma[1] * self.sigma[1])
49 |
50 | xx_, yy_ = torch.meshgrid(torch.arange(0, self.shape[0]),
51 | torch.arange(0, self.shape[1]),
52 | indexing='ij')
53 | self.psf_ = torch.exp(- torch.pow(xx_ - x_0, 2) * sigma_x2
54 | - torch.pow(yy_ - y_0, 2) * sigma_y2)
55 | self.psf_ = self.psf_ / torch.sum(self.psf_)
56 | elif len(self.shape) == 3:
57 | self.psf_ = torch.zeros(self.shape).to(SSettings.instance().device)
58 | x_0 = math.floor(self.shape[2] / 2)
59 | y_0 = math.floor(self.shape[1] / 2)
60 | z_0 = math.floor(self.shape[0] / 2)
61 | sigma_x2 = 0.5 / (self.sigma[2] * self.sigma[2])
62 | sigma_y2 = 0.5 / (self.sigma[1] * self.sigma[1])
63 | sigma_z2 = 0.5 / (self.sigma[0] * self.sigma[0])
64 |
65 | zzz, yyy, xxx = torch.meshgrid(torch.arange(0, self.shape[0]),
66 | torch.arange(0, self.shape[1]),
67 | torch.arange(0, self.shape[2]),
68 | indexing='ij')
69 | self.psf_ = torch.exp(- torch.pow(xxx - x_0, 2) * sigma_x2
70 | - torch.pow(yyy - y_0, 2) * sigma_y2
71 | - torch.pow(zzz - z_0, 2) * sigma_z2)
72 |
73 | self.psf_ = self.psf_ / torch.sum(self.psf_)
74 | else:
75 | raise ValueError('PSFGaussian: can generate only 2D or 3D PSFs')
76 | return self.psf_
77 |
78 |
79 | def spsf_gaussian(sigma: tuple[float, float] | tuple[float, float, float],
80 | shape: tuple[int, int] | tuple[int, int, int]
81 | ) -> torch.Tensor:
82 | """Function to generate a Gaussian PSF
83 |
84 | :param sigma: width of the PSF. [Z, Y, X] in 3D, [Y, X] in 2D,
85 | :param shape: Shape of the PSF support image. [Z, Y, X] in 3D, [Y, X] in 2D,
86 | :return: The PSF image
87 | """
88 | filter_ = SPSFGaussian(sigma, shape)
89 | return filter_()
90 |
91 |
92 | metadata = {
93 | 'name': 'SPSFGaussian',
94 | 'label': 'Gaussian PSF',
95 | 'fnc': spsf_gaussian,
96 | 'inputs': {
97 | 'sigma': {
98 | 'type': 'zyx_float',
99 | 'label': 'Sigma',
100 | 'help': 'Gaussian standard deviation in each direction',
101 | 'default': [0, 1.5, 1.5]
102 | },
103 | 'shape': {
104 | 'type': 'zyx_int',
105 | 'label': 'Size',
106 | 'help': 'PSF image shape',
107 | 'default': [1, 128, 128]
108 | }
109 | },
110 | 'outputs': {
111 | 'image': {
112 | 'type': 'Image',
113 | 'label': 'PSF Gaussian'
114 | },
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/sdeconv/psfs/gibson_lanni.py:
--------------------------------------------------------------------------------
1 | """Implementation of Gibson Lanni Point Spread Function model
2 |
3 | This implementation is an adaptation of
4 | https://kmdouglass.github.io/posts/implementing-a-fast-gibson-lanni-psf-solver-in-python/
5 |
6 | """
7 | from math import sqrt
8 | import numpy as np
9 | import scipy.special
10 | from scipy.interpolate import interp1d
11 | import torch
12 |
13 |
14 | from sdeconv.core import SSettings
15 | from .interface import SPSFGenerator
16 |
17 |
18 | class SPSFGibsonLanni(SPSFGenerator):
19 | """Generate a Gibson-Lanni PSF
20 |
21 | :param shape: Size of the PSF array in each dimension [(Z), Y, X],
22 | :param NA: Numerical aperture,
23 | :param wavelength: Wavelength in microns,
24 | :param M: Magnification,
25 | :param ns: Specimen refractive index (RI),
26 | :param ng0: Coverslip RI design value,
27 | :param ng: Coverslip RI experimental value,
28 | :param ni0: Immersion medium RI design value,
29 | :param ni: Immersion medium RI experimental value,
30 | :param ti0: microns, working distance (immersion medium thickness) design value,
31 | :param tg0: microns, coverslip thickness design value,
32 | :param tg: microns, coverslip thickness experimental value,
33 | :param res_lateral: Lateral resolution in microns,
34 | :param res_axial: Axial resolution in microns,
35 | :param pZ: microns, particle distance from coverslip
36 | :param use_square: If true, calculate the square of the Gibson-Lanni model to simulate a
37 | pinhole. It then gives a PSF for a confocal image
38 | """
39 | def __init__(self,
40 | shape: tuple[int, int] | tuple[int, int, int],
41 | NA: float = 1.4,
42 | wavelength: float = 0.610,
43 | M: float = 100,
44 | ns: float = 1.33,
45 | ng0: float = 1.5,
46 | ng: float = 1.5,
47 | ni0: float = 1.5,
48 | ni: float = 1.5,
49 | ti0: float = 150,
50 | tg0: float = 170,
51 | tg: float = 170,
52 | res_lateral: float = 0.1,
53 | res_axial: float = 0.25,
54 | pZ: float = 0,
55 | use_square: bool = False):
56 | super().__init__()
57 | self.shape = shape
58 |
59 | # Microscope parameters
60 | self.NA = NA
61 | self.wavelength = wavelength
62 | self.M = M
63 | self.ns = ns
64 | self.ng0 = ng0
65 | self.ng = ng
66 | self.ni0 = ni0
67 | self.ni = ni
68 | self.ti0 = ti0
69 | self.tg0 = tg0
70 | self.tg = tg
71 | self.res_lateral = res_lateral
72 | self.res_axial = res_axial
73 | self.pZ = pZ
74 | self.use_square = use_square
75 | # output
76 | self.psf_ = None
77 |
78 | def __call__(self) -> torch.Tensor:
79 | """Calculate the PSF
80 |
81 | :return: The PSF image as a Tensor
82 | """
83 | # Precision control
84 | num_basis = 100 # Number of rescaled Bessels that approximate the phase function
85 | num_samples = 1000 # Number of pupil samples along radial direction
86 | oversampling = 2 # Defines the sampling ratio on the image space grid for computations
87 |
88 | size_x = self.shape[2]
89 | size_y = self.shape[1]
90 | size_z = self.shape[0]
91 | min_wavelength = 0.436 # microns
92 | scaling_factor = (
93 | self.NA * (3 * np.arange(1, num_basis + 1) - 2) * min_wavelength / self.wavelength)
94 |
95 | # Place the origin at the center of the final PSF array
96 | x0 = (size_x - 1) / 2
97 | y0 = (size_y - 1) / 2
98 | # Find the maximum possible radius coordinate of the PSF array by finding the distance
99 | # from the center of the array to a corner
100 | max_radius = round(sqrt((size_x - x0) * (size_x - x0) + (size_y - y0) * (size_y - y0))) + 1
101 | # Radial coordinates, image space
102 | r = self.res_lateral * np.arange(0, oversampling * max_radius) / oversampling
103 | # Radial coordinates, pupil space
104 | a = min([self.NA, self.ns, self.ni, self.ni0, self.ng, self.ng0]) / self.NA
105 | rho = np.linspace(0, a, num_samples)
106 | # Stage displacements away from best focus
107 | z = self.res_axial * np.arange(-size_z / 2, size_z / 2) + self.res_axial / 2
108 |
109 | # Define the wavefront aberration
110 | OPDs = self.pZ * np.sqrt(self.ns * self.ns - self.NA * self.NA * rho * rho)
111 | OPDi = (z.reshape(-1, 1) + self.ti0) * np.sqrt(self.ni * self.ni - self.NA * self.NA * rho * rho) - self.ti0 * np.sqrt(
112 | self.ni0 * self.ni0 - self.NA * self.NA * rho * rho) # OPD in the immersion medium
113 | OPDg = self.tg * np.sqrt(self.ng * self.ng - self.NA * self.NA * rho * rho) - self.tg0 * np.sqrt(
114 | self.ng0 * self.ng0 - self.NA * self.NA * rho * rho) # OPD in the coverslip
115 | W = 2 * np.pi / self.wavelength * (OPDs + OPDi + OPDg)
116 |
117 | # Sample the phase
118 | # Shape is (number of z samples by number of rho samples)
119 | phase = np.cos(W) + 1j * np.sin(W)
120 |
121 | # Define the basis of Bessel functions
122 | # Shape is (number of basis functions by number of rho samples)
123 | J = scipy.special.jv(0, scaling_factor.reshape(-1, 1) * rho)
124 |
125 | # Compute the approximation to the sampled pupil phase by finding the least squares
126 | # solution to the complex coefficients of the Fourier-Bessel expansion.
127 | # Shape of C is (number of basis functions by number of z samples).
128 | # Note the matrix transposes to get the dimensions correct.
129 | C, residuals, _, _ = np.linalg.lstsq(J.T, phase.T, rcond=None)
130 |
131 | # compute the PSF
132 | b = 2 * np. pi * r.reshape(-1, 1) * self.NA / self.wavelength
133 |
134 | # Convenience functions for J0 and J1 Bessel functions
135 | J0 = lambda x: scipy.special.jv(0, x)
136 | J1 = lambda x: scipy.special.jv(1, x)
137 |
138 | # See equation 5 in Li, Xue, and Blu
139 | denom = scaling_factor * scaling_factor - b * b
140 | R = scaling_factor * J1(scaling_factor * a) * J0(b * a) * a - b * J0(scaling_factor * a) * J1(b * a) * a
141 | R /= denom
142 |
143 | # The transpose places the axial direction along the first dimension of the array, i.e. rows
144 | # This is only for convenience.
145 | PSF_rz = (np.abs(R.dot(C)) ** 2).T
146 |
147 | # Normalize to the maximum value
148 | PSF_rz /= np.max(PSF_rz)
149 |
150 | # cartesian PSF
151 | # Create the fleshed-out xy grid of radial distances from the center
152 | xy = np.mgrid[0:size_y, 0:size_x]
153 | r_pixel = np.sqrt((xy[1] - x0) * (xy[1] - x0) + (xy[0] - y0) * (xy[0] - y0)) * self.res_lateral
154 |
155 | self.psf_ = np.zeros((size_z, size_y, size_x))
156 |
157 | for z_index in range(size_z):
158 | # Interpolate the radial PSF function
159 | PSF_interp = interp1d(r, PSF_rz[z_index, :])
160 |
161 | # Evaluate the PSF at each value of r_pixel
162 | self.psf_[z_index, :, :] = PSF_interp(r_pixel.ravel()).reshape(size_y, size_x)
163 |
164 | if self.use_square:
165 | self.psf_ = np.square(self.psf_)
166 |
167 | return torch.from_numpy(self.psf_).to(SSettings.instance().device)
168 |
169 |
170 | def spsf_gibson_lanni(shape: tuple[int, int] | tuple[int, int, int],
171 | NA: float = 1.4,
172 | wavelength: float = 0.610,
173 | M: float = 100,
174 | ns: float = 1.33,
175 | ng0: float = 1.5,
176 | ng: float = 1.5,
177 | ni0: float = 1.5,
178 | ni: float = 1.5,
179 | ti0: float = 150,
180 | tg0: float = 170,
181 | tg: float = 170,
182 | res_lateral: float = 0.1,
183 | res_axial: float = 0.25,
184 | pZ: float = 0,
185 | use_square: bool = False
186 | ) -> torch.Tensor:
187 | """Function to generate a Gibson-Lanni PSF
188 |
189 | :param shape: Size of the PSF array in each dimension [(Z), Y, X],
190 | :param NA: Numerical aperture,
191 | :param wavelength: Wavelength in microns,
192 | :param M: Magnification,
193 | :param ns: Specimen refractive index (RI),
194 | :param ng0: Coverslip RI design value,
195 | :param ng: Coverslip RI experimental value,
196 | :param ni0: Immersion medium RI design value,
197 | :param ni: Immersion medium RI experimental value,
198 | :param ti0: microns, working distance (immersion medium thickness) design value,
199 | :param tg0: microns, coverslip thickness design value,
200 | :param tg: microns, coverslip thickness experimental value,
201 | :param res_lateral: Lateral resolution in microns,
202 | :param res_axial: Axial resolution in microns,
203 | :param pZ: microns, particle distance from coverslip
204 | :param use_square: If true, calculate the square of the Gibson-Lanni model to simulate a
205 | pinhole. It then gives a PSF for a confocal image
206 | """
207 | filter_ = SPSFGibsonLanni(shape, NA, wavelength, M, ns,
208 | ng0, ng, ni0, ni, ti0, tg0, tg,
209 | res_lateral, res_axial, pZ, use_square)
210 | return filter_()
211 |
212 |
213 | metadata = {
214 | 'name': 'SPSFGibsonLanni',
215 | 'label': 'Gibson Lanni PSF',
216 | 'fnc': spsf_gibson_lanni,
217 | 'inputs': {
218 | 'shape': {
219 | 'type': 'zyx_int',
220 | 'label': 'Size',
221 | 'help': 'Regularisation parameter',
222 | 'default': [11, 128, 128]
223 | },
224 | 'NA': {
225 | 'type': 'float',
226 | 'label': 'Numerical aperture',
227 | 'help': 'Numerical aperture',
228 | 'default': 1.4
229 | },
230 | 'wavelength': {
231 | 'type': 'float',
232 | 'label': 'Wavelength',
233 | 'help': 'Wavelength',
234 | 'default': 0.610
235 | },
236 | 'M': {
237 | 'type': 'float',
238 | 'label': 'Magnification',
239 | 'help': 'Magnification',
240 | 'default': 100
241 | },
242 | 'ns': {
243 | 'type': 'float',
244 | 'label': 'ns',
245 | 'help': 'Specimen refractive index (RI)',
246 | 'default': 1.33
247 | },
248 | 'ng0': {
249 | 'type': 'float',
250 | 'label': 'ng0',
251 | 'help': 'Coverslip RI design value',
252 | 'default': 1.5
253 | },
254 | 'ng': {
255 | 'type': 'float',
256 | 'label': 'ng',
257 | 'help': 'coverslip RI experimental value',
258 | 'default': 1.5
259 | },
260 | 'ni0': {
261 | 'type': 'float',
262 | 'label': 'ni0',
263 | 'help': 'Immersion medium RI design value',
264 | 'default': 1.5
265 | },
266 | 'ni': {
267 | 'type': 'float',
268 | 'label': 'ni0',
269 | 'help': 'Immersion medium RI experimental value',
270 | 'default': 1.5
271 | },
272 | 'ti0': {
273 | 'type': 'float',
274 | 'label': 'ti0',
275 | 'help': 'microns, working distance (immersion medium thickness) design value',
276 | 'default': 150
277 | },
278 | 'tg0': {
279 | 'type': 'float',
280 | 'label': 'tg0',
281 | 'help': 'microns, coverslip thickness design value',
282 | 'default': 170
283 | },
284 | 'tg': {
285 | 'type': 'float',
286 | 'label': 'tg',
287 | 'help': 'microns, coverslip thickness experimental value',
288 | 'default': 170
289 | },
290 | 'res_lateral': {
291 | 'type': 'float',
292 | 'label': 'Lateral resolution',
293 | 'help': 'Lateral resolution in microns',
294 | 'default': 0.1
295 | },
296 | 'res_axial': {
297 | 'type': 'float',
298 | 'label': 'Axial resolution',
299 | 'help': 'Axial resolution in microns',
300 | 'default': 0.25
301 | },
302 | 'pZ': {
303 | 'type': 'float',
304 | 'label': 'Particle position',
305 | 'help': 'Particle distance from coverslip in microns',
306 | 'default': 0
307 | },
308 | 'use_square': {
309 | 'type': 'bool',
310 | 'label': 'Confocal',
311 | 'help': 'Check for confocal PSF, uncheck for widefield',
312 | 'default': True
313 | }
314 | },
315 | 'outputs': {
316 | 'image': {
317 | 'type': 'Image',
318 | 'label': 'PSF Gibson-Lanni'
319 | },
320 | }
321 | }
322 |
--------------------------------------------------------------------------------
/sdeconv/psfs/interface.py:
--------------------------------------------------------------------------------
1 | """Interface for a psf generator"""
2 | import torch
3 | from sdeconv.core import SObservable
4 |
5 |
6 | class SPSFGenerator(SObservable):
7 | """Interface for a psf generator"""
8 | def __init__(self):
9 | super().__init__()
10 | self.type = 'SPSFGenerator'
11 |
12 | def __call__(self) -> torch.Tensor:
13 | """Generate the PSF
14 |
15 | return: PSF image [(Z) Y X]
16 | """
17 | raise NotImplementedError('SPSFGenerator is an interface. Please implement the'
18 | ' __call__ method')
19 |
--------------------------------------------------------------------------------
/sdeconv/psfs/lorentz.py:
--------------------------------------------------------------------------------
1 | """Implements the Lorentz Point Spread Function generator"""
2 | import math
3 | import torch
4 | from ..core import SSettings
5 | from .interface import SPSFGenerator
6 |
7 |
8 | class SPSFLorentz(SPSFGenerator):
9 | """Generate a Lorentz PSF
10 |
11 | :param gamma: Width of the Lorentz in each dimension [(Z), Y, X]
12 | :param shape: Size of the PSF image in each dimension [(Z), Y, X]
13 | """
14 | def __init__(self,
15 | gamma: tuple[float, float] | tuple[float, float, float],
16 | shape: tuple[int, int, int] | tuple[int, int, int]):
17 | super().__init__()
18 | self.gamma = gamma
19 | self.shape = shape
20 | self.psf_ = None
21 |
22 | @staticmethod
23 | def _normalize_inputs(gamma: tuple[float, float] | tuple[float, float, float],
24 | shape: tuple[int, int, int] | tuple[int, int, int]
25 | ) -> tuple:
26 | """Remove batch dimention if it exists
27 |
28 | :param sigma: Width of the PSF
29 | :param shape: Shape of the PSF
30 | :return: The modified sigma and shape
31 | """
32 | if len(shape) == 3 and shape[0] == 1:
33 | return gamma[1:], shape[1:]
34 | return gamma, shape
35 |
36 | def __call__(self) -> torch.Tensor:
37 | """Calculate the PSF image
38 |
39 | :return: The PSF image as a Tensor
40 | """
41 | self.gamma, self.shape = SPSFLorentz._normalize_inputs(self.gamma, self.shape)
42 | if len(self.shape) == 2:
43 | self.psf_ = torch.zeros((self.shape[0], self.shape[1])).to(SSettings.instance().device)
44 | x_0 = math.floor(self.shape[0] / 2)
45 | y_0 = math.floor(self.shape[1] / 2)
46 | # print('center= (', x0, ', ', y0, ')')
47 | xx_, yy_ = torch.meshgrid(torch.arange(0, self.shape[0]),
48 | torch.arange(0, self.shape[1]),
49 | indexing='ij')
50 | self.psf_ = 1 / (1 + torch.pow((xx_ - x_0)/(0.5*self.gamma[0]), 2) +
51 | torch.pow((yy_ - y_0)/(0.5*self.gamma[1]), 2))
52 | self.psf_ = self.psf_ / torch.sum(self.psf_)
53 | elif len(self.shape) == 3:
54 | self.psf_ = torch.zeros(self.shape).to(SSettings.instance().device)
55 | x_0 = math.floor(self.shape[2] / 2)
56 | y_0 = math.floor(self.shape[1] / 2)
57 | z_0 = math.floor(self.shape[0] / 2)
58 | zzz, yyy, xxx = torch.meshgrid(torch.arange(0, self.shape[0]),
59 | torch.arange(0, self.shape[1]),
60 | torch.arange(0, self.shape[2]),
61 | indexing='ij')
62 | self.psf_ = 1 / (1 + torch.pow((xxx - x_0)/(0.5*self.gamma[2]), 2) +
63 | torch.pow((yyy - y_0)/(0.5*self.gamma[1]), 2) +
64 | torch.pow((zzz - z_0)/(0.5*self.gamma[0]), 2))
65 | self.psf_ = self.psf_ / torch.sum(self.psf_)
66 | else:
67 | raise ValueError('SPSFLorentz: can generate only 2D or 3D PSFs')
68 | return self.psf_
69 |
70 |
71 | def spsf_lorentz(gamma: tuple[float, float] | tuple[float, float, float],
72 | shape: tuple[int, int, int] | tuple[int, int, int]
73 | ) -> torch.Tensor:
74 | """Function to generate a Lorentz PSF
75 |
76 | :param gamma: Width of the Lorentz in each dimension [(Z), Y, X],
77 | :param shape: Size of the PSF image in each dimension [(Z), Y, X],
78 | :return: The PSF mage
79 | """
80 | filter_ = SPSFLorentz(gamma, shape)
81 | return filter_()
82 |
83 |
84 | metadata = {
85 | 'name': 'SPSFLorentz',
86 | 'label': 'Lorentz PSF',
87 | 'fnc': spsf_lorentz,
88 | 'inputs': {
89 | 'gamma': {
90 | 'type': 'zyx_float',
91 | 'label': 'Gamma',
92 | 'help': 'PSF width in each direction',
93 | 'default': [0, 1.5, 1.5]
94 | },
95 | 'shape': {
96 | 'type': 'zyx_int',
97 | 'label': 'Size',
98 | 'help': 'PSF image shape',
99 | 'default': [1, 128, 128]
100 | }
101 | },
102 | 'outputs': {
103 | 'image': {
104 | 'type': 'Image',
105 | 'label': 'PSF Lorentz'
106 | },
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = sdeconv
3 | version = 1.0.4
4 | author = Sylvain Prigent
5 | author_email = meriadec.prigent@gmail.com
6 | url = https://github.com/sylvainprigent/sdeconv
7 | license = BSD 3-Clause
8 | description = Implementation of 2D and 3D scientific image deconvolution
9 | long_description = file: README.md
10 | long_description_content_type = text/markdown
11 | classifiers =
12 | Development Status :: 2 - Pre-Alpha
13 | Intended Audience :: Developers
14 | Topic :: Software Development :: Testing
15 | Programming Language :: Python
16 | Programming Language :: Python :: 3
17 | Programming Language :: Python :: 3.10
18 | Programming Language :: Python :: 3.11
19 | Operating System :: OS Independent
20 | License :: OSI Approved :: BSD License
21 |
22 | [options]
23 | packages = find:
24 | python_requires = >=3.7
25 |
26 | # add your package requirements here
27 | install_requires =
28 | scipy>=1.8.1
29 | numpy>=1.22.4
30 | torch>=1.11.0
31 | torchvision>=0.12.0
32 | scikit-image>=0.19.2
33 |
34 | [options.package_data]
35 | * = */*.tif
36 |
37 | [options.entry_points]
38 | console_scripts =
39 | spsf = sdeconv.cli.spsf:main
40 | sdeconv = sdeconv.cli.sdeconv:main
41 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup()
4 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/__init__.py
--------------------------------------------------------------------------------
/tests/deconv/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/__init__.py
--------------------------------------------------------------------------------
/tests/deconv/celegans_richardson_lucy.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/celegans_richardson_lucy.tif
--------------------------------------------------------------------------------
/tests/deconv/celegans_spitfire.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/celegans_spitfire.tif
--------------------------------------------------------------------------------
/tests/deconv/celegans_wiener.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/celegans_wiener.tif
--------------------------------------------------------------------------------
/tests/deconv/pollen_richardson_lucy.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/pollen_richardson_lucy.tif
--------------------------------------------------------------------------------
/tests/deconv/pollen_spitfire.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/pollen_spitfire.tif
--------------------------------------------------------------------------------
/tests/deconv/pollen_wiener.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/pollen_wiener.tif
--------------------------------------------------------------------------------
/tests/deconv/test_richardson_lucy.py:
--------------------------------------------------------------------------------
1 | """Unit testing the Richardson-Lucy deconvolution implementation"""
2 | import os
3 | import numpy as np
4 | from skimage.io import imread
5 |
6 | from sdeconv.data import celegans, pollen_poison_noise_blurred, pollen_psf
7 | from sdeconv.deconv import SRichardsonLucy
8 | from sdeconv.psfs import SPSFGaussian
9 |
10 |
11 | def test_richardson_lucy_2d():
12 | """Unit testing the 2D Richardson-Lucy implementation"""
13 | root_dir = os.path.dirname(os.path.abspath(__file__))
14 | image = celegans()
15 |
16 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13))
17 | psf = psf_generator()
18 |
19 | filter_ = SRichardsonLucy(psf, niter=30, pad=13)
20 | out_image = filter_(image)
21 |
22 | # imsave(os.path.join(root_dir, 'celegans_richardson_lucy_gpu.tif'),
23 | # out_image.detach().cpu().numpy())
24 | ref_image = imread(os.path.join(root_dir, 'celegans_richardson_lucy.tif'))
25 |
26 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy(), ref_image, decimal=1)
27 |
28 |
29 | def test_richardson_lucy_3d():
30 | """Unit testing the 3D Richardson-Lucy implementation"""
31 | root_dir = os.path.dirname(os.path.abspath(__file__))
32 | image = pollen_poison_noise_blurred()
33 | psf = pollen_psf()
34 |
35 | filter_ = SRichardsonLucy(psf, niter=30, pad=(16, 64, 64))
36 | out_image = filter_(image)
37 |
38 | # imsave(os.path.join(root_dir, 'pollen_richardson_lucy_linux.tif'),
39 | # out_image.detach().cpu().numpy())
40 | ref_image = imread(os.path.join(root_dir, 'pollen_richardson_lucy.tif'))
41 |
42 | np.testing.assert_almost_equal(out_image[20, ...].detach().cpu().numpy(), ref_image[20, ...],
43 | decimal=1)
44 |
--------------------------------------------------------------------------------
/tests/deconv/test_spitfire.py:
--------------------------------------------------------------------------------
1 | """Unit testing the Spitfire deconvolution implementation"""
2 | import os
3 | import numpy as np
4 | from skimage.io import imread
5 |
6 | from sdeconv.data import celegans, pollen_poison_noise_blurred, pollen_psf
7 | from sdeconv.deconv import Spitfire
8 | from sdeconv.psfs import SPSFGaussian
9 |
10 |
11 | def test_spitfire_2d():
12 | """Unit testing Spitfire 2D deconvolution"""
13 | root_dir = os.path.dirname(os.path.abspath(__file__))
14 | image = celegans()
15 |
16 | psf_generator = SPSFGaussian((1.5, 1.5), (15, 15))
17 | psf = psf_generator()
18 |
19 | filter_ = Spitfire(psf, weight=0.6, reg=0.995, gradient_step=0.01, precision=1e-7, pad=13)
20 | out_image = filter_(image)
21 |
22 | # imsave(os.path.join(root_dir, 'celegans_spitfire.tif'), out_image.detach().cpu().numpy())
23 | ref_image = imread(os.path.join(root_dir, 'celegans_spitfire.tif'))
24 |
25 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy()/10, ref_image/10, decimal=0)
26 |
27 |
28 | def test_spitfire_3d():
29 | """Unit testing Spitfire 3D deconvolution"""
30 | root_dir = os.path.dirname(os.path.abspath(__file__))
31 | image = pollen_poison_noise_blurred()
32 | psf = pollen_psf()
33 |
34 | filter_ = Spitfire(psf, weight=0.6, reg=0.99995, gradient_step=0.01, precision=1e-7)
35 | out_image = filter_(image)
36 |
37 | # imsave(os.path.join(root_dir, 'pollen_spitfire.tif'), out_image.detach().cpu().numpy())
38 | ref_image = imread(os.path.join(root_dir, 'pollen_spitfire.tif'))
39 |
40 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy(), ref_image, decimal=0)
41 |
--------------------------------------------------------------------------------
/tests/deconv/test_wiener.py:
--------------------------------------------------------------------------------
1 | """Unit testing the Wiener deconvolution implementation"""
2 | import os
3 | import numpy as np
4 | from skimage.io import imread
5 |
6 | from sdeconv.data import celegans, pollen_poison_noise_blurred, pollen_psf
7 | from sdeconv.deconv import SWiener
8 | from sdeconv.psfs import SPSFGaussian
9 |
10 |
11 | def test_wiener_2d():
12 | """Unit testing wiener 2D deconvolution"""
13 | root_dir = os.path.dirname(os.path.abspath(__file__))
14 | image = celegans()
15 |
16 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13))
17 | psf = psf_generator()
18 |
19 | filter_ = SWiener(psf, beta=0.005, pad=13)
20 | out_image = filter_(image)
21 |
22 | # imsave(os.path.join(root_dir, 'celegans_wiener.tif'), out_image.detach().numpy())
23 | ref_image = imread(os.path.join(root_dir, 'celegans_wiener.tif'))
24 |
25 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy(), ref_image, decimal=1)
26 |
27 |
28 | def test_wiener_3d():
29 | """Unit testing wiener 3D deconvolution"""
30 | root_dir = os.path.dirname(os.path.abspath(__file__))
31 | image = pollen_poison_noise_blurred()
32 | psf = pollen_psf()
33 |
34 | filter_ = SWiener(psf, beta=0.0005, pad=(16, 64, 64))
35 | out_image = filter_(image)
36 |
37 | # imsave(os.path.join(root_dir, 'pollen_wiener.tif'), out_image.detach().numpy())
38 | ref_image = imread(os.path.join(root_dir, 'pollen_wiener.tif'))
39 |
40 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy(), ref_image, decimal=1)
41 |
--------------------------------------------------------------------------------
/tests/psfs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/__init__.py
--------------------------------------------------------------------------------
/tests/psfs/gaussian2d.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/gaussian2d.tif
--------------------------------------------------------------------------------
/tests/psfs/gaussian3d.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/gaussian3d.tif
--------------------------------------------------------------------------------
/tests/psfs/gibsonlanni.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/gibsonlanni.tif
--------------------------------------------------------------------------------
/tests/psfs/lorentz2d.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/lorentz2d.tif
--------------------------------------------------------------------------------
/tests/psfs/lorentz3d.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/lorentz3d.tif
--------------------------------------------------------------------------------
/tests/psfs/test_gausian.py:
--------------------------------------------------------------------------------
1 | """Unit testing for Gaussian PSFs"""
2 |
3 | import os
4 | import numpy as np
5 | from skimage.io import imread
6 |
7 | from sdeconv.psfs.gaussian import SPSFGaussian
8 |
9 |
10 | def test_psf_gaussian_2d():
11 | """Unit testing the 2D Gaussian PSF generator"""
12 |
13 | psf_generator = SPSFGaussian((1.5, 1.5), (15, 15))
14 | psf = psf_generator()
15 |
16 | root_dir = os.path.dirname(os.path.abspath(__file__))
17 | # imsave(os.path.join(root_dir, 'gaussian2d.tif'), psf.detach().numpy())
18 | ref_psf = imread(os.path.join(root_dir, 'gaussian2d.tif'))
19 |
20 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5)
21 |
22 |
23 | def test_psf_gaussian_3d():
24 | """Unit testing the 3D Gaussian PSF generator"""
25 |
26 | psf_generator = SPSFGaussian((0.5, 1.5, 1.5), (11, 15, 15))
27 | psf = psf_generator()
28 |
29 | root_dir = os.path.dirname(os.path.abspath(__file__))
30 | # imsave(os.path.join(root_dir, 'gaussian3d.tif'), psf.detach().numpy())
31 | ref_psf = imread(os.path.join(root_dir, 'gaussian3d.tif'))
32 |
33 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5)
34 |
--------------------------------------------------------------------------------
/tests/psfs/test_gibsonlanni.py:
--------------------------------------------------------------------------------
1 | """Module to test the Gibson-Lanni PSF implementation"""
2 | import os
3 | import numpy as np
4 | from skimage.io import imread
5 | from sdeconv.psfs.gibson_lanni import SPSFGibsonLanni
6 |
7 |
8 | def test_gibson_lanni():
9 | """An example of how you might test your plugin."""
10 |
11 | shape = (18, 128, 128)
12 | NA = 1.4
13 | wavelength = 0.610
14 | M = 100
15 | ns = 1.33
16 | ng0 = 1.5
17 | ng = 1.5
18 | ni0 = 1.5
19 | ni = 1.5
20 | ti0 = 150
21 | tg0 = 170
22 | tg = 170
23 | res_lateral = 0.1
24 | res_axial = 0.25
25 | pZ = 0
26 | use_square = False
27 |
28 | psf_generator = SPSFGibsonLanni(shape, NA, wavelength, M, ns,
29 | ng0, ng, ni0, ni, ti0, tg0, tg,
30 | res_lateral, res_axial, pZ, use_square)
31 | psf = psf_generator()
32 |
33 | root_dir = os.path.dirname(os.path.abspath(__file__))
34 | # imsave(os.path.join(root_dir, 'gibsonlanni.tif'), psf.detach().numpy())
35 | ref_psf = imread(os.path.join(root_dir, 'gibsonlanni.tif'))
36 |
37 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5)
38 |
--------------------------------------------------------------------------------
/tests/psfs/test_lorentz.py:
--------------------------------------------------------------------------------
1 | """Unit testing for Lorentz PSFs"""
2 |
3 | import os
4 | import numpy as np
5 | from skimage.io import imread, imsave
6 |
7 | from sdeconv.psfs.lorentz import SPSFLorentz
8 |
9 |
10 | def test_psf_lorentz_2d():
11 | """Unit testing the 2D Lorentz PSF generator"""
12 |
13 | psf_generator = SPSFLorentz((1.5, 1.5), (15, 15))
14 | psf = psf_generator()
15 |
16 | root_dir = os.path.dirname(os.path.abspath(__file__))
17 | #imsave(os.path.join(root_dir, 'lorentz2d.tif'), psf.detach().numpy())
18 | ref_psf = imread(os.path.join(root_dir, 'lorentz2d.tif'))
19 |
20 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5)
21 |
22 |
23 | def test_psf_lorentz_3d():
24 | """Unit testing the 3D Lorentz PSF generator"""
25 |
26 | psf_generator = SPSFLorentz((0.5, 1.5, 1.5), (11, 15, 15))
27 | psf = psf_generator()
28 |
29 | root_dir = os.path.dirname(os.path.abspath(__file__))
30 | #imsave(os.path.join(root_dir, 'lorentz3d.tif'), psf.detach().numpy())
31 | ref_psf = imread(os.path.join(root_dir, 'lorentz3d.tif'))
32 |
33 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5)
34 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/
2 | [tox]
3 | envlist = py{310, 311}-{linux,macos,windows}
4 |
5 | [gh-actions]
6 | python =
7 | 3.10: py310
8 | 3.11: py311
9 |
10 | [gh-actions:env]
11 | PLATFORM =
12 | ubuntu-latest: linux
13 | macos-latest: macos
14 | windows-latest: windows
15 |
16 | [testenv]
17 | platform =
18 | macos: darwin
19 | linux: linux
20 | windows: win32
21 | passenv =
22 | CI
23 | GITHUB_ACTIONS
24 | DISPLAY_XAUTHORITY
25 | NUMPY_EXPERIMENTAL_ARRAY_FUNCTION
26 | PYVISTA_OFF_SCREEN
27 | deps =
28 | pytest # https://docs.pytest.org/en/latest/contents.html
29 | pytest-cov # https://pytest-cov.readthedocs.io/en/latest/
30 | pytest-xvfb ; sys_platform == 'linux'
31 | numpy
32 | torch
33 | torchvision
34 | scikit-image
35 | commands = pytest -v --color=yes --cov=sdeconv --cov-report=xml
--------------------------------------------------------------------------------