├── .github └── workflows │ └── test_and_deploy.yml ├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── build.rst ├── docs ├── .gitignore ├── Makefile ├── make.bat └── source │ ├── _static │ └── .gitkeep │ ├── conf.py │ ├── index.rst │ ├── install.rst │ ├── interfaces.rst │ ├── intro.rst │ ├── modules.rst │ ├── quickstart.rst │ └── references.rst ├── examples ├── README.rst ├── api.py ├── gaussian_psf_2d.py ├── gaussian_psf_3d.py ├── gibsonlanni.py ├── richardson_lucy.py ├── spitfire.py └── wiener.py ├── pyproject.toml ├── requirements.txt ├── sdeconv ├── __init__.py ├── api │ ├── __init__.py │ ├── api.py │ └── factory.py ├── cli │ ├── __init__.py │ ├── sdeconv.py │ └── spsf.py ├── core │ ├── __init__.py │ ├── _observers.py │ ├── _progress_logger.py │ ├── _settings.py │ └── _timing.py ├── data │ ├── __init__.py │ ├── celegans.tif │ ├── pollen.tif │ ├── pollen_poisson_noise_blurred.tif │ └── pollen_psf.tif ├── deconv │ ├── __init__.py │ ├── _datasets.py │ ├── _transforms.py │ ├── _unet_2d.py │ ├── _utils.py │ ├── interface.py │ ├── interface_nn.py │ ├── nn_deconv.py │ ├── noise2void.py │ ├── richardson_lucy.py │ ├── self_supervised_nn.py │ ├── spitfire.py │ └── wiener.py └── psfs │ ├── __init__.py │ ├── gaussian.py │ ├── gibson_lanni.py │ ├── interface.py │ └── lorentz.py ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── deconv │ ├── __init__.py │ ├── celegans_richardson_lucy.tif │ ├── celegans_spitfire.tif │ ├── celegans_wiener.tif │ ├── pollen_richardson_lucy.tif │ ├── pollen_spitfire.tif │ ├── pollen_wiener.tif │ ├── test_richardson_lucy.py │ ├── test_spitfire.py │ └── test_wiener.py └── psfs │ ├── __init__.py │ ├── gaussian2d.tif │ ├── gaussian3d.tif │ ├── gibsonlanni.tif │ ├── lorentz2d.tif │ ├── lorentz3d.tif │ ├── test_gausian.py │ ├── test_gibsonlanni.py │ └── test_lorentz.py └── tox.ini /.github/workflows/test_and_deploy.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: tests 5 | 6 | on: 7 | push: 8 | branches: 9 | - master 10 | - main 11 | tags: 12 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 13 | pull_request: 14 | branches: 15 | - master 16 | - main 17 | workflow_dispatch: 18 | 19 | jobs: 20 | test: 21 | name: ${{ matrix.platform }} py${{ matrix.python-version }} 22 | runs-on: ${{ matrix.platform }} 23 | strategy: 24 | matrix: 25 | platform: [ubuntu-latest, windows-latest, macos-latest] 26 | python-version: ['3.10', '3.11'] 27 | 28 | steps: 29 | - uses: actions/checkout@v2 30 | 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v2 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | 36 | # these libraries, along with pytest-xvfb (added in the `deps` in tox.ini), 37 | # enable testing on Qt on linux 38 | - name: Install Linux libraries 39 | if: runner.os == 'Linux' 40 | run: | 41 | sudo apt-get install -y libdbus-1-3 libxkbcommon-x11-0 libxcb-icccm4 \ 42 | libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 \ 43 | libxcb-xinerama0 libxcb-xinput0 libxcb-xfixes0 44 | # note: if you need dependencies from conda, considering using 45 | # setup-miniconda: https://github.com/conda-incubator/setup-miniconda 46 | # and 47 | # tox-conda: https://github.com/tox-dev/tox-conda 48 | - name: Install dependencies 49 | run: | 50 | python -m pip install --upgrade pip 51 | pip install setuptools tox numpy scipy tox-gh-actions 52 | # this runs the platform-specific tests declared in tox.ini 53 | - name: Test with tox 54 | run: tox 55 | env: 56 | PLATFORM: ${{ matrix.platform }} 57 | 58 | - name: Coverage 59 | uses: codecov/codecov-action@v1 60 | 61 | deploy: 62 | # this will run when you have tagged a commit, starting with "v*" 63 | # and requires that you have put your twine API key in your 64 | # github secrets (see readme for details) 65 | needs: [test] 66 | runs-on: ubuntu-latest 67 | if: contains(github.ref, 'tags') 68 | steps: 69 | - uses: actions/checkout@v2 70 | - name: Set up Python 71 | uses: actions/setup-python@v2 72 | with: 73 | python-version: "3.x" 74 | - name: Install dependencies 75 | run: | 76 | python -m pip install --upgrade pip 77 | pip install numpy scipy torch torchvision scikit-image 78 | pip install -U setuptools setuptools_scm wheel twine 79 | - name: Build and publish 80 | env: 81 | TWINE_USERNAME: __token__ 82 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 83 | run: | 84 | git tag 85 | python setup.py sdist bdist_wheel 86 | twine upload dist/* 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vscode/ 3 | *.pyc 4 | *.pyo 5 | *~ 6 | \.#* 7 | /build 8 | /dist 9 | *.egg-info 10 | *.so 11 | __pycache__ 12 | .DS_Store 13 | _build 14 | docs/source/generated/ 15 | runs/ 16 | demo_*.* -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code. 6 | extension-pkg-allow-list= 7 | 8 | # A comma-separated list of package or module names from where C extensions may 9 | # be loaded. Extensions are loading into the active Python interpreter and may 10 | # run arbitrary code. (This is an alternative name to extension-pkg-allow-list 11 | # for backward compatibility.) 12 | extension-pkg-whitelist= 13 | 14 | # Return non-zero exit code if any of these messages/categories are detected, 15 | # even if score is above --fail-under value. Syntax same as enable. Messages 16 | # specified are enabled, while categories only check already-enabled messages. 17 | fail-on= 18 | 19 | # Specify a score threshold to be exceeded before program exits with error. 20 | fail-under=10.0 21 | 22 | # Files or directories to be skipped. They should be base names, not paths. 23 | ignore=CVS 24 | 25 | # Add files or directories matching the regex patterns to the ignore-list. The 26 | # regex matches against paths. 27 | ignore-paths= 28 | 29 | # Files or directories matching the regex patterns are skipped. The regex 30 | # matches against base names, not paths. 31 | ignore-patterns= 32 | 33 | # Python code to execute, usually for sys.path manipulation such as 34 | # pygtk.require(). 35 | #init-hook= 36 | 37 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 38 | # number of processors available to use. 39 | jobs=1 40 | 41 | # Control the amount of potential inferred values when inferring a single 42 | # object. This can help the performance when dealing with large functions or 43 | # complex, nested conditions. 44 | limit-inference-results=100 45 | 46 | # List of plugins (as comma separated values of python module names) to load, 47 | # usually to register additional checkers. 48 | load-plugins= 49 | 50 | # Pickle collected data for later comparisons. 51 | persistent=yes 52 | 53 | # When enabled, pylint would attempt to guess common misconfiguration and emit 54 | # user-friendly hints instead of false-positive error messages. 55 | suggestion-mode=yes 56 | 57 | # Allow loading of arbitrary C extensions. Extensions are imported into the 58 | # active Python interpreter and may run arbitrary code. 59 | unsafe-load-any-extension=no 60 | 61 | 62 | [MESSAGES CONTROL] 63 | 64 | # Only show warnings with the listed confidence levels. Leave empty to show 65 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. 66 | confidence= 67 | 68 | # Disable the message, report, category or checker with the given id(s). You 69 | # can either give multiple identifiers separated by comma (,) or put this 70 | # option multiple times (only on the command line, not in the configuration 71 | # file where it should appear only once). You can also use "--disable=all" to 72 | # disable everything first and then reenable specific checks. For example, if 73 | # you want to run only the similarities checker, you can use "--disable=all 74 | # --enable=similarities". If you want to run only the classes checker, but have 75 | # no Warning level messages displayed, use "--disable=all --enable=classes 76 | # --disable=W". 77 | disable=too-many-arguments, duplicate-code, too-few-public-methods, too-many-instance-attributes, consider-using-with 78 | 79 | # Enable the message, report, category or checker with the given id(s). You can 80 | # either give multiple identifier separated by comma (,) or put this option 81 | # multiple time (only on the command line, not in the configuration file where 82 | # it should appear only once). See also the "--disable" option for examples. 83 | enable=c-extension-no-member 84 | 85 | 86 | [REPORTS] 87 | 88 | # Python expression which should return a score less than or equal to 10. You 89 | # have access to the variables 'error', 'warning', 'refactor', and 'convention' 90 | # which contain the number of messages in each category, as well as 'statement' 91 | # which is the total number of statements analyzed. This score is used by the 92 | # global evaluation report (RP0004). 93 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 94 | 95 | # Template used to display messages. This is a python new-style format string 96 | # used to format the message information. See doc for all details. 97 | #msg-template= 98 | 99 | # Set the output format. Available formats are text, parseable, colorized, json 100 | # and msvs (visual studio). You can also give a reporter class, e.g. 101 | # mypackage.mymodule.MyReporterClass. 102 | output-format=text 103 | 104 | # Tells whether to display a full report or only the messages. 105 | reports=no 106 | 107 | # Activate the evaluation score. 108 | score=yes 109 | 110 | 111 | [REFACTORING] 112 | 113 | # Maximum number of nested blocks for function / method body 114 | max-nested-blocks=5 115 | 116 | # Complete name of functions that never returns. When checking for 117 | # inconsistent-return-statements if a never returning function is called then 118 | # it will be considered as an explicit return statement and no message will be 119 | # printed. 120 | never-returning-functions=sys.exit,argparse.parse_error 121 | 122 | 123 | [LOGGING] 124 | 125 | # The type of string formatting that logging methods do. `old` means using % 126 | # formatting, `new` is for `{}` formatting. 127 | logging-format-style=old 128 | 129 | # Logging modules to check that the string format arguments are in logging 130 | # function parameter format. 131 | logging-modules=logging 132 | 133 | 134 | [SPELLING] 135 | 136 | # Limits count of emitted suggestions for spelling mistakes. 137 | max-spelling-suggestions=4 138 | 139 | # Spelling dictionary name. Available dictionaries: none. To make it work, 140 | # install the 'python-enchant' package. 141 | spelling-dict= 142 | 143 | # List of comma separated words that should be considered directives if they 144 | # appear and the beginning of a comment and should not be checked. 145 | spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: 146 | 147 | # List of comma separated words that should not be checked. 148 | spelling-ignore-words= 149 | 150 | # A path to a file that contains the private dictionary; one word per line. 151 | spelling-private-dict-file= 152 | 153 | # Tells whether to store unknown words to the private dictionary (see the 154 | # --spelling-private-dict-file option) instead of raising a message. 155 | spelling-store-unknown-words=no 156 | 157 | 158 | [MISCELLANEOUS] 159 | 160 | # List of note tags to take in consideration, separated by a comma. 161 | notes=FIXME, 162 | XXX, 163 | TODO 164 | 165 | # Regular expression of note tags to take in consideration. 166 | #notes-rgx= 167 | 168 | 169 | [TYPECHECK] 170 | 171 | # List of decorators that produce context managers, such as 172 | # contextlib.contextmanager. Add to this list to register other decorators that 173 | # produce valid context managers. 174 | contextmanager-decorators=contextlib.contextmanager 175 | 176 | # List of members which are set dynamically and missed by pylint inference 177 | # system, and so shouldn't trigger E1101 when accessed. Python regular 178 | # expressions are accepted. 179 | generated-members= 180 | 181 | # Tells whether missing members accessed in mixin class should be ignored. A 182 | # mixin class is detected if its name ends with "mixin" (case insensitive). 183 | ignore-mixin-members=yes 184 | 185 | # Tells whether to warn about missing members when the owner of the attribute 186 | # is inferred to be None. 187 | ignore-none=yes 188 | 189 | # This flag controls whether pylint should warn about no-member and similar 190 | # checks whenever an opaque object is returned when inferring. The inference 191 | # can return multiple potential results while evaluating a Python object, but 192 | # some branches might not be evaluated, which results in partial inference. In 193 | # that case, it might be useful to still emit no-member and other checks for 194 | # the rest of the inferred objects. 195 | ignore-on-opaque-inference=yes 196 | 197 | # List of class names for which member attributes should not be checked (useful 198 | # for classes with dynamically set attributes). This supports the use of 199 | # qualified names. 200 | ignored-classes=optparse.Values,thread._local,_thread._local,numpy,torch 201 | 202 | # List of module names for which member attributes should not be checked 203 | # (useful for modules/projects where namespaces are manipulated during runtime 204 | # and thus existing member attributes cannot be deduced by static analysis). It 205 | # supports qualified module names, as well as Unix pattern matching. 206 | ignored-modules=numpy,torch 207 | 208 | # Show a hint with possible names when a member name was not found. The aspect 209 | # of finding the hint is based on edit distance. 210 | missing-member-hint=yes 211 | 212 | # The minimum edit distance a name should have in order to be considered a 213 | # similar match for a missing member name. 214 | missing-member-hint-distance=1 215 | 216 | # The total number of similar names that should be taken in consideration when 217 | # showing a hint for a missing member. 218 | missing-member-max-choices=1 219 | 220 | # List of decorators that change the signature of a decorated function. 221 | signature-mutators= 222 | 223 | 224 | [VARIABLES] 225 | 226 | # List of additional names supposed to be defined in builtins. Remember that 227 | # you should avoid defining new builtins when possible. 228 | additional-builtins= 229 | 230 | # Tells whether unused global variables should be treated as a violation. 231 | allow-global-unused-variables=yes 232 | 233 | # List of names allowed to shadow builtins 234 | allowed-redefined-builtins= 235 | 236 | # List of strings which can identify a callback function by name. A callback 237 | # name must start or end with one of those strings. 238 | callbacks=cb_, 239 | _cb 240 | 241 | # A regular expression matching the name of dummy variables (i.e. expected to 242 | # not be used). 243 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 244 | 245 | # Argument names that match this expression will be ignored. Default to name 246 | # with leading underscore. 247 | ignored-argument-names=_.*|^ignored_|^unused_ 248 | 249 | # Tells whether we should check for unused import in __init__ files. 250 | init-import=no 251 | 252 | # List of qualified module names which can have objects that can redefine 253 | # builtins. 254 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 255 | 256 | 257 | [FORMAT] 258 | 259 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 260 | expected-line-ending-format= 261 | 262 | # Regexp for a line that is allowed to be longer than the limit. 263 | ignore-long-lines=^\s*(# )??$ 264 | 265 | # Number of spaces of indent required inside a hanging or continued line. 266 | indent-after-paren=4 267 | 268 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 269 | # tab). 270 | indent-string=' ' 271 | 272 | # Maximum number of characters on a single line. 273 | max-line-length=100 274 | 275 | # Maximum number of lines in a module. 276 | max-module-lines=1000 277 | 278 | # Allow the body of a class to be on the same line as the declaration if body 279 | # contains single statement. 280 | single-line-class-stmt=no 281 | 282 | # Allow the body of an if to be on the same line as the test if there is no 283 | # else. 284 | single-line-if-stmt=no 285 | 286 | 287 | [SIMILARITIES] 288 | 289 | # Comments are removed from the similarity computation 290 | ignore-comments=yes 291 | 292 | # Docstrings are removed from the similarity computation 293 | ignore-docstrings=yes 294 | 295 | # Imports are removed from the similarity computation 296 | ignore-imports=no 297 | 298 | # Signatures are removed from the similarity computation 299 | ignore-signatures=no 300 | 301 | # Minimum lines number of a similarity. 302 | min-similarity-lines=4 303 | 304 | 305 | [BASIC] 306 | 307 | # Naming style matching correct argument names. 308 | argument-naming-style=snake_case 309 | 310 | # Regular expression matching correct argument names. Overrides argument- 311 | # naming-style. 312 | #argument-rgx= 313 | 314 | # Naming style matching correct attribute names. 315 | attr-naming-style=snake_case 316 | 317 | # Regular expression matching correct attribute names. Overrides attr-naming- 318 | # style. 319 | #attr-rgx= 320 | 321 | # Bad variable names which should always be refused, separated by a comma. 322 | bad-names=foo, 323 | bar, 324 | baz, 325 | toto, 326 | tutu, 327 | tata 328 | 329 | # Bad variable names regexes, separated by a comma. If names match any regex, 330 | # they will always be refused 331 | bad-names-rgxs= 332 | 333 | # Naming style matching correct class attribute names. 334 | class-attribute-naming-style=any 335 | 336 | # Regular expression matching correct class attribute names. Overrides class- 337 | # attribute-naming-style. 338 | #class-attribute-rgx= 339 | 340 | # Naming style matching correct class constant names. 341 | class-const-naming-style=UPPER_CASE 342 | 343 | # Regular expression matching correct class constant names. Overrides class- 344 | # const-naming-style. 345 | #class-const-rgx= 346 | 347 | # Naming style matching correct class names. 348 | class-naming-style=PascalCase 349 | 350 | # Regular expression matching correct class names. Overrides class-naming- 351 | # style. 352 | #class-rgx= 353 | 354 | # Naming style matching correct constant names. 355 | const-naming-style=UPPER_CASE 356 | 357 | # Regular expression matching correct constant names. Overrides const-naming- 358 | # style. 359 | #const-rgx= 360 | 361 | # Minimum line length for functions/classes that require docstrings, shorter 362 | # ones are exempt. 363 | docstring-min-length=-1 364 | 365 | # Naming style matching correct function names. 366 | function-naming-style=snake_case 367 | 368 | # Regular expression matching correct function names. Overrides function- 369 | # naming-style. 370 | #function-rgx= 371 | 372 | # Good variable names which should always be accepted, separated by a comma. 373 | good-names=i, 374 | j, 375 | k, 376 | m, 377 | n, 378 | x, 379 | y, 380 | z, 381 | ex, 382 | Run, 383 | _ 384 | 385 | # Good variable names regexes, separated by a comma. If names match any regex, 386 | # they will always be accepted 387 | good-names-rgxs= 388 | 389 | # Include a hint for the correct naming format with invalid-name. 390 | include-naming-hint=no 391 | 392 | # Naming style matching correct inline iteration names. 393 | inlinevar-naming-style=any 394 | 395 | # Regular expression matching correct inline iteration names. Overrides 396 | # inlinevar-naming-style. 397 | #inlinevar-rgx= 398 | 399 | # Naming style matching correct method names. 400 | method-naming-style=snake_case 401 | 402 | # Regular expression matching correct method names. Overrides method-naming- 403 | # style. 404 | #method-rgx= 405 | 406 | # Naming style matching correct module names. 407 | module-naming-style=snake_case 408 | 409 | # Regular expression matching correct module names. Overrides module-naming- 410 | # style. 411 | #module-rgx= 412 | 413 | # Colon-delimited sets of names that determine each other's naming style when 414 | # the name regexes allow several styles. 415 | name-group= 416 | 417 | # Regular expression which should only match function or class names that do 418 | # not require a docstring. 419 | no-docstring-rgx=^_ 420 | 421 | # List of decorators that produce properties, such as abc.abstractproperty. Add 422 | # to this list to register other decorators that produce valid properties. 423 | # These decorators are taken in consideration only for invalid-name. 424 | property-classes=abc.abstractproperty 425 | 426 | # Naming style matching correct variable names. 427 | variable-naming-style=snake_case 428 | 429 | # Regular expression matching correct variable names. Overrides variable- 430 | # naming-style. 431 | #variable-rgx= 432 | 433 | 434 | [STRING] 435 | 436 | # This flag controls whether inconsistent-quotes generates a warning when the 437 | # character used as a quote delimiter is used inconsistently within a module. 438 | check-quote-consistency=no 439 | 440 | # This flag controls whether the implicit-str-concat should generate a warning 441 | # on implicit string concatenation in sequences defined over several lines. 442 | check-str-concat-over-line-jumps=no 443 | 444 | 445 | [IMPORTS] 446 | 447 | # List of modules that can be imported at any level, not just the top level 448 | # one. 449 | allow-any-import-level= 450 | 451 | # Allow wildcard imports from modules that define __all__. 452 | allow-wildcard-with-all=no 453 | 454 | # Analyse import fallback blocks. This can be used to support both Python 2 and 455 | # 3 compatible code, which means that the block might have code that exists 456 | # only in one or another interpreter, leading to false positives when analysed. 457 | analyse-fallback-blocks=no 458 | 459 | # Deprecated modules which should not be used, separated by a comma. 460 | deprecated-modules= 461 | 462 | # Output a graph (.gv or any supported image format) of external dependencies 463 | # to the given file (report RP0402 must not be disabled). 464 | ext-import-graph= 465 | 466 | # Output a graph (.gv or any supported image format) of all (i.e. internal and 467 | # external) dependencies to the given file (report RP0402 must not be 468 | # disabled). 469 | import-graph= 470 | 471 | # Output a graph (.gv or any supported image format) of internal dependencies 472 | # to the given file (report RP0402 must not be disabled). 473 | int-import-graph= 474 | 475 | # Force import order to recognize a module as part of the standard 476 | # compatibility libraries. 477 | known-standard-library= 478 | 479 | # Force import order to recognize a module as part of a third party library. 480 | known-third-party=enchant 481 | 482 | # Couples of modules and preferred modules, separated by a comma. 483 | preferred-modules= 484 | 485 | 486 | [CLASSES] 487 | 488 | # Warn about protected attribute access inside special methods 489 | check-protected-access-in-special-methods=no 490 | 491 | # List of method names used to declare (i.e. assign) instance attributes. 492 | defining-attr-methods=__init__, 493 | __new__, 494 | setUp, 495 | __post_init__ 496 | 497 | # List of member names, which should be excluded from the protected access 498 | # warning. 499 | exclude-protected=_asdict, 500 | _fields, 501 | _replace, 502 | _source, 503 | _make 504 | 505 | # List of valid names for the first argument in a class method. 506 | valid-classmethod-first-arg=cls 507 | 508 | # List of valid names for the first argument in a metaclass class method. 509 | valid-metaclass-classmethod-first-arg=cls 510 | 511 | 512 | [DESIGN] 513 | 514 | # List of qualified class names to ignore when countint class parents (see 515 | # R0901) 516 | ignored-parents= 517 | 518 | # Maximum number of arguments for function / method. 519 | max-args=5 520 | 521 | # Maximum number of attributes for a class (see R0902). 522 | max-attributes=7 523 | 524 | # Maximum number of boolean expressions in an if statement (see R0916). 525 | max-bool-expr=5 526 | 527 | # Maximum number of branch for function / method body. 528 | max-branches=12 529 | 530 | # Maximum number of locals for function / method body. 531 | max-locals=15 532 | 533 | # Maximum number of parents for a class (see R0901). 534 | max-parents=7 535 | 536 | # Maximum number of public methods for a class (see R0904). 537 | max-public-methods=20 538 | 539 | # Maximum number of return / yield for function / method body. 540 | max-returns=6 541 | 542 | # Maximum number of statements in function / method body. 543 | max-statements=50 544 | 545 | # Minimum number of public methods for a class (see R0903). 546 | min-public-methods=2 547 | 548 | 549 | [EXCEPTIONS] 550 | 551 | # Exceptions that will emit a warning when being caught. Defaults to 552 | # "BaseException, Exception". 553 | overgeneral-exceptions=builtins.BaseException, 554 | builtins.Exception 555 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, STracking 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SDeconv 2 | 3 | **SDeconv** is a python framework to develop scientific image deconvolution algorithms. This 4 | library has been developed for microscopy 2D and 3D images, but can be use to any image 5 | deconvolution application. 6 | 7 | # System Requirements 8 | 9 | ## Software Requirements 10 | 11 | ### OS Requirements 12 | 13 | The `SDeconv` development version is tested on *Windows 10*, *MacOS* and *Linux* operating systems. 14 | The developmental version of the package has been tested on the following systems: 15 | 16 | - Linux: 20.04.4 17 | - Mac OSX: Mac OS Catalina 10.15.7 18 | - Windows: 10 19 | 20 | # install 21 | 22 | ## Library installation from PyPI 23 | 24 | 1. Install an [Anaconda](https://www.anaconda.com/download/) distribution of Python -- Choose **Python 3.9** and your operating system. Note you might need to use an anaconda prompt if you did not add anaconda to the path. 25 | 2. Open an anaconda prompt / command prompt with `conda` for **python 3** in the path 26 | 3. Create a new environment with `conda create --name sdeconv python=3.9`. 27 | 4. To activate this new environment, run `conda activate sdeconv` 28 | 5. To install the `SDeconv`library, run `python -m pip install sdeconv`. 29 | 30 | if you need to update to a new release, use: 31 | ~~~sh 32 | python -m pip install sdeconv --upgrade 33 | ~~~ 34 | 35 | ## Library installation from source 36 | 37 | This installation is for developers or people who want the last features in the ``main`` branch. 38 | 39 | 1. Install an [Anaconda](https://www.anaconda.com/download/) distribution of Python -- Choose **Python 3.9** and your operating system. Note you might need to use an anaconda prompt if you did not add anaconda to the path. 40 | 2. Open an anaconda prompt / command prompt with `conda` for **python 3** in the path 41 | 3. Create a new environment with `conda create --name sdeconv python=3.9`. 42 | 4. To activate this new environment, run `conda activate sdeconv` 43 | 5. Pull the source code from git with `git pull https://github.com/sylvainprigent/sdeconv.git 44 | 6. Then install the `SDeconv` library from you local dir with: `python -m pip install -e ./sdeconv`. 45 | 46 | ## Use SDeconv with napari 47 | 48 | The SDeconv library is embedded in a napari plugin that allows using ``SDeconv`` with a graphical interface. 49 | Please refer to the [`SDeconv` napari plugin](https://www.napari-hub.org/plugins/napari-sdeconv) documentation to install and use it. 50 | 51 | # SDeconv documentation 52 | 53 | The full documentation with tutorial and docstring is available [here](https://sylvainprigent.github.io/sdeconv/) 54 | -------------------------------------------------------------------------------- /build.rst: -------------------------------------------------------------------------------- 1 | Build locally 2 | ============= 3 | 4 | do:: 5 | python3 setup.py build_ext --inplace 6 | 7 | 8 | Building from source 9 | ==================== 10 | 11 | Building from source is required to work on a contribution (bug fix, new 12 | feature, code or documentation improvement). 13 | 14 | .. _git_repo: 15 | 16 | #. Use `Git `_ to check out the latest source from the 17 | `simglibpy repository `_ on 18 | GitLab.:: 19 | 20 | git clone git://github.com/sylvainprigent/sdeconv.git # add --depth 1 if your connection is slow 21 | cd simglibpy 22 | 23 | If you plan on submitting a pull-request, you should clone from your fork 24 | instead. 25 | 26 | #. Install a compiler with OpenMP_ support for your platform. 27 | 28 | #. Optional (but recommended): create and activate a dedicated virtualenv_ 29 | or `conda environment`_. 30 | 31 | #. Install Cython_ and build the project with pip in :ref:`editable_mode`:: 32 | 33 | pip install cython 34 | pip install --verbose --no-build-isolation --editable . 35 | 36 | #. Check that the installed simglibpy has a version number ending with 37 | `.dev0`:: 38 | 39 | python -c "import simglibpy; print(simglibpy.__version__)" 40 | 41 | 42 | .. note:: 43 | 44 | You will have to run the ``pip install --no-build-isolation --editable .`` 45 | command every time the source code of a Cython file is updated 46 | (ending in `.pyx` or `.pxd`). Use the ``--no-build-isolation`` flag to 47 | avoid compiling the whole project each time, only the files you have 48 | modified. 49 | 50 | Create a wheel 51 | ============== 52 | 53 | do:: 54 | 55 | python3 setup.py bdist_wheel 56 | 57 | Testing 58 | ======= 59 | 60 | run tests by running:: 61 | 62 | pytest simglibpy 63 | 64 | or 65 | 66 | python3 -m pytest simglibpy 67 | 68 | 69 | Profiling 70 | ========= 71 | 72 | python -m cProfile -o out_profile script_name.py 73 | cprofilev -f out_profile 74 | 75 | Build documentation 76 | =================== 77 | 78 | without example gallery:: 79 | 80 | cd doc 81 | make 82 | 83 | with the example gallery (may take a while):: 84 | 85 | cd doc 86 | make html 87 | 88 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build/ -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/docs/source/_static/.gitkeep -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../../')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'SDeconv' 21 | copyright = '2022-2024, SDeconv' 22 | author = 'Sylvain Prigent' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '1.0.4' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = ['sphinx.ext.autodoc', 34 | 'sphinx.ext.autosummary'] 35 | 36 | # Add any paths that contain templates here, relative to this directory. 37 | templates_path = ['_templates'] 38 | 39 | # List of patterns, relative to source directory, that match files and 40 | # directories to ignore when looking for source files. 41 | # This pattern also affects html_static_path and html_extra_path. 42 | exclude_patterns = [] 43 | 44 | 45 | # -- Options for HTML output ------------------------------------------------- 46 | 47 | # The theme to use for HTML and HTML Help pages. See the documentation for 48 | # a list of builtin themes. 49 | # 50 | html_theme = 'sphinx_book_theme' 51 | 52 | # Add any paths that contain custom static files (such as style sheets) here, 53 | # relative to this directory. They are copied after the builtin static files, 54 | # so a file named "default.css" will overwrite the builtin "default.css". 55 | html_static_path = ['_static'] 56 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. SDeconv documentation master file. 2 | 3 | SDeconv's documentation 4 | ======================= 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Contents: 9 | 10 | intro 11 | install 12 | quickstart 13 | modules 14 | references 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Install 2 | ======= 3 | 4 | This section contains the instructions to install ``SDeconv`` 5 | 6 | Using PyPI 7 | ---------- 8 | 9 | Releases are available in PyPI a repository. We recommend using virtual environment 10 | 11 | .. code-block:: shell 12 | 13 | python -m venv .sdeconv-env 14 | source .sdeconv-env/bin/activate 15 | pip install sdeconv 16 | 17 | 18 | From source 19 | ----------- 20 | 21 | If you plan to develop ``SDeconv`` we recommend installing locally 22 | 23 | .. code-block:: shell 24 | 25 | python -m venv .sdeconv-env 26 | source .sdeconv-env/bin/activate 27 | git clone https://github.com/sylvainprigent/sdeconv.git 28 | cd sdeconv 29 | pip install -e . 30 | -------------------------------------------------------------------------------- /docs/source/interfaces.rst: -------------------------------------------------------------------------------- 1 | Interfaces 2 | ========== 3 | 4 | 5 | .. autoclass:: sdeconv.psfs.interface.SPSFGenerator 6 | :members: 7 | 8 | .. autoclass:: sdeconv.deconv.interface.SDeconvFilter 9 | :members: 10 | 11 | .. autoclass:: sdeconv.deconv.interface_nn.NNModule 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | SDeconv is a library for 2D and 3D deconvolution of scientific images 5 | 6 | Context 7 | ------- 8 | SDeconv has been developed in the `Serpico `_ research team. The goal is to provide a 9 | modular library to perform deconvolution of microscopy images. A classical application of our team is to apply deconvolution in 3D+t 10 | images depecting endosomes with Lattice LightSheet microscopy, and then ease the analysis. 11 | 12 | Library components 13 | ------------------ 14 | SDeconv is written in python3 with pytorch. SDeconv library provides a module for each components of deconvolution algorithms: 15 | 16 | * **psfs**: this module defines the interface to implement Point Spread Function generators. 17 | * **deconv**: this module defines the interfaces to implement a deconvolution algorithm with or without neural networks 18 | 19 | Furthermore, the library provides sample data, a command line interface, and a application 20 | programing interface to ease the integration of the sdeconv deconvolution algorithms into softwares. 21 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======= 3 | 4 | Point Spread Functions 5 | ---------------------- 6 | 7 | .. currentmodule:: sdeconv.psfs 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | 13 | SPSFGaussian 14 | SPSFGibsonLanni 15 | SPSFLorentz 16 | 17 | 18 | Deconvolution algorithms 19 | ------------------------ 20 | 21 | .. currentmodule:: sdeconv.deconv 22 | 23 | .. autosummary:: 24 | :toctree: generated 25 | :nosignatures: 26 | 27 | SWiener 28 | SRichardsonLucy 29 | Spitfire 30 | Noise2VoidDeconv 31 | SelfSupervisedNNDeconv 32 | NNDeconv 33 | 34 | 35 | Interfaces 36 | ---------- 37 | 38 | Available interfaces to create a new PSF generator or a new deconvolution algorithm are: 39 | 40 | .. list-table:: Interfaces 41 | :widths: 25 75 42 | 43 | * - :class:`SPSFGenerator ` 44 | - Interface for creating a new PSF generator 45 | * - :class:`SDeconvFilter ` 46 | - Interface for creating a deconvolution filter that does not need neural network 47 | * - :class:`NNModule ` 48 | - Interface for creating a deconvolution filter using a neural network 49 | -------------------------------------------------------------------------------- /docs/source/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quick start 2 | =========== 3 | 4 | This is a quick start example of how to use the **SDeconv** library. This section supposes you to know the principles 5 | of deconvolution. If it is not the case, please refer to the 6 | :doc:`References `. 7 | 8 | Input images 9 | ------------ 10 | Input images are 2D or 3D gray scaled images. 2D images are represented as torch tensors with the following 11 | columns ordering ``[Y, X]`` and 3D images are represented with torch tensors with ``[Z, Y, X]`` columns ordering 12 | 13 | Sample images can be loaded using the data module: 14 | 15 | .. code-block:: python3 16 | 17 | from sdeconv import data 18 | 19 | image = data.celegans() 20 | 21 | 22 | Deconvolution using the API 23 | --------------------------- 24 | 25 | Bellow is an example how to write a deconvolution script with the API. In this example, we run the Wiener deconvolution algorithm: 26 | 27 | .. code-block:: python3 28 | 29 | from sdeconv import data 30 | from sdeconv.api import SDeconvAPI 31 | import matplotlib.pyplot as plt 32 | 33 | # Instantiate the API 34 | api = SDeconvAPI() 35 | 36 | # Load image 37 | image = data.celegans() 38 | 39 | # Generate a PSF 40 | psf = api.generate_psf('SPSFGaussian', sigma=[1.5, 1.5], shape=[13, 13]) 41 | 42 | # Deconvolution with API 43 | image_decon = api.deconvolve(image, "SWiener", plane_by_plane=False, psf=psf, beta=0.005, pad=13) 44 | 45 | # Plot the results 46 | plt.figure() 47 | plt.subplot(131) 48 | plt.title('Original') 49 | plt.imshow(image.detach().cpu().numpy(), cmap='gray') 50 | plt.axis('off') 51 | 52 | plt.subplot(132) 53 | plt.title('PSF') 54 | plt.imshow(psf.detach().cpu().numpy(), cmap='gray') 55 | plt.axis('off') 56 | 57 | plt.subplot(133) 58 | plt.title('Wiener deconvolution') 59 | plt.imshow(image_decon.detach().cpu().numpy(), cmap='gray') 60 | plt.axis('off') 61 | 62 | plt.show() 63 | 64 | 65 | The advantage of using the API is that it implements several strategies to deconvolve a 3D or 3D+t image either plane 66 | by plane then frame by frame, or frame by frame in 3D. 67 | 68 | 69 | Deconvolution using the library classes 70 | --------------------------------------- 71 | 72 | When we need only one method, the easiest way may be to call direclty the class that implements the deconvolution 73 | algorithm: 74 | 75 | .. code-block:: python3 76 | 77 | import matplotlib.pyplot as plt 78 | from sdeconv.data import celegans 79 | from sdeconv.psfs import SPSFGaussian 80 | from sdeconv.deconv import SWiener 81 | 82 | # Load a 2D sample 83 | image = celegans() 84 | 85 | # Generate a 2D PSF 86 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13)) 87 | psf = psf_generator() 88 | 89 | # Apply Wiener filter 90 | wiener = SWiener(psf, beta=0.005, pad=13) 91 | out_image = wiener(image) 92 | 93 | # Display results 94 | plt.figure() 95 | plt.title('PSF') 96 | plt.imshow(psf.detach().numpy(), cmap='gray') 97 | 98 | plt.figure() 99 | plt.title('C. elegans original') 100 | plt.imshow(image.detach().numpy(), cmap='gray') 101 | 102 | plt.figure() 103 | plt.title('C. elegans Wiener') 104 | plt.imshow(out_image.detach().numpy(), cmap='gray') 105 | 106 | plt.show() 107 | 108 | Please refer to :doc:`Modules ` for more details on the interfaces and the list of available PSFs and deconvolution methods. 109 | -------------------------------------------------------------------------------- /docs/source/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | 4 | .. rubric:: References 5 | 6 | .. [Siba05] Sibarita, J.-B. Deconvolution Microscopy 201-243 (Springer Berlin Heidelberg, Berlin Heidelberg, 2005). https://doi.org/10.1007/b102215 7 | .. [Rich72] Richardson, W. H. Bayesian-based iterative method of image restoration. J. Opt. Soc. Am. 62, 55-59 (1972). 8 | .. [Lucy74] Lucy, L. B. An iterative technique for the rectification of observed distributions. Astron. J. 79, 745-754 (1974). 9 | .. [VdSt95] Van Der Voort, H. T. M. & Strasters, K. C. Restoration of confocal images for quantitative image analysis. J. Microsc. 178, 165-181. https://doi.org/10.1111/j.1365-2818.1995.tb03593.x (1995). 10 | .. [VkVB96] van Kempen, G. M. P., van der Voort, H. T. M., Bauman, J. G. J. & Strasters, K. C. Comparing maximum likelihood estimation and constrained Tikhonov-Miller restoration. IEEE Engineering in Medicine and Biology Magazine 15, 76-83 (1996). 11 | .. [VkVV97] van Kempen, G. M. P., van Vliet, L. J., Verveer, P. J. & van der Voort, H. T. M. A quantitative comparison of image restoration methods for confocal microscopy. J. Microscopy 185, 354-365. https://doi.org/10.1046/j.1365-2818.1997.d01-629.x (1997). 12 | .. [PNLV22] Prigent, S., Nguyen, HN., Leconte, L. et al. SPITFIR(e): a supermaneuverable algorithm for fast denoising and deconvolution of 3D fluorescence microscopy images and videos. Sci Rep 13, 1489 (2023). https://doi.org/10.1038/s41598-022-26178-y 13 | .. [KrBF19] Krull, A., Buchholz, T-O, Jug, F. Noise2Void - Learning Denoising From Single Noisy Images. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). Jine, 2019 14 | -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== -------------------------------------------------------------------------------- /examples/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | SDeconv API 3 | 4 | This example shows how to use the SDeconv application programming interface 5 | """ 6 | 7 | from sdeconv import data 8 | from sdeconv.api import SDeconvAPI 9 | import matplotlib.pyplot as plt 10 | 11 | # instantiate the API 12 | api = SDeconvAPI() 13 | 14 | # load image 15 | image = data.celegans() 16 | 17 | # Generate a PSF 18 | psf = api.generate_psf('SPSFGaussian', sigma=[1.5, 1.5], shape=[13, 13]) 19 | 20 | # deconvolution with API 21 | image_decon = api.deconvolve(image, "SWiener", plane_by_plane=False, psf=psf, beta=0.005, pad=13) 22 | 23 | # plot the result 24 | plt.figure() 25 | plt.subplot(131) 26 | plt.title('Original') 27 | plt.imshow(image.detach().cpu().numpy(), cmap='gray') 28 | plt.axis('off') 29 | 30 | plt.subplot(132) 31 | plt.title('PSF') 32 | plt.imshow(psf.detach().cpu().numpy(), cmap='gray') 33 | plt.axis('off') 34 | 35 | plt.subplot(133) 36 | plt.title('Wiener deconvolution') 37 | plt.imshow(image_decon.detach().cpu().numpy(), cmap='gray') 38 | plt.axis('off') 39 | 40 | plt.show() 41 | -------------------------------------------------------------------------------- /examples/gaussian_psf_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gaussian PSF 3 | 4 | This example shows how to generate a Gaussian PSF 5 | """ 6 | 7 | import matplotlib.pyplot as plt 8 | from sdeconv.psfs import SPSFGaussian 9 | 10 | 11 | psf_generator = SPSFGaussian(sigma=(1.5, 1.5), shape=(13, 13)) 12 | psf = psf_generator() 13 | 14 | plt.figure() 15 | plt.title('Gaussian PSF') 16 | plt.imshow(psf.detach().numpy(), cmap='gray') 17 | plt.show() -------------------------------------------------------------------------------- /examples/gaussian_psf_3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gaussian PSF 3 | 4 | This example shows how to generate a Gaussian PSF 5 | """ 6 | 7 | from sdeconv.psfs import SPSFGaussian 8 | import napari 9 | import time 10 | 11 | psf_generator = SPSFGaussian(sigma=(0.5, 1.5, 1.5), shape=(25, 128, 128)) 12 | t = time.time() 13 | psf = psf_generator() 14 | elapsed = time.time() - t 15 | print('elapsed = ', elapsed) 16 | 17 | viewer = napari.view_image(psf.detach().numpy(), scale=[200, 100, 100]) 18 | napari.run() 19 | -------------------------------------------------------------------------------- /examples/gibsonlanni.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gibson Lanni PSF 3 | 4 | This example shows how to generate a Gibson Lanni PSF 5 | """ 6 | 7 | import matplotlib.pyplot as plt 8 | from sdeconv.psfs import SPSFGibsonLanni 9 | import napari 10 | import time 11 | 12 | psf_generator = SPSFGibsonLanni((11, 128, 128), use_square=True) 13 | t = time.time() 14 | 15 | psf = psf_generator() 16 | elapsed = time.time() - t 17 | print('elapsed = ', elapsed) 18 | 19 | viewer = napari.view_image(psf.detach().numpy(), scale=[200, 100, 100]) 20 | napari.run() 21 | -------------------------------------------------------------------------------- /examples/richardson_lucy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Richardson-Lucy deconvolution 3 | 4 | This example shows how to use the Richardson-Lucy deconvolution on a 2D image 5 | """ 6 | 7 | import matplotlib.pyplot as plt 8 | from sdeconv.data import celegans 9 | from sdeconv.psfs import SPSFGaussian 10 | from sdeconv.deconv import SRichardsonLucy 11 | 12 | 13 | # load a 2D sample 14 | image = celegans() 15 | 16 | # Generate a 2D PSF 17 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13)) 18 | psf = psf_generator() 19 | 20 | # apply Wiener filter 21 | filter_ = SRichardsonLucy(psf, niter=30, pad=13) 22 | out_image = filter_(image) 23 | 24 | # display results 25 | plt.figure() 26 | plt.title('PSF') 27 | plt.imshow(psf.detach().numpy(), cmap='gray') 28 | 29 | plt.figure() 30 | plt.title('C. elegans original') 31 | plt.imshow(image.detach().numpy(), cmap='gray') 32 | 33 | plt.figure() 34 | plt.title('C. elegans Richardson-Lucy') 35 | plt.imshow(out_image.detach().numpy(), cmap='gray') 36 | 37 | plt.show() 38 | 39 | -------------------------------------------------------------------------------- /examples/spitfire.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spitfire deconvolution 3 | 4 | This example shows how to use the Spitfire deconvolution on a 2D image 5 | """ 6 | 7 | import matplotlib.pyplot as plt 8 | from sdeconv.data import celegans 9 | from sdeconv.psfs import SPSFGaussian 10 | from sdeconv.deconv import Spitfire 11 | 12 | 13 | # load a 2D sample 14 | image = celegans() 15 | 16 | # Generate a 2D PSF 17 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13)) 18 | psf = psf_generator() 19 | 20 | # apply Spitfire filter 21 | filter_ = Spitfire(psf, weight=0.6, reg=0.995, gradient_step=0.01, precision=1e-7, pad=13) 22 | out_image = filter_(image) 23 | 24 | # display results 25 | plt.figure() 26 | plt.title('PSF') 27 | plt.imshow(psf.detach().numpy(), cmap='gray') 28 | 29 | plt.figure() 30 | plt.title('C. elegans original') 31 | plt.imshow(image.detach().numpy(), cmap='gray') 32 | 33 | plt.figure() 34 | plt.title('C. elegans Spitfire') 35 | plt.imshow(out_image.detach().numpy(), cmap='gray') 36 | 37 | plt.show() 38 | 39 | -------------------------------------------------------------------------------- /examples/wiener.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wiener deconvolution 3 | 4 | This example shows how to use the Wiener deconvolution on a 2D image 5 | """ 6 | 7 | import matplotlib.pyplot as plt 8 | from sdeconv.data import celegans 9 | from sdeconv.psfs import SPSFGaussian 10 | from sdeconv.deconv import SWiener 11 | 12 | 13 | # load a 2D sample 14 | image = celegans() 15 | 16 | # Generate a 2D PSF 17 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13)) 18 | psf = psf_generator() 19 | 20 | # apply Wiener filter 21 | filter_ = SWiener(psf, beta=0.005, pad=13) 22 | out_image = filter_(image) 23 | 24 | # display results 25 | plt.figure() 26 | plt.title('PSF') 27 | plt.imshow(psf.detach().numpy(), cmap='gray') 28 | 29 | plt.figure() 30 | plt.title('C. elegans original') 31 | plt.imshow(image.detach().numpy(), cmap='gray') 32 | 33 | plt.figure() 34 | plt.title('C. elegans Wiener') 35 | plt.imshow(out_image.detach().numpy(), cmap='gray') 36 | 37 | plt.show() 38 | 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | 7 | [tool.black] 8 | line-length = 79 9 | 10 | [tool.isort] 11 | profile = "black" 12 | line_length = 79 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy>=1.14.1 2 | numpy==1.26.4 3 | torch==2.2.1 4 | torchvision>=0.17.1 5 | scikit-image>=0.24.0 6 | pylint>=3.2.6 7 | pytest>=8.3.2 8 | sphinx~=8.0.2 9 | sphinx_book_theme~=1.1.3 -------------------------------------------------------------------------------- /sdeconv/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bio-image analysis module for Python 3 | ==================================== 4 | sdeconv is a Python module for microscopy image deconvolution based on the scientific 5 | ecosystem in python (numpy, torch, torchvision, scipy, scikit-image). 6 | See http://github.com/sylvainprigent/sdeconv for complete documentation. 7 | """ 8 | 9 | __version__ = '1.0.4' 10 | 11 | __all__ = [] 12 | -------------------------------------------------------------------------------- /sdeconv/api/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining the SDeconv Application Programming Interface (API)""" 2 | from .factory import SDeconvModuleFactory, SDeconvModuleBuilder 3 | from .api import SDeconvAPI 4 | 5 | __all__ = ['SDeconvAPI'] 6 | -------------------------------------------------------------------------------- /sdeconv/api/api.py: -------------------------------------------------------------------------------- 1 | """Application programing interface for SDeconv""" 2 | 3 | import os 4 | import importlib 5 | import torch 6 | 7 | from ..psfs import SPSFGenerator 8 | from ..deconv import SDeconvFilter 9 | from .factory import SDeconvModuleFactory, SDeconvFactoryError 10 | 11 | 12 | class SDeconvAPI: 13 | """Main API to call SDeconv methods. 14 | 15 | The API implements a factory that instantiate the deconvolution 16 | """ 17 | def __init__(self): 18 | self.psfs = SDeconvModuleFactory() 19 | self.filters = SDeconvModuleFactory() 20 | for name in self._find_modules('deconv'): 21 | mod = importlib.import_module(name) 22 | self.filters.register(mod.metadata['name'], mod.metadata) 23 | for name in self._find_modules('psfs'): 24 | mod = importlib.import_module(name) 25 | self.psfs.register(mod.metadata['name'], mod.metadata) 26 | 27 | @staticmethod 28 | def _find_modules(directory: str) -> list[str]: 29 | """Search sub modules in a directory 30 | 31 | :param directory: Directory to search 32 | :return: The founded module names 33 | """ 34 | path = os.path.abspath(os.path.dirname(__file__)) 35 | path = os.path.dirname(path) 36 | modules = [] 37 | for parent in [directory]: 38 | path_ = os.path.join(path, parent) 39 | for module_path in os.listdir(path_): 40 | if module_path.endswith( 41 | ".py") and 'setup' not in module_path and \ 42 | 'interface' not in module_path and \ 43 | '__init__' not in module_path and not module_path.startswith( 44 | "_"): 45 | modules.append(f"sdeconv.{parent}.{module_path.split('.')[0]}") 46 | return modules 47 | 48 | def filter(self, name: str, **kwargs) -> SDeconvFilter | None: 49 | """Instantiate a deconvolution filter 50 | 51 | :param name: Unique name of the filter to instantiate 52 | :param kwargs: arguments of the filter 53 | :return: An instance of the filter 54 | """ 55 | if name == 'None': 56 | return None 57 | return self.filters.get(name, **kwargs) 58 | 59 | def psf(self, method_name, **kwargs) -> SPSFGenerator: 60 | """Instantiate a psf generator 61 | 62 | :param method_name: name of the PSF generator 63 | :param kwargs: parameters of the PSF generator 64 | :return: An instance of the generator 65 | """ 66 | if method_name == 'None': 67 | return None 68 | filter_ = self.psfs.get(method_name, **kwargs) 69 | if filter_.type == 'SPSFGenerator': 70 | return filter_ 71 | raise SDeconvFactoryError(f'The method {method_name} is not a PSF generator') 72 | 73 | def generate_psf(self, method_name, **kwargs) -> torch.Tensor: 74 | """Generates a Point SPread Function 75 | 76 | :param method_name: Name of the PSF Generator method 77 | :param kwargs: Parameters of the PSF generator 78 | :return: The generated PSF 79 | """ 80 | # print('generate psf args=', **kwargs) 81 | generator = self.psf(method_name, **kwargs) 82 | return generator() 83 | 84 | def deconvolve(self, 85 | image: torch.Tensor, 86 | method_name: str, 87 | plane_by_plane: bool, 88 | **kwargs 89 | ) -> torch.Tensor: 90 | """Run the deconvolution on an image 91 | 92 | :param image: Image to deconvolve. Can be 2D to 5D 93 | :param method_name: Name of the deconvolution method to use 94 | :param plane_by_plane: True to process the image plane by plane 95 | when dimension is more than 2 96 | :param kwargs: Parameters of the deconvolution method 97 | :return: The deblurred image 98 | """ 99 | filter_ = self.filter(method_name, **kwargs) 100 | if filter_.type == 'SDeconvFilter': 101 | return self._deconv_dims(image, filter_, plane_by_plane=plane_by_plane) 102 | raise SDeconvFactoryError(f'The method {method_name} is not a deconvolution filter') 103 | 104 | @staticmethod 105 | def _deconv_3d_by_plane(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor: 106 | """Call the 3D deconvolution plane by plane 107 | 108 | :param image: 3D image tensor 109 | :param filter_: deconvolution class 110 | :return: The deblurred image 111 | """ 112 | out_image = torch.zeros(image.shape) 113 | for i in range(image.shape[0]): 114 | out_image[i, ...] = filter_(image[i, ...]) 115 | return out_image 116 | 117 | @staticmethod 118 | def _deconv_4d(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor: 119 | """Call the 3D+t deconvolution 120 | 121 | :param image: 3D+t image tensor 122 | :param filter_: deconvolution class 123 | :return: The deblurred stack 124 | """ 125 | out_image = torch.zeros(image.shape) 126 | for i in range(image.shape[0]): 127 | out_image[i, ...] = filter_(image[i, ...]) 128 | return out_image 129 | 130 | @staticmethod 131 | def _deconv_4d_by_plane(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor: 132 | """Call the 3D+t deconvolution plane by plane 133 | 134 | :param image: 3D+t image tensor 135 | :param filter_: deconvolution class 136 | :return: The deblurred stack 137 | """ 138 | out_image = torch.zeros(image.shape) 139 | for i in range(image.shape[0]): 140 | for j in range(image.shape[1]): 141 | out_image[i, j, ...] = filter_(image[i, j, ...]) 142 | return out_image 143 | 144 | @staticmethod 145 | def _deconv_5d(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor: 146 | """Call the 3D+t multi-channel deconvolution 147 | 148 | :param image: 3D+t image tensor 149 | :param filter_: deconvolution class 150 | :return: The deblurred hyper-stack 151 | """ 152 | out_image = torch.zeros(image.shape) 153 | for i in range(image.shape[0]): 154 | for j in range(image.shape[1]): 155 | out_image[i, j, ...] = filter_(image[i, j, ...]) 156 | return out_image 157 | 158 | @staticmethod 159 | def _deconv_5d_by_plane(image: torch.Tensor, filter_: SDeconvFilter) -> torch.Tensor: 160 | """Call the 3D+t multi-channel deconvolution plane by plane 161 | 162 | :param image: 3D+t image tensor 163 | :param filter_: deconvolution class 164 | :return: The deblurred hyper-stack 165 | """ 166 | out_image = torch.zeros(image.shape) 167 | for batch in range(image.shape[0]): 168 | for channel in range(image.shape[1]): 169 | for plane in range(image.shape[2]): 170 | out_image[batch, channel, plane, ...] = \ 171 | filter_(image[batch, channel, plane, ...]) 172 | return out_image 173 | 174 | @staticmethod 175 | def _deconv_dims(image: torch.Tensor, 176 | filter_: SDeconvFilter, 177 | plane_by_plane: bool = False): 178 | """Call the deconvolution method depending on the image dimension 179 | 180 | :param image: 3D+t image tensor 181 | :param filter_: deconvolution class 182 | :param plane_by_plane: True to deblur third dimention as independent planes 183 | :return: The deblurred image, stack or hyper-stack 184 | """ 185 | out_image = None 186 | if image.ndim == 2: 187 | out_image = filter_(image) 188 | elif image.ndim == 3 and plane_by_plane: 189 | out_image = SDeconvAPI._deconv_3d_by_plane(image, filter_) 190 | elif image.ndim == 3 and not plane_by_plane: 191 | out_image = filter_(image) 192 | elif image.ndim == 4 and not plane_by_plane: 193 | out_image = SDeconvAPI._deconv_4d(image, filter_) 194 | elif image.ndim == 4 and plane_by_plane: 195 | out_image = SDeconvAPI._deconv_4d_by_plane(image, filter_) 196 | elif image.ndim == 5 and not plane_by_plane: 197 | out_image = SDeconvAPI._deconv_5d(image, filter_) 198 | elif image.ndim == 5 and plane_by_plane: 199 | out_image = SDeconvAPI._deconv_5d_by_plane(image, filter_) 200 | else: 201 | raise SDeconvFactoryError('SDeconv can process only images up to 5 dims') 202 | return out_image 203 | -------------------------------------------------------------------------------- /sdeconv/api/factory.py: -------------------------------------------------------------------------------- 1 | """Implements factory for PSF and deconvolution modules""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from ..psfs import SPSFGenerator 7 | from ..deconv import SDeconvFilter 8 | 9 | 10 | class SDeconvFactoryError(Exception): 11 | """Raised when an error happen when a module is built in the factory""" 12 | 13 | 14 | class SDeconvModuleFactory: 15 | """Factory for SDeconv modules""" 16 | def __init__(self): 17 | self._data = {} 18 | 19 | def register(self, key: str, metadata: dict[str, any]): 20 | """Register a new builder to the factory 21 | 22 | :param key: Name of the module to register 23 | :param metadata: Dictionary containing the filter metadata 24 | """ 25 | self._data[key] = metadata 26 | 27 | def get_parameters(self, key: str) -> dict[str, any]: 28 | """Parameters getter method 29 | 30 | :param key: Name of the module builder 31 | :return: The module parameters 32 | """ 33 | return self._data[key]['inputs'] 34 | 35 | def get_keys(self) -> list[str]: 36 | """Get the names of all the registered modules 37 | 38 | :return: The list of all the registered modules names 39 | """ 40 | return self._data.keys() 41 | 42 | def get(self, key: str, **kwargs) -> SPSFGenerator | SDeconvFilter: 43 | """Get an instance of the SDeconv module 44 | 45 | :param key: Name of the module to load 46 | :param kwargs: Dictionary of args for models parameters (ex: number of channels) 47 | :return: The instance of the module 48 | """ 49 | metadata = self._data.get(key) 50 | if not metadata: 51 | raise ValueError(key) 52 | builder = SDeconvModuleBuilder() 53 | return builder.get_instance(metadata, kwargs) 54 | 55 | 56 | class SDeconvModuleBuilder: 57 | """Interface for a SDeconv module builder 58 | 59 | The builder is used by the factory to instantiate a module 60 | 61 | """ 62 | def __init__(self): 63 | self._instance = None 64 | 65 | def get_instance(self, metadata: dict[str, any], args: dict) -> SPSFGenerator | SDeconvFilter: 66 | """Get the instance of the module 67 | 68 | :param metadata: Metadata of the module 69 | :param args: Argument to pass for the module instantiation 70 | :return: Instance of the module 71 | """ 72 | # check the args 73 | instance_args = {} 74 | for key, value in metadata['inputs'].items(): 75 | val = self._get_arg(value, key, args) 76 | instance_args[key] = val 77 | return metadata['class'](**instance_args) 78 | 79 | def _get_arg(self, param_metadata: dict[str, any], key: str, args: dict[str, any]) -> any: 80 | """Retrieve the value of a parameter with a type check 81 | 82 | :param param_metadata: Metadata of the parameter, 83 | :param key: Name of the parameter, 84 | :param args: Value of the parameters 85 | :return: The value of the parameter if check is successful 86 | """ 87 | type_ = param_metadata['type'] 88 | range_ = None 89 | if 'range' in param_metadata: 90 | range_ = param_metadata['range'] 91 | arg_value = None 92 | if type_ == 'float': 93 | arg_value = self.get_arg_float(args, key, param_metadata['default'], 94 | range_) 95 | elif type_ == 'int': 96 | arg_value = self.get_arg_int(args, key, param_metadata['default'], 97 | range_) 98 | elif type_ == 'bool': 99 | arg_value = self.get_arg_bool(args, key, param_metadata['default'], 100 | range_) 101 | elif type_ == 'str': 102 | arg_value = self.get_arg_str(args, key, param_metadata['default']) 103 | elif type_ is torch.Tensor: 104 | arg_value = self.get_arg_array(args, key, param_metadata['default']) 105 | elif type_ == 'select': 106 | arg_value = self.get_arg_select(args, key, param_metadata['values']) 107 | elif 'zyx' in type_: 108 | arg_value = self.get_arg_list(args, key, param_metadata['default']) 109 | return arg_value 110 | 111 | @staticmethod 112 | def _error_message(key: str, value_type: str, value_range: tuple | None): 113 | """Throw an exception if an input parameter is not correct 114 | 115 | :param key: Input parameter key 116 | :param value_type: String naming the input type (int, float...) 117 | :param value_range: Min and max values of the parameter 118 | """ 119 | range_message = '' 120 | if value_range and len(value_range) == 2: 121 | range_message = f' in range [{str(value_range[0]), str(value_range[1])}]' 122 | 123 | message = f'Parameter {key} must be of type `{value_type}` {range_message}' 124 | return message 125 | 126 | def get_arg_int(self, 127 | args: dict[str, any], 128 | key: str, 129 | default_value: int, 130 | value_range: tuple = None 131 | ) -> int: 132 | """Get the value of a parameter from the args list 133 | 134 | The default value of the parameter is returned if the 135 | key is not in args 136 | 137 | :param args: Dictionary of the input args 138 | :param key: Name of the parameters 139 | :param default_value: Default value of the parameter 140 | :param value_range: Min and max value of the parameter 141 | :return: The arg value 142 | """ 143 | value = default_value 144 | if isinstance(args, dict) and key in args: 145 | # cast 146 | try: 147 | value = int(args[key]) 148 | except ValueError as exc: 149 | raise SDeconvFactoryError(self._error_message(key, 'int', value_range)) from exc 150 | # test range 151 | if value_range and len(value_range) == 2: 152 | if value > value_range[1] or value < value_range[0]: 153 | raise SDeconvFactoryError(self._error_message(key, 'int', value_range)) 154 | return value 155 | 156 | def get_arg_float(self, 157 | args: dict[str, any], 158 | key: str, 159 | default_value: float, 160 | value_range: tuple = None 161 | ) -> str: 162 | """Get the value of a parameter from the args list 163 | 164 | The default value of the parameter is returned if the 165 | key is not in args 166 | 167 | :param args: Dictionary of the input args 168 | :param key: Name of the parameters 169 | :param default_value: Default value of the parameter 170 | :param value_range: Min and max value of the parameter 171 | :return: The arg value 172 | """ 173 | value = default_value 174 | if isinstance(args, dict) and key in args: 175 | # cast 176 | try: 177 | value = float(args[key]) 178 | except ValueError as exc: 179 | raise SDeconvFactoryError(self._error_message(key, 'float', value_range)) from exc 180 | # test range 181 | if value_range and len(value_range) == 2: 182 | if value > value_range[1] or value < value_range[0]: 183 | raise SDeconvFactoryError(self._error_message(key, 'float', value_range)) 184 | return value 185 | 186 | def get_arg_str(self, 187 | args: dict[str, any], 188 | key: str, 189 | default_value: str, 190 | value_range: tuple = None 191 | ) -> str: 192 | """Get the value of a parameter from the args list 193 | 194 | The default value of the parameter is returned if the 195 | key is not in args 196 | 197 | :param args: Dictionary of the input args 198 | :param key: Name of the parameters 199 | :param default_value: Default value of the parameter 200 | :param value_range: Min and max value of the parameter 201 | :return: The arg value 202 | """ 203 | value = default_value 204 | if isinstance(args, dict) and key in args: 205 | # cast 206 | try: 207 | value = str(args[key]) 208 | except ValueError as exc: 209 | raise SDeconvFactoryError(self._error_message(key, 'str', value_range)) from exc 210 | # test range 211 | if value_range and len(value_range) == 2: 212 | if value > value_range[1] or value < value_range[0]: 213 | raise SDeconvFactoryError(self._error_message(key, 'str', value_range)) 214 | return value 215 | 216 | @staticmethod 217 | def _str2bool(value: str) -> bool: 218 | """Convert a string to a boolean 219 | 220 | :param value: String to convert 221 | :return: The boolean conversion 222 | """ 223 | return value.lower() in ("yes", "true", "t", "1") 224 | 225 | def get_arg_bool(self, 226 | args: dict[str, any], 227 | key: str, 228 | default_value: bool, 229 | value_range: tuple = None 230 | ) -> bool: 231 | """Get the value of a parameter from the args list 232 | 233 | The default value of the parameter is returned if the 234 | key is not in args 235 | 236 | :param args: Dictionary of the input args 237 | :param key: Name of the parameters 238 | :param default_value: Default value of the parameter 239 | :param value_range: Min and max value of the parameter 240 | :return: The arg value 241 | """ 242 | value = default_value 243 | # cast 244 | if isinstance(args, dict) and key in args: 245 | if isinstance(args[key], str): 246 | value = SDeconvModuleBuilder._str2bool(args[key]) 247 | elif isinstance(args[key], bool): 248 | value = args[key] 249 | else: 250 | raise SDeconvFactoryError(self._error_message(key, 'bool', value_range)) 251 | # test range 252 | if value_range and len(value_range) == 2: 253 | if value > value_range[1] or value < value_range[0]: 254 | raise SDeconvFactoryError(self._error_message(key, 'bool', value_range)) 255 | return value 256 | 257 | def get_arg_array(self, 258 | args: dict[str, any], 259 | key: str, 260 | default_value: torch.Tensor 261 | ) -> torch.Tensor: 262 | """Get the value of a parameter from the args list 263 | 264 | The default value of the parameter is returned if the 265 | key is not in args 266 | 267 | :param args: Dictionary of the input args 268 | :param key: Name of the parameters 269 | :param default_value: Default value of the parameter 270 | :return: The arg value 271 | """ 272 | value = default_value 273 | if isinstance(args, dict) and key in args: 274 | if isinstance(args[key], torch.Tensor): 275 | value = args[key] 276 | elif isinstance(args[key], np.ndarray): 277 | value = torch.Tensor(args[key]) 278 | else: 279 | raise SDeconvFactoryError(self._error_message(key, 'array', None)) 280 | return value 281 | 282 | def get_arg_list(self, 283 | args: dict[str, any], 284 | key: str, 285 | default_value: list 286 | ) -> list: 287 | """Get the value of a parameter from the args list 288 | 289 | The default value of the parameter is returned if the 290 | key is not in args 291 | 292 | :param args: Dictionary of the input args 293 | :param key: Name of the parameters 294 | :param default_value: Default value of the parameter 295 | :return: The arg value 296 | """ 297 | value = default_value 298 | if isinstance(args, dict) and key in args: 299 | if isinstance(args[key], (list, tuple)): 300 | value = args[key] 301 | else: 302 | raise SDeconvFactoryError(self._error_message(key, 'list', None)) 303 | return value 304 | 305 | def get_arg_select(self, 306 | args: dict[str, any], 307 | key: str, 308 | values: list 309 | ) -> str: 310 | """Get the value of a parameter from the args list as a select input 311 | 312 | :param args: Dictionary of the input args 313 | :param key: Name of the parameters 314 | :param values: Possible values in select input 315 | :return: The arg value 316 | """ 317 | if isinstance(args, dict) and key in args: 318 | value = str(args[key]) 319 | for val in values: 320 | if str(val) == value: 321 | return val 322 | raise SDeconvFactoryError(self._error_message(key, 'select', None)) 323 | -------------------------------------------------------------------------------- /sdeconv/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/cli/__init__.py -------------------------------------------------------------------------------- /sdeconv/cli/sdeconv.py: -------------------------------------------------------------------------------- 1 | """Command line interface module for 2D to 5D image deconvolution""" 2 | 3 | import argparse 4 | from skimage.io import imread, imsave 5 | import torch 6 | import numpy as np 7 | 8 | from sdeconv.api import SDeconvAPI 9 | 10 | 11 | def add_args_to_parser(parser: argparse.ArgumentParser, api: SDeconvAPI): 12 | """Add all the parameters available in the API as an argument in the parser 13 | 14 | :param parser: Argument parser instance 15 | :param api: SDeconv Application Programming Interface instance 16 | 17 | """ 18 | for filter_name in api.filters.get_keys(): 19 | params = api.filters.get_parameters(filter_name) 20 | for key, value in params.items(): 21 | parser.add_argument(f"--{key}", help=value['help'], default=value['default']) 22 | 23 | 24 | def main(): 25 | """Command line interface entrypoint function""" 26 | parser = argparse.ArgumentParser(description='2D to 5D image deconvolution', 27 | conflict_handler='resolve') 28 | 29 | parser.add_argument('-i', '--input', help='Input image file', default='.tif') 30 | parser.add_argument('-m', '--method', help='Deconvolution method', default='wiener') 31 | parser.add_argument('-o', '--output', help='Output image file', default='.tif') 32 | parser.add_argument('-p', '--plane', help='Plane by plane deconvolution', default=False) 33 | 34 | api = SDeconvAPI() 35 | add_args_to_parser(parser, api) 36 | args = parser.parse_args() 37 | 38 | args_dict = vars(args) 39 | 40 | image = torch.Tensor(np.float32(imread(args.input))) 41 | out_image = api.deconvolve(image, args.method, args.plane, **args_dict) 42 | imsave(args.output, out_image.detach().numpy()) 43 | -------------------------------------------------------------------------------- /sdeconv/cli/spsf.py: -------------------------------------------------------------------------------- 1 | """Command line interface module for Point Spread Function generator""" 2 | 3 | import argparse 4 | from skimage.io import imsave 5 | 6 | from sdeconv.api import SDeconvAPI 7 | 8 | 9 | def add_args_to_parser(parser: argparse.ArgumentParser, api: SDeconvAPI): 10 | """Add all the parameters available in the API as an argument in the parser 11 | 12 | :param parser: Argument parser instance 13 | :param api: SDeconv Application Programming Interface instance 14 | """ 15 | for filter_name in api.filters.get_keys(): 16 | params = api.psfs.get_parameters(filter_name) 17 | for key, value in params.items(): 18 | parser.add_argument(f"--{key}", help=value['help'], default=value['default']) 19 | 20 | 21 | def main(): 22 | """Command line interface entrypoint function""" 23 | parser = argparse.ArgumentParser(description='2D to 5D image deconvolution', 24 | conflict_handler='resolve') 25 | 26 | parser.add_argument('-m', '--method', help='Deconvolution method', default='wiener') 27 | parser.add_argument('-o', '--output', help='Output image file', default='.tif') 28 | 29 | api = SDeconvAPI() 30 | add_args_to_parser(parser, api) 31 | args = parser.parse_args() 32 | 33 | args_dict = vars(args) 34 | out_image = api.psf(args.method, **args_dict) 35 | imsave(args.output, out_image.detach().numpy()) 36 | -------------------------------------------------------------------------------- /sdeconv/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core module for the SDeconv library. It implements settings and observer/observable design 2 | patterns""" 3 | from ._observers import SObservable, SObserver, SObserverConsole 4 | from ._settings import SSettings, SSettingsContainer 5 | from ._timing import seconds2str 6 | from ._progress_logger import SConsoleLogger 7 | 8 | 9 | __all__ = ['SSettings', 10 | 'SSettingsContainer', 11 | 'SObservable', 12 | 'SObserver', 13 | 'SObserverConsole', 14 | 'seconds2str', 15 | 'SConsoleLogger'] 16 | -------------------------------------------------------------------------------- /sdeconv/core/_observers.py: -------------------------------------------------------------------------------- 1 | """Module that implements the observer/observable design pattern to display progress""" 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class SObserver(ABC): 6 | """Interface of observer to notify progress 7 | 8 | An observer must implement the progress and message 9 | """ 10 | @abstractmethod 11 | def notify(self, message: str): 12 | """Notify a progress message 13 | 14 | :param message: Progress message 15 | """ 16 | raise NotImplementedError('SObserver is abstract') 17 | 18 | @abstractmethod 19 | def progress(self, value: int): 20 | """Notify progress value 21 | 22 | :param value: Progress value in [0, 100] 23 | """ 24 | raise NotImplementedError('SObserver is abstract') 25 | 26 | 27 | class SObservable: 28 | """Interface for data processing class 29 | 30 | The observable class can notify the observers for progress 31 | """ 32 | def __init__(self): 33 | self._observers = [] 34 | 35 | def add_observer(self, observer: SObserver): 36 | """Add an observer 37 | 38 | :param observer: Observer instance to add 39 | """ 40 | self._observers.append(observer) 41 | 42 | def notify(self, message: str): 43 | """Notify progress to observers 44 | 45 | :param message: Progress message 46 | """ 47 | for obs in self._observers: 48 | obs.notify(message) 49 | 50 | def progress(self, value): 51 | """Notify progress to observers 52 | 53 | :param value: Progress value in [0, 100] 54 | """ 55 | for obs in self._observers: 56 | obs.progress(value) 57 | 58 | 59 | class SObserverConsole(SObserver): 60 | """print message and progress to console""" 61 | 62 | def notify(self, message: str): 63 | """Print message 64 | 65 | :param message: Progress message 66 | """ 67 | print(message) 68 | 69 | def progress(self, value: str): 70 | """Print progress 71 | 72 | :param value: Progress value in [0, 100] 73 | """ 74 | print('progress:', value, '%') 75 | -------------------------------------------------------------------------------- /sdeconv/core/_progress_logger.py: -------------------------------------------------------------------------------- 1 | """Set of classes to log a workflow run""" 2 | COLOR_WARNING = '\033[93m' 3 | COLOR_ERROR = '\033[91m' 4 | COLOR_GREEN = '\033[92m' 5 | COLOR_ENDC = '\033[0m' 6 | 7 | 8 | class SProgressLogger: 9 | """Default logger 10 | 11 | A logger is used by a workflow to print the warnings, errors and progress. 12 | A logger can be used to print in the console or in a log file 13 | 14 | """ 15 | def __init__(self): 16 | self.prefix = '' 17 | 18 | def new_line(self): 19 | """Print a new line in the log""" 20 | raise NotImplementedError() 21 | 22 | def message(self, message: str): 23 | """Log a default message 24 | 25 | :param message: Message to log 26 | """ 27 | raise NotImplementedError() 28 | 29 | def error(self, message: str): 30 | """Log an error message 31 | 32 | :param message: Message to log 33 | """ 34 | raise NotImplementedError() 35 | 36 | def warning(self, message: str): 37 | """Log a warning 38 | 39 | :param message: Message to log 40 | """ 41 | raise NotImplementedError() 42 | 43 | def progress(self, iteration: int, total: int, prefix: str, suffix: str): 44 | """Log a progress 45 | 46 | :param iteration: Current iteration 47 | :param total: Total number of iteration 48 | :param prefix: Text to print before the progress 49 | :param suffix: Text to print after the message 50 | """ 51 | raise NotImplementedError() 52 | 53 | def close(self): 54 | """Close the logger""" 55 | raise NotImplementedError() 56 | 57 | 58 | class SProgressObservable: 59 | """Observable pattern 60 | 61 | This pattern allows to set multiple progress logger to 62 | one workflow 63 | 64 | """ 65 | def __init__(self): 66 | self._loggers = [] 67 | 68 | def set_prefix(self, prefix: str): 69 | """Set the prefix for all loggers 70 | 71 | The prefix is a printed str ad the beginning of each 72 | line of the logger 73 | 74 | :param prefix: Prefix content 75 | """ 76 | for logger in self._loggers: 77 | logger.prefix = prefix 78 | 79 | def add_logger(self, logger: SProgressLogger): 80 | """Add a logger to the observer 81 | 82 | :param logger: Logger to add to the observer 83 | """ 84 | self._loggers.append(logger) 85 | 86 | def new_line(self): 87 | """Print a new line in the loggers""" 88 | for logger in self._loggers: 89 | logger.new_line() 90 | 91 | def message(self, message: str): 92 | """Log a default message 93 | 94 | :param message: Message to log 95 | """ 96 | for logger in self._loggers: 97 | logger.message(message) 98 | 99 | def error(self, message: str): 100 | """Log an error message 101 | 102 | :param message: Message to log 103 | """ 104 | for logger in self._loggers: 105 | logger.error(message) 106 | 107 | def warning(self, message: str): 108 | """Log a warning message 109 | 110 | :param message: Message to log 111 | """ 112 | for logger in self._loggers: 113 | logger.warning(message) 114 | 115 | def progress(self, iteration: int, total: int, prefix: str, suffix: str): 116 | """Log a progress 117 | 118 | :param iteration: Current iteration 119 | :param total: Total number of iteration 120 | :param prefix: Text to print before the progress 121 | :param suffix: Text to print after the message 122 | """ 123 | for logger in self._loggers: 124 | logger.progress(iteration, total, prefix, suffix) 125 | 126 | def close(self): 127 | """Close the loggers""" 128 | for logger in self._loggers: 129 | logger.close() 130 | 131 | 132 | class SConsoleLogger(SProgressLogger): 133 | """Console logger displaying a progress bar 134 | 135 | The progress bar display the basic information of a batch loop (loss, 136 | batch id, time/remaining time) 137 | 138 | """ 139 | def __init__(self): 140 | super().__init__() 141 | self.decimals = 1 142 | self.print_end = "\r" 143 | self.length = 100 144 | self.fill = '█' 145 | 146 | def new_line(self): 147 | print(f"{self.prefix}:\n") 148 | 149 | def message(self, message): 150 | print(f'{self.prefix}: {message}') 151 | 152 | def error(self, message): 153 | print(f'{COLOR_ERROR}{self.prefix} ERROR: ' 154 | f'{message}{COLOR_ENDC}') 155 | 156 | def warning(self, message): 157 | print(f'{COLOR_WARNING}{self.prefix} WARNING: ' 158 | f'{message}{COLOR_ENDC}') 159 | 160 | def progress(self, iteration, total, prefix, suffix): 161 | percent = ("{0:." + str(self.decimals) + "f}").format( 162 | 100 * (iteration / float(total))) 163 | filled_length = int(self.length * iteration // total) 164 | bar_ = self.fill * filled_length + ' ' * (self.length - filled_length) 165 | print(f'\r{prefix} {percent}% |{bar_}| {suffix}', 166 | end=self.print_end) 167 | 168 | def close(self): 169 | pass 170 | -------------------------------------------------------------------------------- /sdeconv/core/_settings.py: -------------------------------------------------------------------------------- 1 | """Implements setting management""" 2 | import torch 3 | 4 | 5 | class SSettingsContainer: 6 | """Container for the SDeconv library settings""" 7 | def __init__(self): 8 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | 10 | def get_device(self) -> str: 11 | """Returns the device name for torch""" 12 | return self.device 13 | 14 | def print(self): 15 | """Display the settings in the console""" 16 | print(f'SDeconv settings: device={self.device}') 17 | 18 | 19 | class SSettings: 20 | """Singleton to access the Settings container 21 | 22 | :raises: Exception: if multiple instantiation of the Settings container is tried 23 | """ 24 | __instance = None 25 | 26 | def __init__(self): 27 | """ Virtually private constructor. """ 28 | if SSettings.__instance is not None: 29 | raise RuntimeError("Settings container can be initialized only once!") 30 | SSettings.__instance = SSettingsContainer() 31 | 32 | @staticmethod 33 | def instance(): 34 | """ Static access method to the Config. """ 35 | if SSettings.__instance is None: 36 | SSettings.__instance = SSettingsContainer() 37 | return SSettings.__instance 38 | 39 | @staticmethod 40 | def print(): 41 | """Print the settings to the console""" 42 | SSettings.instance().print() 43 | -------------------------------------------------------------------------------- /sdeconv/core/_timing.py: -------------------------------------------------------------------------------- 1 | """Module to implement useflkl function for time formating""" 2 | 3 | 4 | def seconds2str(sec: int) -> str: 5 | """Convert seconds to printable string in hh:mm:ss 6 | 7 | :param sec: Duration in seconds 8 | :return: human-readable time string 9 | """ 10 | sec_value = sec % (24 * 3600) 11 | hour_value = sec_value // 3600 12 | sec_value %= 3600 13 | min_value = sec_value // 60 14 | sec_value %= 60 15 | if hour_value > 0: 16 | return f"{hour_value:02d}:{min_value:02d}:{sec_value:02d}" 17 | return f"{min_value:02d}:{sec_value:02d}" 18 | -------------------------------------------------------------------------------- /sdeconv/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Module that provides sample images""" 2 | 3 | import os.path as osp 4 | import os 5 | import numpy as np 6 | import torch 7 | from skimage.io import imread 8 | from sdeconv.core import SSettings 9 | 10 | 11 | __all__ = ['celegans', 12 | 'pollen', 13 | 'pollen_poison_noise_blurred', 14 | 'pollen_psf'] 15 | 16 | legacy_data_dir = osp.abspath(osp.dirname(__file__)) 17 | 18 | 19 | def _fetch(data_filename: str) -> str: 20 | """Fetch a given data file from the local data dir. 21 | 22 | This function provides the path location of the data file given 23 | its name in the scikit-image repository. 24 | 25 | :param data_filename: Name of the file in the scikit-bioimaging data dir, 26 | :return: Path of the local file as a python string 27 | """ 28 | 29 | filepath = os.path.join(legacy_data_dir, data_filename) 30 | 31 | if os.path.isfile(filepath): 32 | return filepath 33 | raise FileExistsError("Cannot find the file:", filepath) 34 | 35 | 36 | def _load(filename: str) -> np.ndarray: 37 | """Load an image file located in the data directory. 38 | 39 | :param: filename: Path of the file to load. 40 | :return: The data loaded in a numpy array 41 | """ 42 | return torch.tensor(np.float32(imread(_fetch(filename)))).to(SSettings.instance().device) 43 | 44 | 45 | def celegans(): 46 | """2D confocal (Airyscan) image of a c. elegans intestine. 47 | 48 | :return: (310, 310) uint16 ndarray 49 | """ 50 | return _load("celegans.tif")[3:-3, 3:-3] 51 | 52 | 53 | def pollen(): 54 | """3D Pollen image. 55 | 56 | :return: (32, 256, 256) uint16 ndarray 57 | """ 58 | return _load("pollen.tif") 59 | 60 | 61 | def pollen_poison_noise_blurred(): 62 | """3D Pollen image corrupted with Poisson noise and blurred . 63 | 64 | :return: (32, 256, 256) uint16 ndarray 65 | """ 66 | return _load("pollen_poisson_noise_blurred.tif") 67 | 68 | 69 | def pollen_psf(): 70 | """3D PSF to deblur the pollen image. 71 | 72 | :return: (32, 256, 256) uint16 ndarray 73 | """ 74 | psf = _load("pollen_psf.tif") 75 | return psf / torch.sum(psf) 76 | -------------------------------------------------------------------------------- /sdeconv/data/celegans.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/data/celegans.tif -------------------------------------------------------------------------------- /sdeconv/data/pollen.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/data/pollen.tif -------------------------------------------------------------------------------- /sdeconv/data/pollen_poisson_noise_blurred.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/data/pollen_poisson_noise_blurred.tif -------------------------------------------------------------------------------- /sdeconv/data/pollen_psf.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/sdeconv/data/pollen_psf.tif -------------------------------------------------------------------------------- /sdeconv/deconv/__init__.py: -------------------------------------------------------------------------------- 1 | """Module that implements the image deconvolution algorithms""" 2 | from .interface import SDeconvFilter 3 | from .wiener import SWiener, swiener 4 | from .richardson_lucy import SRichardsonLucy, srichardsonlucy 5 | from .spitfire import Spitfire, spitfire 6 | from .noise2void import Noise2VoidDeconv 7 | from .self_supervised_nn import SelfSupervisedNNDeconv 8 | from .nn_deconv import NNDeconv 9 | 10 | __all__ = ['SDeconvFilter', 11 | 'SWiener', 12 | 'swiener', 13 | 'SRichardsonLucy', 14 | 'srichardsonlucy', 15 | 'Spitfire', 16 | 'spitfire', 17 | 'Noise2VoidDeconv', 18 | 'SelfSupervisedNNDeconv', 19 | 'NNDeconv' 20 | ] 21 | -------------------------------------------------------------------------------- /sdeconv/deconv/_datasets.py: -------------------------------------------------------------------------------- 1 | """This module implements the datasets for deep learning training""" 2 | from typing import Callable 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from skimage.io import imread 8 | 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class SelfSupervisedPatchDataset(Dataset): 13 | """Gray scaled image patched dataset for Self supervised learning 14 | 15 | :param images_dir: Directory containing the training images 16 | :param patch_size: Size of the squared training patches 17 | :param stride: Stride used to extract overlapping patches from images 18 | :param transform: Transformation to images before model 19 | """ 20 | def __init__(self, 21 | images_dir: Path, 22 | patch_size: int = 40, 23 | stride: int = 10, 24 | transform: Callable = None): 25 | super().__init__() 26 | self.images_dir = Path(images_dir) 27 | self.patch_size = patch_size 28 | self.stride = stride 29 | self.transform = transform 30 | 31 | source_images = sorted(self.images_dir.glob('*.*')) 32 | 33 | self.nb_images = len(source_images) 34 | image = imread(source_images[0]) 35 | self.n_patches = self.nb_images * ((image.shape[0] - patch_size) // stride) * \ 36 | ((image.shape[1] - patch_size) // stride) 37 | print('num patches = ', self.n_patches) 38 | 39 | # Load all the images in a list 40 | self.images_data = [] 41 | for source in source_images: 42 | self.images_data.append(np.float32(imread(source))) 43 | 44 | def __len__(self): 45 | return self.n_patches 46 | 47 | def __getitem__(self, idx): 48 | nb_patch_per_img = self.n_patches // self.nb_images 49 | 50 | img_number = idx // nb_patch_per_img 51 | 52 | img_np = self.images_data[img_number] 53 | 54 | nb_patch_w = (img_np.shape[1] - self.patch_size) // self.stride 55 | idx = idx % nb_patch_per_img 56 | i, j = idx // nb_patch_w, idx % nb_patch_w 57 | img_patch = \ 58 | img_np[i * self.stride:i * self.stride + self.patch_size, 59 | j * self.stride:j * self.stride + self.patch_size] 60 | 61 | if self.transform: 62 | img_patch = self.transform(torch.Tensor(img_patch)) 63 | else: 64 | img_patch = torch.Tensor(img_patch).float() 65 | 66 | return ( 67 | img_patch.view(1, *img_patch.shape), 68 | str(idx) 69 | ) 70 | 71 | 72 | class SelfSupervisedDataset(Dataset): 73 | """Gray scaled image dataset for Self supervised learning 74 | 75 | :param images_dir: Directory containing the training images 76 | :param transform: Transformation to images before model 77 | """ 78 | def __init__(self, 79 | images_dir: Path, 80 | transform: Callable = None): 81 | super().__init__() 82 | self.images_dir = Path(images_dir) 83 | self.transform = transform 84 | 85 | self.source_images = sorted(self.images_dir.glob('*.*')) 86 | 87 | self.nb_images = len(self.source_images) 88 | 89 | # Load all the images in a list 90 | self.images_data = [] 91 | for source in self.source_images: 92 | self.images_data.append(np.float32(imread(source))) 93 | 94 | def __len__(self): 95 | return self.nb_images 96 | 97 | def __getitem__(self, idx): 98 | 99 | img_patch = self.images_data[idx] 100 | if self.transform: 101 | img_patch = self.transform(torch.Tensor(img_patch)) 102 | else: 103 | img_patch = torch.Tensor(img_patch).float() 104 | 105 | return ( 106 | img_patch.view(1, *img_patch.shape), 107 | self.source_images[idx].stem 108 | ) 109 | 110 | 111 | class RestorationDataset(Dataset): 112 | """Dataset for image restoration from full images 113 | 114 | All the training images must be saved as individual images in source and 115 | target folders. 116 | 117 | :param source_dir: Path of the noisy training images (or patches) 118 | :param target_dir: Path of the ground truth images (or patches) 119 | :param transform: Transformation to apply to the image before model call 120 | """ 121 | def __init__(self, 122 | source_dir: str | Path, 123 | target_dir: str | Path, 124 | transform: Callable = None): 125 | super().__init__() 126 | self.device = None 127 | self.source_dir = Path(source_dir) 128 | self.target_dir = Path(target_dir) 129 | self.transform = transform 130 | 131 | self.source_images = sorted(self.source_dir.glob('*.*')) 132 | self.target_images = sorted(self.target_dir.glob('*.*')) 133 | if len(self.source_images) != len(self.target_images): 134 | raise ValueError("Source and target dirs are not the same length") 135 | 136 | self.nb_images = len(self.source_images) 137 | 138 | def __len__(self): 139 | return self.nb_images 140 | 141 | def __getitem__(self, idx): 142 | source_patch = np.float32(imread(self.source_images[idx])) 143 | target_patch = np.float32(imread(self.target_images[idx])) 144 | 145 | # numpy to tensor 146 | source_patch = torch.from_numpy(source_patch).view(1, *source_patch.shape).float() 147 | target_patch = torch.from_numpy(target_patch).view(1, *target_patch.shape).float() 148 | 149 | # data augmentation 150 | if self.transform: 151 | both_images = torch.cat((source_patch.unsqueeze(0), target_patch.unsqueeze(0)), 0) 152 | transformed_images = self.transform(both_images) 153 | source_patch = transformed_images[0, ...] 154 | target_patch = transformed_images[1, ...] 155 | 156 | return source_patch, target_patch, self.source_images[idx].stem 157 | 158 | 159 | class RestorationPatchDataset(Dataset): 160 | """Dataset for image restoration using patches 161 | 162 | All the training images must be saved as individual images in source and 163 | target folders. 164 | 165 | :param source_dir: Path of the noisy training images (or patches) 166 | :param target_dir: Path of the ground truth images (or patches) 167 | :param patch_size: Size of the patches (width=height) 168 | :param stride: Length of the patch overlapping 169 | :param transform: Transformation to apply to the image before model call 170 | 171 | """ 172 | def __init__(self, 173 | source_dir: str | Path, 174 | target_dir: str | Path, 175 | patch_size: int = 40, 176 | stride: int = 10, 177 | transform: Callable = None): 178 | super().__init__() 179 | self.device = None 180 | self.source_dir = Path(source_dir) 181 | self.target_dir = Path(target_dir) 182 | self.patch_size = patch_size 183 | self.stride = stride 184 | self.transform = transform 185 | 186 | self.source_images = sorted(self.source_dir.glob('*.*')) 187 | self.target_images = sorted(self.target_dir.glob('*.*')) 188 | if len(self.source_images) != len(self.target_images): 189 | raise ValueError("Source and target dirs are not the same length") 190 | 191 | self.nb_images = len(self.source_images) 192 | image = imread(self.source_images[0]) 193 | self.n_patches = self.nb_images * ((image.shape[0] - patch_size) // stride) * \ 194 | ((image.shape[1] - patch_size) // stride) 195 | 196 | def __len__(self): 197 | return self.n_patches 198 | 199 | def __getitem__(self, idx): 200 | # Crop a patch from original image 201 | nb_patch_per_img = self.n_patches // self.nb_images 202 | 203 | elt = self.source_images[idx // nb_patch_per_img] 204 | 205 | img_source_np = \ 206 | np.float32(imread(self.source_dir / elt)) 207 | img_target_np = \ 208 | np.float32(imread(self.target_dir / elt)) 209 | 210 | nb_patch_w = (img_source_np.shape[1] - self.patch_size) // self.stride 211 | idx = idx % nb_patch_per_img 212 | i, j = idx // nb_patch_w, idx % nb_patch_w 213 | source_patch = \ 214 | img_source_np[i * self.stride:i * self.stride + self.patch_size, 215 | j * self.stride:j * self.stride + self.patch_size] 216 | target_patch = \ 217 | img_target_np[i * self.stride:i * self.stride + self.patch_size, 218 | j * self.stride:j * self.stride + self.patch_size] 219 | 220 | # numpy to tensor 221 | source_patch = torch.from_numpy(source_patch).view(1, *source_patch.shape).float() 222 | target_patch = torch.from_numpy(target_patch).view(1, *target_patch.shape).float() 223 | 224 | # data augmentation 225 | if self.transform: 226 | both_images = torch.cat((source_patch.unsqueeze(0), target_patch.unsqueeze(0)), 0) 227 | transformed_images = self.transform(both_images) 228 | source_patch = transformed_images[0, ...] 229 | target_patch = transformed_images[1, ...] 230 | 231 | # to tensor 232 | return (source_patch, 233 | target_patch, 234 | str(idx) 235 | ) 236 | 237 | 238 | class RestorationPatchDatasetLoad(Dataset): 239 | """Dataset for image restoration using patches preloaded in memory 240 | 241 | All the training images must be saved as individual images in source and 242 | target folders. 243 | This version load all the dataset in the CPU 244 | 245 | :param source_dir: Path of the noisy training images (or patches) 246 | :param target_dir: Path of the ground truth images (or patches) 247 | :param patch_size: Size of the patches (width=height) 248 | :param stride: Length of the patch overlapping 249 | :param transform: Transformation to apply to the image before model call 250 | 251 | """ 252 | # pylint: disable=too-many-instance-attributes 253 | # pylint: disable=too-many-arguments 254 | def __init__(self, 255 | source_dir: str | Path, 256 | target_dir: str | Path, 257 | patch_size: int = 40, 258 | stride: int = 10, 259 | transform: Callable = None): 260 | super().__init__() 261 | self.source_dir = Path(source_dir) 262 | self.target_dir = Path(target_dir) 263 | self.patch_size = patch_size 264 | self.stride = stride 265 | self.transform = transform 266 | 267 | self.source_images = sorted(self.source_dir.glob('*.*')) 268 | self.target_images = sorted(self.target_dir.glob('*.*')) 269 | 270 | if len(self.source_images) != len(self.target_images): 271 | raise ValueError("Source and target dirs are not the same length") 272 | 273 | self.nb_images = len(self.source_images) 274 | image = imread(self.source_images[0]) 275 | self.n_patches = self.nb_images * ((image.shape[0] - patch_size) // stride) * \ 276 | ((image.shape[1] - patch_size) // stride) 277 | print('num patches = ', self.n_patches) 278 | 279 | # Load all the images in a list 280 | self.source_data = [] 281 | for source in self.source_images: 282 | self.source_data.append(np.float32(imread(source))) 283 | self.target_data = [] 284 | for target in self.target_images: 285 | self.target_data.append(np.float32(imread(target))) 286 | 287 | def __len__(self): 288 | return self.n_patches 289 | 290 | def __getitem__(self, idx): 291 | # Crop a patch from original image 292 | nb_patch_per_img = self.n_patches // self.nb_images 293 | 294 | img_number = idx // nb_patch_per_img 295 | 296 | img_source_np = self.source_data[img_number] 297 | img_target_np = self.target_data[img_number] 298 | 299 | nb_patch_w = (img_source_np.shape[1] - self.patch_size) // self.stride 300 | idx = idx % nb_patch_per_img 301 | i, j = idx // nb_patch_w, idx % nb_patch_w 302 | source_patch = \ 303 | img_source_np[i * self.stride:i * self.stride + self.patch_size, 304 | j * self.stride:j * self.stride + self.patch_size] 305 | target_patch = \ 306 | img_target_np[i * self.stride:i * self.stride + self.patch_size, 307 | j * self.stride:j * self.stride + self.patch_size] 308 | 309 | # numpy to tensor 310 | source_patch = torch.from_numpy(source_patch).view(1, *source_patch.shape).float() 311 | target_patch = torch.from_numpy(target_patch).view(1, *target_patch.shape).float() 312 | 313 | # data augmentation 314 | if self.transform: 315 | both_images = torch.cat((source_patch.unsqueeze(0), target_patch.unsqueeze(0)), 0) 316 | transformed_images = self.transform(both_images) 317 | source_patch = transformed_images[0, ...] 318 | target_patch = transformed_images[1, ...] 319 | 320 | return (source_patch, 321 | target_patch, 322 | str(idx) 323 | ) 324 | -------------------------------------------------------------------------------- /sdeconv/deconv/_transforms.py: -------------------------------------------------------------------------------- 1 | """Data transformation for image restoration workflow""" 2 | import torch 3 | from torchvision.transforms import v2 4 | 5 | 6 | class FlipAugmentation: 7 | """Data augmentation by flipping images""" 8 | def __init__(self): 9 | self.__transform = v2.Compose([ 10 | v2.RandomHorizontalFlip(p=0.5), 11 | v2.RandomVerticalFlip(p=0.5), 12 | v2.ToDtype(torch.float32, scale=False) 13 | ]) 14 | 15 | def __call__(self, image): 16 | return self.__transform(image) 17 | 18 | 19 | class VisionScale: 20 | """Scale images in [-1, 1]""" 21 | 22 | def __init__(self): 23 | self.__transform = v2.Compose([ 24 | v2.ToDtype(torch.float32, scale=True) 25 | ]) 26 | 27 | def __call__(self, image): 28 | return self.__transform(image) 29 | -------------------------------------------------------------------------------- /sdeconv/deconv/_unet_2d.py: -------------------------------------------------------------------------------- 1 | """Module to implement a UNet in 2D wwith pytorch module""" 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class UNetConvBlock(nn.Module): 7 | """Convolution block for UNet architecture 8 | 9 | This block is 2 convolution layers with a ReLU. 10 | An optional batch norm can be added after each convolution layer 11 | 12 | :param n_channels_in: Number of input channels (or features) 13 | :param n_channels_out: Number of output channels (or features) 14 | :param use_batch_norm: True to use the batch norm layers 15 | """ 16 | def __init__(self, 17 | n_channels_in: int, 18 | n_channels_out: int, 19 | use_batch_norm: bool = True): 20 | super().__init__() 21 | 22 | self.use_batch_norm = use_batch_norm 23 | self.conv1 = nn.Conv2d(n_channels_in, n_channels_out, 24 | kernel_size=3, padding=1) 25 | self.bn1 = nn.BatchNorm2d(n_channels_out) 26 | 27 | self.conv2 = nn.Conv2d(n_channels_out, n_channels_out, 28 | kernel_size=3, padding=1) 29 | self.bn2 = nn.BatchNorm2d(n_channels_out) 30 | 31 | self.relu = nn.ReLU() 32 | 33 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 34 | """Apply the model 35 | 36 | :param inputs: Data to process 37 | :return: The processed data 38 | """ 39 | x = self.conv1(inputs) 40 | if self.use_batch_norm: 41 | x = self.bn1(x) 42 | x = self.relu(x) 43 | 44 | x = self.conv2(x) 45 | if self.use_batch_norm: 46 | x = self.bn2(x) 47 | x = self.relu(x) 48 | 49 | return x 50 | 51 | 52 | class UNetEncoderBlock(nn.Module): 53 | """Encoder block of the UNet architecture 54 | 55 | The encoder block is a convolution block and a max polling layer 56 | 57 | :param n_channels_in: Number of input channels (or features) 58 | :param n_channels_out: Number of output channels (or features) 59 | :param use_batch_norm: True to use the batch norm layers 60 | """ 61 | def __init__(self, 62 | n_channels_in: int, 63 | n_channels_out: int, 64 | use_batch_norm: bool = True): 65 | super().__init__() 66 | 67 | self.conv = UNetConvBlock(n_channels_in, n_channels_out, 68 | use_batch_norm) 69 | self.pool = nn.MaxPool2d((2, 2)) 70 | 71 | def forward(self, inputs: torch.Tensor): 72 | """torch module forward method 73 | 74 | :param inputs: tensor to process 75 | """ 76 | x = self.conv(inputs) 77 | p = self.pool(x) 78 | 79 | return x, p 80 | 81 | 82 | class UNetDecoderBlock(nn.Module): 83 | """Decoder block of a UNet architecture 84 | 85 | The decoder is an up-sampling concatenation and convolution block 86 | 87 | :param n_channels_in: Number of input channels (or features) 88 | :param n_channels_out: Number of output channels (or features) 89 | :param use_batch_norm: True to use the batch norm layers 90 | """ 91 | def __init__(self, 92 | n_channels_in: int, 93 | n_channels_out: int, 94 | use_batch_norm: bool = True): 95 | super().__init__() 96 | 97 | self.up = nn.Upsample(scale_factor=(2, 2), mode='nearest') 98 | self.conv = UNetConvBlock(n_channels_in+n_channels_out, 99 | n_channels_out, use_batch_norm) 100 | 101 | def forward(self, inputs: torch.Tensor, skip: torch.Tensor): 102 | """Module torch forward 103 | 104 | :param inputs: input tensor 105 | :param skip: skip connection tensor 106 | """ 107 | x = self.up(inputs) 108 | x = torch.cat([x, skip], dim=1) 109 | x = self.conv(x) 110 | 111 | return x 112 | 113 | 114 | class UNet2D(nn.Module): 115 | """Implementation of a UNet network 116 | 117 | :param n_channels_in: Number of input channels (or features), 118 | :param n_channels_out: Number of output channels (or features), 119 | :param n_feature_first: Number of channels (or features) in the first convolution block, 120 | :param use_batch_norm: True to use the batch norm layers 121 | """ 122 | def __init__(self, 123 | n_channels_in: int = 1, 124 | n_channels_out: int = 1, 125 | n_channels_layers: list[int] = (32, 64, 128), 126 | use_batch_norm: bool = False): 127 | super().__init__() 128 | 129 | # Encoder 130 | self.encoder = nn.ModuleList() 131 | for idx, n_channels in enumerate(n_channels_layers[:-1]): 132 | n_in = n_channels_in if idx == 0 else n_channels_layers[idx-1] 133 | n_out = n_channels 134 | self.encoder.append(UNetEncoderBlock(n_in, n_out, use_batch_norm)) 135 | 136 | # Bottleneck 137 | self.bottleneck = UNetConvBlock(n_channels_layers[-2], 138 | n_channels_layers[-1], 139 | use_batch_norm) 140 | 141 | # Decoder 142 | self.decoder = nn.ModuleList() 143 | for idx in reversed(range(len(n_channels_layers))): 144 | if idx > 0: 145 | n_in = n_channels_layers[idx] 146 | n_out = n_channels_layers[idx-1] 147 | self.decoder.append(UNetDecoderBlock(n_in, n_out, use_batch_norm)) 148 | 149 | self.outputs = nn.Conv2d(n_channels_layers[0], n_channels_out, 150 | kernel_size=1, padding=0) 151 | 152 | self.num_layers = len(n_channels_layers)-1 153 | 154 | 155 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 156 | """Module torch forward 157 | 158 | :param inputs: input tensor 159 | :return: the tensor processed by the network 160 | """ 161 | # Encoder 162 | skips = [] 163 | p = [None] * (len(self.encoder)+1) 164 | p[0] = inputs 165 | for idx, layer in enumerate(self.encoder): 166 | s, p[idx+1] = layer(p[idx]) 167 | skips.append(s) 168 | 169 | # Bottleneck 170 | d = [None] * (len(self.decoder)+1) 171 | d[0] = self.bottleneck(p[-1]) 172 | 173 | # decoder 174 | for idx, layer in enumerate(self.decoder): 175 | d[idx+1] = layer(d[idx], skips[self.num_layers-idx-1]) 176 | 177 | # Classifier 178 | return self.outputs(d[-1]) 179 | -------------------------------------------------------------------------------- /sdeconv/deconv/_utils.py: -------------------------------------------------------------------------------- 1 | """Implementation of shared methods for all multiple deconvolution algorithms""" 2 | import numpy as np 3 | import torch 4 | from sdeconv.core import SSettings 5 | 6 | 7 | def np_to_torch(image: np.ndarray | torch.Tensor) -> torch.Tensor: 8 | """Convert a numpy array into a torch tensor 9 | 10 | The array is then loaded on the GPU if available 11 | 12 | :param image: Image to convert, 13 | :return: The converted image 14 | """ 15 | if isinstance(image, np.ndarray): 16 | image_ = torch.tensor(image).to(SSettings.instance().device) 17 | else: 18 | image_ = image 19 | return image_ 20 | 21 | 22 | def resize_psf_2d(image: torch.Tensor, psf: torch.Tensor) -> torch.Tensor: 23 | """Resize a 2D PSF image to the target image size 24 | 25 | :param image: Reference image tensor, 26 | :param psf: Point Spread Function tensor to resize, 27 | :return: the psf tensor padded to get the same shape as image 28 | """ 29 | kernel = torch.zeros(image.shape).to(SSettings.instance().device) 30 | x_start = int(image.shape[0] / 2 - psf.shape[0] / 2) + 1 31 | y_start = int(image.shape[1] / 2 - psf.shape[1] / 2) + 1 32 | kernel[x_start:x_start + psf.shape[0], y_start:y_start + psf.shape[1]] = psf 33 | return kernel 34 | 35 | 36 | def resize_psf_3d(image: torch.Tensor, psf: torch.Tensor) -> torch.Tensor: 37 | """Resize a 3D PSF image to the target image size 38 | 39 | :param image: Reference image tensor 40 | :param psf: Point Spread Function tensor to resize 41 | :returns: the psf tensor padded to get the same shape as image 42 | """ 43 | kernel = torch.zeros(image.shape).to(SSettings.instance().device) 44 | x_start = int(image.shape[0] / 2 - psf.shape[0] / 2) + 1 45 | y_start = int(image.shape[1] / 2 - psf.shape[1] / 2) + 1 46 | z_start = int(image.shape[2] / 2 - psf.shape[2] / 2) + 1 47 | kernel[x_start:x_start + psf.shape[0], y_start:y_start + psf.shape[1], 48 | z_start:z_start + psf.shape[2]] = psf 49 | return kernel 50 | 51 | 52 | def pad_2d(image: torch.Tensor, 53 | psf: torch.Tensor, 54 | pad: int | tuple[int, int] 55 | ) -> tuple[torch.Tensor, torch.Tensor, int | tuple[int, int]]: 56 | """Pad an image and it PSF for deconvolution 57 | 58 | :param image: 2D image tensor 59 | :param psf: 2D Point Spread Function. 60 | :param pad: Padding in each dimension. 61 | :return image: psf, padding: padded versions of the image and the PSF, plus the padding tuple 62 | """ 63 | padding = pad 64 | if isinstance(pad, tuple) and len(pad) != image.ndim: 65 | raise ValueError("Padding must be the same dimension as image") 66 | if isinstance(pad, int): 67 | if pad == 0: 68 | return image, psf, (0, 0) 69 | padding = (pad, pad) 70 | 71 | if padding[0] > 0 and padding[1] > 0: 72 | 73 | pad_fn = torch.nn.ReflectionPad2d((padding[0], padding[0], padding[1], padding[1])) 74 | image_pad = pad_fn(image.detach().clone().to( 75 | SSettings.instance().device).view(1, 1, image.shape[0], image.shape[1])).view( 76 | (image.shape[0] + 2 * padding[0], image.shape[1] + 2 * padding[0])) 77 | else: 78 | image_pad = image.detach().clone().to(SSettings.instance().device) 79 | psf_pad = resize_psf_2d(image_pad, psf) 80 | return image_pad, psf_pad, padding 81 | 82 | 83 | def pad_3d(image: torch.Tensor, 84 | psf: torch.Tensor, 85 | pad: int | tuple[int, int, int] 86 | ) -> tuple[torch.Tensor, torch.Tensor, int | tuple[int, int, int]]: 87 | """Pad an image and it PSF for deconvolution 88 | 89 | :param image: 2D image tensor 90 | :param psf: 2D Point Spread Function. 91 | :param pad: Padding in each dimension. 92 | :return: image, psf, padding: padded versions of the image and the PSF, plus the padding tuple 93 | """ 94 | padding = pad 95 | if isinstance(pad, tuple) and len(pad) != image.ndim: 96 | raise ValueError("Padding must be the same dimension as image") 97 | if isinstance(pad, int): 98 | if pad == 0: 99 | return image, psf, (0, 0, 0) 100 | padding = (pad, pad, pad) 101 | 102 | if padding[0] > 0 and padding[1] > 0 and padding[2] > 0: 103 | p3d = (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]) 104 | pad_fn = torch.nn.ReflectionPad3d(p3d) 105 | image_pad = pad_fn( 106 | image.detach().clone().to(SSettings.instance().device).view(1, 1, image.shape[0], 107 | image.shape[1], 108 | image.shape[2])).view( 109 | (image.shape[0] + 2 * padding[0], image.shape[1] + 2 * padding[1], 110 | image.shape[2] + 2 * padding[2])) 111 | psf_pad = torch.nn.functional.pad(psf, p3d, "constant", 0) 112 | else: 113 | image_pad = image 114 | psf_pad = psf 115 | return image_pad, psf_pad, padding 116 | 117 | 118 | def unpad_3d(image: torch.Tensor, padding: tuple[int, int, int]) -> torch.Tensor: 119 | """Remove the padding of an image 120 | 121 | :param image: 3D image to un-pad 122 | :param padding: Padding in each dimension 123 | :return: a torch.Tensor of the un-padded image 124 | """ 125 | return image[padding[0]:-padding[0], 126 | padding[1]:-padding[1], 127 | padding[2]:-padding[2]] 128 | 129 | 130 | # define the PSF parameter 131 | psf_parameter = { 132 | 'type': 'torch.Tensor', 133 | 'label': 'psf', 134 | 'help': 'Point Spread Function', 135 | 'default': None 136 | } 137 | -------------------------------------------------------------------------------- /sdeconv/deconv/interface.py: -------------------------------------------------------------------------------- 1 | """Interface for a deconvolution filter""" 2 | import torch 3 | from sdeconv.core import SObservable 4 | 5 | 6 | class SDeconvFilter(SObservable): 7 | """Interface for a deconvolution filter 8 | 9 | All the algorithm settings must be set in the `__init__` method (PSF included) and the 10 | `__call__` method is used to actually do the calculation 11 | """ 12 | def __init__(self): 13 | super().__init__() 14 | self.type = 'SDeconvFilter' 15 | 16 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 17 | """Do the deconvolution 18 | 19 | :param image: Blurry image for a single channel time point [(Z) Y X] 20 | :return: deblurred image [(Z) Y X] 21 | """ 22 | raise NotImplementedError('SDeconvFilter is an interface. Please implement the' 23 | ' __call__ method') 24 | -------------------------------------------------------------------------------- /sdeconv/deconv/interface_nn.py: -------------------------------------------------------------------------------- 1 | """This module implements interface for deconvolution based on neural network""" 2 | from abc import abstractmethod 3 | from pathlib import Path 4 | 5 | import torch 6 | from skimage.io import imsave 7 | 8 | from ..core import seconds2str 9 | from ..core import SConsoleLogger 10 | 11 | from ._unet_2d import UNet2D 12 | from ._transforms import VisionScale 13 | 14 | 15 | class NNModule(torch.nn.Module): 16 | """Deconvolution using the noise to void algorithm""" 17 | def __init__(self): 18 | super().__init__() 19 | 20 | self._model_args = None 21 | self._model = None 22 | self._loss_fn = None 23 | self._optimizer = None 24 | self._save_all = True 25 | self._device = None 26 | self._out_dir = None 27 | self._val_data_loader = None 28 | self._train_data_loader = None 29 | self._progress = SConsoleLogger() 30 | self._current_epoch = None 31 | self._current_loss = None 32 | 33 | @abstractmethod 34 | def fit(self, 35 | train_directory: Path, 36 | val_directory: Path, 37 | n_channel_in: int = 1, 38 | n_channels_layer: list[int] = (32, 64, 128), 39 | patch_size: int = 32, 40 | n_epoch: int = 25, 41 | learning_rate: float = 1e-3, 42 | out_dir: Path = None 43 | ): 44 | """Train a model on a dataset 45 | 46 | :param train_directory: Directory containing the images used for 47 | training. One file per image, 48 | :param val_directory: Directory containing the images used for validation of 49 | the training. One file per image, 50 | :param n_channel_in: Number of channels in the input images 51 | :param n_channels_layer: Number of channels for each hidden layers of the model, 52 | :param patch_size: Size of square patches used for training the model, 53 | :param n_epoch: Number of epochs, 54 | :param learning_rate: Adam optimizer learning rate 55 | """ 56 | raise NotImplementedError('NNModule is abstract') 57 | 58 | def _train_loop(self, n_epoch: int): 59 | """Run the main train loop (should be called in fit) 60 | 61 | :param n_epoch: Number of epochs to run 62 | """ 63 | for epoch in range(n_epoch): 64 | self._current_epoch = epoch 65 | train_data = self._train_step() 66 | self._after_train_step(train_data) 67 | self._after_train() 68 | 69 | @abstractmethod 70 | def _train_step(self): 71 | """Runs one step of training""" 72 | 73 | def _after_train_batch(self, data: dict[str, any]): 74 | """Instructions runs after one batch 75 | 76 | :param data: Dictionary of metadata to log or process 77 | """ 78 | prefix = f"Epoch = {self._current_epoch+1:d}" 79 | loss_str = f"{data['loss']:.7f}" 80 | full_time_str = seconds2str(int(data['full_time'])) 81 | remains_str = seconds2str(int(data['remain_time'])) 82 | suffix = str(data['id_batch']) + '/' + str(data['total_batch']) + \ 83 | ' [' + full_time_str + '<' + remains_str + ', loss=' + \ 84 | loss_str + '] ' 85 | self._progress.progress(data['id_batch'], 86 | data['total_batch'], 87 | prefix=prefix, 88 | suffix=suffix) 89 | 90 | def _after_train_step(self, data: dict): 91 | """Instructions runs after one train step. 92 | 93 | This method can be used to log data or print console messages 94 | 95 | :param data: Dictionary of metadata to log or process 96 | """ 97 | if self._save_all: 98 | self.save(Path(self._out_dir, f'model_{self._current_epoch}.ml')) 99 | 100 | def _after_train(self): 101 | """Instructions runs after the train.""" 102 | # create the output dir 103 | predictions_dir = self._out_dir / 'predictions' 104 | predictions_dir.mkdir(parents=True, exist_ok=True) 105 | 106 | # predict on all the test set 107 | self._model.eval() 108 | for x, names in self._val_data_loader: 109 | x = x.to(self.device()) 110 | 111 | with torch.no_grad(): 112 | prediction = self._model(x) 113 | for i, name in enumerate(names): 114 | imsave(predictions_dir / f'{name}.tif', 115 | prediction[i, ...].cpu().numpy()) 116 | 117 | def _init_model(self, 118 | n_channel_in: int = 1, 119 | n_channel_out: int = 1, 120 | n_channels_layer: list[int] = (32, 64, 128) 121 | ): 122 | """Initialize the model 123 | 124 | :param n_channel_in: Number of channels for the input image 125 | :param n_channel_in: Number of channels for the output image 126 | :param n_channels_layer: Number of channels for each layers of the UNet 127 | """ 128 | self._model_args = { 129 | "n_channel_in": n_channel_in, 130 | "n_channel_out": n_channel_out, 131 | "n_channels_layer": n_channels_layer 132 | } 133 | self._model = UNet2D(n_channel_in, n_channel_out, n_channels_layer, True) 134 | self._model.to(self.device()) 135 | 136 | def device(self) -> str: 137 | """Get the GPU if exists 138 | 139 | :return: The device name (cuda or CPU) 140 | """ 141 | if self._device is None: 142 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 143 | return self._device 144 | 145 | def load(self, filename: Path): 146 | """Load pre-trained model from file 147 | 148 | :param: Path of the model file 149 | """ 150 | params = torch.load(filename, map_location=torch.device(self.device())) 151 | self._init_model(**params["model_args"]) 152 | self._model.load_state_dict(params['model_state_dict']) 153 | self._model.to(self.device()) 154 | 155 | def save(self, filename: Path): 156 | """Save the model into file 157 | 158 | :param: Path of the model file 159 | """ 160 | torch.save({ 161 | 'model_args': self._model_args, 162 | 'model_state_dict': self._model.state_dict(), 163 | }, filename) 164 | 165 | 166 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 167 | """Apply the model on a image or batch of image 168 | 169 | :param image: Blurry image for a single channel or batch [(B) Y X] 170 | :return: deblurred image [(B) Y X] 171 | """ 172 | if image.ndim == 2: 173 | image = image.view(1, 1, *image.shape) 174 | elif image.ndim > 2: 175 | raise ValueError("The current implementation of neural network " 176 | "deconvolution works only with 2D images") 177 | 178 | self._model.eval() 179 | scaler = VisionScale() 180 | with torch.no_grad(): 181 | return self._model(scaler(image)) 182 | -------------------------------------------------------------------------------- /sdeconv/deconv/nn_deconv.py: -------------------------------------------------------------------------------- 1 | """This module implements self supervised deconvolution with Spitfire regularisation""" 2 | from pathlib import Path 3 | from timeit import default_timer as timer 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from skimage.io import imsave 8 | 9 | from .interface_nn import NNModule 10 | from ._datasets import RestorationDataset 11 | from ._datasets import RestorationPatchDatasetLoad 12 | from ._datasets import RestorationPatchDataset 13 | from ._transforms import FlipAugmentation, VisionScale 14 | 15 | 16 | class NNDeconv(NNModule): 17 | """Deconvolution using a neural network trained using ground truth""" 18 | def fit(self, 19 | train_directory: Path, 20 | val_directory: Path, 21 | n_channel_in: int = 1, 22 | n_channels_layer: list[int] = (32, 64, 128), 23 | patch_size: int = 32, 24 | n_epoch: int = 25, 25 | learning_rate: float = 1e-3, 26 | out_dir: Path = None, 27 | preload: bool = True 28 | ): 29 | """Train a model on a dataset 30 | 31 | :param train_directory: Directory containing the images used for 32 | training. One file per image, 33 | :param val_directory: Directory containing the images used for validation of the 34 | training. One file per image, 35 | :param n_channel_in: Number of channels in the input images 36 | :param n_channels_layer: Number of channels for each hidden layers of the model, 37 | :param patch_size: Size of square patches used for training the model, 38 | :param n_epoch: Number of epochs, 39 | :param learning_rate: Adam optimizer learning rate 40 | """ 41 | self._init_model(n_channel_in, n_channel_in, n_channels_layer) 42 | self._out_dir = out_dir 43 | self._loss_fn = torch.nn.MSELoss() 44 | self._optimizer = torch.optim.Adam(self._model.parameters(), lr=learning_rate) 45 | if preload: 46 | train_dataset = RestorationPatchDatasetLoad(train_directory / "source", 47 | train_directory / "target", 48 | patch_size=patch_size, 49 | stride=int(patch_size/2), 50 | transform=FlipAugmentation()) 51 | else: 52 | train_dataset = RestorationPatchDataset(train_directory / "source", 53 | train_directory / "target", 54 | patch_size=patch_size, 55 | stride=int(patch_size/2), 56 | transform=FlipAugmentation()) 57 | val_dataset = RestorationDataset(val_directory / "source", 58 | val_directory / "target", 59 | transform=VisionScale()) 60 | self._train_data_loader = DataLoader(train_dataset, 61 | batch_size=300, 62 | shuffle=True, 63 | drop_last=True, 64 | num_workers=0) 65 | self._val_data_loader = DataLoader(val_dataset, 66 | batch_size=1, 67 | shuffle=False, 68 | drop_last=False, 69 | num_workers=0) 70 | self._train_loop(n_epoch) 71 | 72 | def _train_step(self): 73 | """Runs one step of training""" 74 | size = len(self._train_data_loader.dataset) 75 | self._model.train() 76 | step_loss = 0 77 | count_step = 0 78 | tic = timer() 79 | for batch, (x, y, _) in enumerate(self._train_data_loader): 80 | count_step += 1 81 | 82 | x = x.to(self.device()) 83 | 84 | # Compute prediction error 85 | prediction = self._model(x) 86 | loss = self._loss_fn(prediction, y) 87 | step_loss += loss 88 | 89 | # Backpropagation 90 | self._optimizer.zero_grad() 91 | loss.backward() 92 | self._optimizer.step() 93 | 94 | # count time 95 | toc = timer() 96 | full_time = toc - tic 97 | total_batch = int(size / len(x)) 98 | remains = full_time * (total_batch - (batch+1)) / (batch+1) 99 | 100 | self._after_train_batch({'loss': loss, 101 | 'id_batch': batch+1, 102 | 'total_batch': total_batch, 103 | 'remain_time': int(remains+0.5), 104 | 'full_time': int(full_time+0.5) 105 | }) 106 | if count_step > 0: 107 | step_loss /= count_step 108 | self._current_loss = step_loss 109 | return {'train_loss': step_loss} 110 | 111 | def _after_train(self): 112 | """Instructions runs after the train.""" 113 | # create the output dir 114 | predictions_dir = self._out_dir / 'predictions' 115 | predictions_dir.mkdir(parents=True, exist_ok=True) 116 | 117 | # predict on all the test set 118 | self._model.eval() 119 | for x, _, names in self._val_data_loader: 120 | x = x.to(self.device()) 121 | 122 | with torch.no_grad(): 123 | prediction = self._model(x) 124 | for i, name in enumerate(names): 125 | imsave(predictions_dir / f'{name}.tif', 126 | prediction[i, ...].cpu().numpy()) 127 | -------------------------------------------------------------------------------- /sdeconv/deconv/noise2void.py: -------------------------------------------------------------------------------- 1 | """Implementation of deconvolution using noise2void algorithm""" 2 | from pathlib import Path 3 | from timeit import default_timer as timer 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from .interface_nn import NNModule 10 | from ._datasets import SelfSupervisedPatchDataset 11 | from ._datasets import SelfSupervisedDataset 12 | from ._transforms import FlipAugmentation, VisionScale 13 | 14 | 15 | def generate_2d_points(shape: tuple[int, int], n_point: int) -> tuple[np.ndarray, np.ndarray]: 16 | """Generate random 2D coordinates to mask 17 | 18 | :param shape: Shape of the image to mask 19 | :param n_point: Number of coordinates to mask 20 | :return: (y, x) coordinates to mask 21 | """ 22 | idy_msk = np.random.randint(0, int(shape[0]/2), n_point) 23 | idx_msk = np.random.randint(0, int(shape[1]/2), n_point) 24 | 25 | idy_msk = 2*idy_msk 26 | idx_msk = 2*idx_msk 27 | if np.random.randint(2) == 1: 28 | idy_msk += 1 29 | if np.random.randint(2) == 1: 30 | idx_msk += 1 31 | 32 | return idy_msk, idx_msk 33 | 34 | 35 | def generate_mask_n2v(image: torch.Tensor, ratio: float) -> tuple[torch.Tensor, torch.Tensor]: 36 | """Generate a blind spots mask fot the patch image by randomly switch pixels values 37 | 38 | :param image: Image patch to add blind spots 39 | :param ratio: Ratio of blind spots for input patch masking 40 | :return: the transformed image and the mask image 41 | """ 42 | img_shape = image.shape 43 | size_window = (5, 5) 44 | num_sample = int(img_shape[-2] * img_shape[-1] * ratio) 45 | 46 | mask = torch.zeros((img_shape[-2], img_shape[-1]), dtype=torch.float32) 47 | output = image.clone() 48 | 49 | idy_msk, idx_msk = generate_2d_points((img_shape[-2], img_shape[-1]), num_sample) 50 | num_sample = len(idy_msk) 51 | 52 | idy_neigh = np.random.randint(-size_window[0] // 2 + size_window[0] % 2, 53 | size_window[0] // 2 + size_window[0] % 2, 54 | num_sample) 55 | idx_neigh = np.random.randint(-size_window[1] // 2 + size_window[1] % 2, 56 | size_window[1] // 2 + size_window[1] % 2, 57 | num_sample) 58 | 59 | idy_msk_neigh = idy_msk + idy_neigh 60 | idx_msk_neigh = idx_msk + idx_neigh 61 | 62 | idy_msk_neigh = (idy_msk_neigh + (idy_msk_neigh < 0) * size_window[0] - 63 | (idy_msk_neigh >= img_shape[-2]) * size_window[0]) 64 | idx_msk_neigh = (idx_msk_neigh + (idx_msk_neigh < 0) * size_window[1] - 65 | (idx_msk_neigh >= img_shape[-1]) * size_window[1]) 66 | 67 | id_msk = (idy_msk, idx_msk) 68 | #id_msk_neigh = (idy_msk_neigh, idx_msk_neigh) 69 | 70 | output[:, :, idy_msk, idx_msk] = image[:, :, idy_msk_neigh, idx_msk_neigh] 71 | mask[id_msk] = 1.0 72 | 73 | return output, mask 74 | 75 | 76 | class N2VDeconLoss(torch.nn.Module): 77 | """MSE Loss with mask for Noise2Void deconvolution 78 | 79 | :param psf_file: File image containing the Point Spread Function 80 | :return: Loss tensor 81 | """ 82 | def __init__(self, 83 | psf_image: torch.Tensor): 84 | super().__init__() 85 | 86 | self.__psf = psf_image 87 | if self.__psf.ndim > 2: 88 | raise ValueError('N2VDeconLoss PSF must be a gray scaled 2D image') 89 | 90 | self.__psf = self.__psf.view((1, 1, *self.__psf.shape)) 91 | print('psf shape=', self.__psf.shape) 92 | self.__conv_op = torch.nn.Conv2d(1, 1, 93 | kernel_size=self.__psf.shape[2], 94 | stride=1, 95 | padding=int((self.__psf.shape[2] - 1) / 2), 96 | bias=False) 97 | with torch.no_grad(): 98 | self.__conv_op.weight = torch.nn.Parameter(self.__psf) 99 | 100 | def forward(self, predict: torch.Tensor, target: torch.Tensor, mask: torch.Tensor): 101 | """Calculate forward loss 102 | 103 | :param predict: tensor predicted by the model 104 | :param target: Reference target tensor 105 | :param mask: Mask to select pixels of interest 106 | """ 107 | conv_img = self.__conv_op(predict) 108 | 109 | num = torch.sum((conv_img*mask - target*mask)**2) 110 | den = torch.sum(mask) 111 | return num/den 112 | 113 | 114 | class Noise2VoidDeconv(NNModule): 115 | """Deconvolution using the noise to void algorithm""" 116 | def fit(self, 117 | train_directory: Path, 118 | val_directory: Path, 119 | n_channel_in: int = 1, 120 | n_channels_layer: list[int] = (32, 64, 128), 121 | patch_size: int = 32, 122 | n_epoch: int = 25, 123 | learning_rate: float = 1e-3, 124 | out_dir: Path = None, 125 | psf: torch.Tensor = None 126 | ): 127 | """Train a model on a dataset 128 | 129 | :param train_directory: Directory containing the images used for 130 | training. One file per image, 131 | :param val_directory: Directory containing the images used for validation of 132 | the training. One file per image, 133 | :param psf: Point spread function for deconvolution, 134 | :param n_channel_in: Number of channels in the input images 135 | :param n_channels_layer: Number of channels for each hidden layers of the model, 136 | :param patch_size: Size of square patches used for training the model, 137 | :param n_epoch: Number of epochs, 138 | :param learning_rate: Adam optimizer learning rate 139 | """ 140 | self._init_model(n_channel_in, n_channel_in, n_channels_layer) 141 | self._out_dir = out_dir 142 | self._loss_fn = N2VDeconLoss(psf.to(self.device())) 143 | self._optimizer = torch.optim.Adam(self._model.parameters(), lr=learning_rate) 144 | train_dataset = SelfSupervisedPatchDataset(train_directory, 145 | patch_size=patch_size, 146 | stride=int(patch_size/2), 147 | transform=FlipAugmentation()) 148 | val_dataset = SelfSupervisedDataset(val_directory, transform=VisionScale()) 149 | self._train_data_loader = DataLoader(train_dataset, 150 | batch_size=300, 151 | shuffle=True, 152 | drop_last=True, 153 | num_workers=0) 154 | self._val_data_loader = DataLoader(val_dataset, 155 | batch_size=1, 156 | shuffle=False, 157 | drop_last=False, 158 | num_workers=0) 159 | 160 | self._train_loop(n_epoch) 161 | 162 | def _train_step(self): 163 | """Runs one step of training""" 164 | size = len(self._train_data_loader.dataset) 165 | self._model.train() 166 | step_loss = 0 167 | count_step = 0 168 | tic = timer() 169 | for batch, (x, _) in enumerate(self._train_data_loader): 170 | count_step += 1 171 | 172 | masked_x, mask = generate_mask_n2v(x, 0.1) 173 | x, masked_x, mask = (x.to(self.device()), 174 | masked_x.to(self.device()), 175 | mask.to(self.device())) 176 | 177 | # Compute prediction error 178 | prediction = self._model(masked_x) 179 | loss = self._loss_fn(prediction, x, mask) 180 | step_loss += loss 181 | 182 | # Backpropagation 183 | self._optimizer.zero_grad() 184 | loss.backward() 185 | self._optimizer.step() 186 | 187 | # count time 188 | toc = timer() 189 | full_time = toc - tic 190 | total_batch = int(size / len(x)) 191 | remains = full_time * (total_batch - (batch+1)) / (batch+1) 192 | 193 | self._after_train_batch({'loss': loss, 194 | 'id_batch': batch+1, 195 | 'total_batch': total_batch, 196 | 'remain_time': int(remains+0.5), 197 | 'full_time': int(full_time+0.5) 198 | }) 199 | 200 | if count_step > 0: 201 | step_loss /= count_step 202 | self._current_loss = step_loss 203 | return {'train_loss': step_loss} 204 | -------------------------------------------------------------------------------- /sdeconv/deconv/richardson_lucy.py: -------------------------------------------------------------------------------- 1 | """Implementation of Richardson-Lucy deconvolution for 2D and 3D images""" 2 | import torch 3 | import numpy as np 4 | from .interface import SDeconvFilter 5 | from ._utils import pad_2d, pad_3d, np_to_torch 6 | 7 | 8 | class SRichardsonLucy(SDeconvFilter): 9 | """Implements the Richardson-Lucy deconvolution 10 | 11 | :param psf: Point spread function 12 | :param niter: Number of iterations 13 | :param pad: image padding size 14 | """ 15 | def __init__(self, 16 | psf: torch.Tensor, 17 | niter: int = 30, 18 | pad: int | tuple[int, int] | tuple[int, int, int] = 0): 19 | super().__init__() 20 | self.psf = psf 21 | self.niter = niter 22 | self.pad = pad 23 | 24 | @staticmethod 25 | def _resize_psf(psf, width, height) -> torch.Tensor: 26 | """Resize the PSF to match the image size for Fourier transform 27 | 28 | :param psf: Point spread function 29 | :param width: Width of the resized PSF 30 | :param height: Height of the resized PSF 31 | :return: The resized PSF 32 | """ 33 | kernel = torch.zeros((width, height)) 34 | x_start = int(width / 2 - psf.shape[0] / 2) + 1 35 | y_start = int(height / 2 - psf.shape[1] / 2) + 1 36 | kernel[x_start:x_start+psf.shape[0], y_start:y_start+psf.shape[1]] = psf 37 | return kernel 38 | 39 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 40 | """Apply the Richardson-Lucy deconvolution 41 | 42 | :param image: Blurry image for a single channel time point [(Z) Y X] 43 | :return: deblurred image [(Z) Y X] 44 | """ 45 | if image.ndim == 2: 46 | return self._deconv_2d(image) 47 | if image.ndim == 3: 48 | return self._deconv_3d(image) 49 | raise ValueError('Richardson-Lucy can only deblur 2D or 3D tensors') 50 | 51 | def _deconv_2d(self, image: torch.Tensor) -> torch.Tensor: 52 | """Implements Richardson-Lucy for 2D images 53 | 54 | :param image: 2D image tensor 55 | :return: 2D deblurred image 56 | """ 57 | image_pad, psf_pad, padding = pad_2d(image, self.psf / torch.sum(self.psf), self.pad) 58 | 59 | psf_roll = torch.roll(psf_pad, [int(-psf_pad.shape[0] / 2), 60 | int(-psf_pad.shape[1] / 2)], dims=(0, 1)) 61 | fft_psf = torch.fft.fft2(psf_roll) 62 | fft_psf_mirror = torch.fft.fft2(torch.flip(psf_roll, dims=[0, 1])) 63 | 64 | out_image = image_pad.detach().clone() 65 | for _ in range(self.niter): 66 | fft_out = torch.fft.fft2(out_image) 67 | fft_tmp = fft_out * fft_psf 68 | tmp = torch.real(torch.fft.ifft2(fft_tmp)) 69 | tmp = image_pad / tmp 70 | fft_tmp = torch.fft.fft2(tmp) 71 | fft_tmp = fft_tmp * fft_psf_mirror 72 | tmp = torch.real(torch.fft.ifft2(fft_tmp)) 73 | out_image = out_image * tmp 74 | 75 | if image_pad.shape != image.shape: 76 | return out_image[padding[0]:-padding[0], padding[1]:-padding[1]] 77 | return out_image 78 | 79 | def _deconv_3d(self, image: torch.Tensor) -> torch.Tensor: 80 | """Implements Richardson-Lucy for 3D images 81 | 82 | :param image: 3D image tensor 83 | :return: 3D deblurred image 84 | """ 85 | image_pad, psf_pad, padding = pad_3d(image, self.psf / torch.sum(self.psf), self.pad) 86 | 87 | psf_roll = torch.roll(psf_pad, int(-psf_pad.shape[0] / 2), dims=0) 88 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[1] / 2), dims=1) 89 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[2] / 2), dims=2) 90 | 91 | fft_psf = torch.fft.fftn(psf_roll) 92 | fft_psf_mirror = torch.fft.fftn(torch.flip(psf_roll, dims=[0, 1])) 93 | 94 | out_image = image_pad.detach().clone() 95 | for _ in range(self.niter): 96 | fft_out = torch.fft.fftn(out_image) 97 | fft_tmp = fft_out * fft_psf 98 | tmp = torch.real(torch.fft.ifftn(fft_tmp)) 99 | tmp = image_pad / tmp 100 | fft_tmp = torch.fft.fftn(tmp) 101 | fft_tmp = fft_tmp * fft_psf_mirror 102 | tmp = torch.real(torch.fft.ifftn(fft_tmp)) 103 | out_image = out_image * tmp 104 | 105 | if image_pad.shape != image.shape: 106 | return out_image[padding[0]:-padding[0], 107 | padding[1]:-padding[1], 108 | padding[2]:-padding[2]] 109 | return out_image 110 | 111 | 112 | def srichardsonlucy(image: torch.Tensor, 113 | psf: torch.Tensor, 114 | niter: int = 30, 115 | pad: int | tuple[int, int] | tuple[int, int, int] = 0 116 | ) -> torch.Tensor: 117 | """Convenient function to call the SRichardsonLucy using numpy array 118 | 119 | :param image: Image to deblur 120 | :param psf: Point spread function 121 | :param niter: Number of iterations 122 | :param pad: image padding size 123 | :return: the deblurred image 124 | """ 125 | psf_ = np_to_torch(psf) 126 | image_ = np_to_torch(image) 127 | filter_ = SRichardsonLucy(psf_, niter, pad) 128 | if isinstance(image, np.ndarray): 129 | return filter_(image_) 130 | return filter_(image) 131 | 132 | 133 | metadata = { 134 | 'name': 'SRichardsonLucy', 135 | 'label': 'Richardson-Lucy', 136 | 'fnc': srichardsonlucy, 137 | 'inputs': { 138 | 'image': { 139 | 'type': 'Image', 140 | 'label': 'Image', 141 | 'help': 'Input image' 142 | }, 143 | 'psf': { 144 | 'type': 'Image', 145 | 'label': 'PSF', 146 | 'help': 'Point Spread Function' 147 | }, 148 | 'niter': { 149 | 'type': 'int', 150 | 'label': 'niter', 151 | 'help': 'Number of iterations', 152 | 'default': 30, 153 | 'range': (0, 999999) 154 | }, 155 | 'pad': { 156 | 'type': 'int', 157 | 'label': 'Padding', 158 | 'help': 'Padding to avoid spectrum artifacts', 159 | 'default': 13, 160 | 'range': (0, 999999) 161 | } 162 | }, 163 | 'outputs': { 164 | 'image': { 165 | 'type': 'Image', 166 | 'label': 'Richardson-Lucy' 167 | }, 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /sdeconv/deconv/self_supervised_nn.py: -------------------------------------------------------------------------------- 1 | """This module implements self supervised deconvolution with Spitfire regularisation""" 2 | from pathlib import Path 3 | from timeit import default_timer as timer 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from .interface_nn import NNModule 9 | from ._datasets import SelfSupervisedPatchDataset 10 | from ._datasets import SelfSupervisedDataset 11 | from .spitfire import hv_loss 12 | 13 | 14 | class DeconSpitfireLoss(torch.nn.Module): 15 | """MSE LOSS with a (de)convolution filter and Spitfire regularisation 16 | 17 | :param psf_file: File containing the PSF for deconvolution 18 | :return: Loss tensor 19 | """ 20 | def __init__(self, 21 | psf: torch.Tensor, 22 | regularization: float = 1e-3, 23 | weighting: float = 0.6 24 | ): 25 | super().__init__() 26 | self.__psf = psf 27 | self.regularization = regularization 28 | self.weighting = weighting 29 | 30 | if self.__psf.ndim > 2: 31 | raise ValueError('DeconMSE PSF must be a gray scaled 2D image') 32 | 33 | self.__psf = self.__psf.view((1, 1, *self.__psf.shape)) 34 | print('psf shape=', self.__psf.shape) 35 | self.__conv_op = torch.nn.Conv2d(1, 1, 36 | kernel_size=self.__psf.shape[2], 37 | stride=1, 38 | padding=int((self.__psf.shape[2] - 1) / 2), 39 | bias=False) 40 | with torch.no_grad(): 41 | self.__conv_op.weight = torch.nn.Parameter(self.__psf, requires_grad=False) 42 | self.__conv_op.requires_grad_(False) 43 | 44 | def forward(self, input_image: torch.Tensor, target: torch.Tensor): 45 | """Deconvolution L2 data-term 46 | 47 | Compute the L2 error between the original image (input) and the 48 | convoluted reconstructed image (target) 49 | 50 | :param input_image: Tensor of shape BCYX containing the original blurry image 51 | :param target: Tensor of shape BCYX containing the estimated deblurred image 52 | """ 53 | conv_img = self.__conv_op(input_image) 54 | mse = torch.nn.MSELoss() 55 | return self.regularization*mse(target, conv_img) + \ 56 | (1-self.regularization)*hv_loss(input_image, weighting=self.weighting) 57 | 58 | 59 | class SelfSupervisedNNDeconv(NNModule): 60 | """Deconvolution using a neural network trained using the Spitfire loss""" 61 | def fit(self, 62 | train_directory: Path, 63 | val_directory: Path, 64 | n_channel_in: int = 1, 65 | n_channels_layer: list[int] = (32, 64, 128), 66 | patch_size: int = 32, 67 | n_epoch: int = 25, 68 | learning_rate: float = 1e-3, 69 | out_dir: Path = None, 70 | weight: float = 0.9, 71 | reg: float = 0.95, 72 | psf: torch.Tensor = None 73 | ): 74 | """Train a model on a dataset 75 | 76 | :param train_directory: Directory containing the images used 77 | for training. One file per image, 78 | :param val_directory: Directory containing the images used for validation of 79 | the training. One file per image, 80 | :param psf: Point spread function for deconvolution, 81 | :param n_channel_in: Number of channels in the input images 82 | :param n_channels_layer: Number of channels for each hidden layers of the model, 83 | :param patch_size: Size of square patches used for training the model, 84 | :param n_epoch: Number of epochs, 85 | :param learning_rate: Adam optimizer learning rate 86 | """ 87 | self._init_model(n_channel_in, n_channel_in, n_channels_layer) 88 | self._out_dir = out_dir 89 | self._loss_fn = DeconSpitfireLoss(psf.to(self.device()), reg, weight) 90 | self._optimizer = torch.optim.Adam(self._model.parameters(), lr=learning_rate) 91 | train_dataset = SelfSupervisedPatchDataset(train_directory, 92 | patch_size=patch_size, 93 | stride=int(patch_size/2)) 94 | val_dataset = SelfSupervisedDataset(val_directory) 95 | self._train_data_loader = DataLoader(train_dataset, 96 | batch_size=300, 97 | shuffle=True, 98 | drop_last=True, 99 | num_workers=0) 100 | self._val_data_loader = DataLoader(val_dataset, 101 | batch_size=1, 102 | shuffle=False, 103 | drop_last=False, 104 | num_workers=0) 105 | self._train_loop(n_epoch) 106 | 107 | def _train_step(self): 108 | """Runs one step of training""" 109 | size = len(self._train_data_loader.dataset) 110 | self._model.train() 111 | step_loss = 0 112 | count_step = 0 113 | tic = timer() 114 | for batch, (x, _) in enumerate(self._train_data_loader): 115 | count_step += 1 116 | 117 | x = x.to(self.device()) 118 | 119 | # Compute prediction error 120 | prediction = self._model(x) 121 | loss = self._loss_fn(prediction, x) 122 | step_loss += loss 123 | 124 | # Backpropagation 125 | self._optimizer.zero_grad() 126 | loss.backward() 127 | self._optimizer.step() 128 | 129 | # count time 130 | toc = timer() 131 | full_time = toc - tic 132 | total_batch = int(size / len(x)) 133 | remains = full_time * (total_batch - (batch+1)) / (batch+1) 134 | 135 | self._after_train_batch({'loss': loss, 136 | 'id_batch': batch+1, 137 | 'total_batch': total_batch, 138 | 'remain_time': int(remains+0.5), 139 | 'full_time': int(full_time+0.5) 140 | }) 141 | if count_step > 0: 142 | step_loss /= count_step 143 | self._current_loss = step_loss 144 | return {'train_loss': step_loss} 145 | -------------------------------------------------------------------------------- /sdeconv/deconv/spitfire.py: -------------------------------------------------------------------------------- 1 | """Implements the Spitfire deconvolution algorithms for 2D and 3D images""" 2 | import torch 3 | 4 | from ..core import SSettings 5 | from .interface import SDeconvFilter 6 | from ._utils import pad_2d, pad_3d, unpad_3d, np_to_torch 7 | 8 | 9 | def hv_loss(img: torch.Tensor, weighting: float = 0.5) -> torch.Tensor: 10 | """Sparse Hessian regularization term 11 | 12 | :param img: Tensor of shape BCYX containing the estimated image 13 | :param weighting: Sparse weighting parameter in [0, 1]. 0 sparse, and 1 not sparse 14 | :return: the loss value in torch.Tensor 15 | """ 16 | dxx2 = torch.square(-img[:, :, 2:, 1:-1] + 2 * img[:, :, 1:-1, 1:-1] - img[:, :, :-2, 1:-1]) 17 | dyy2 = torch.square(-img[:, :, 1:-1, 2:] + 2 * img[:, :, 1:-1, 1:-1] - img[:, :, 1:-1, :-2]) 18 | dxy2 = torch.square(img[:, :, 2:, 2:] - img[:, :, 2:, 1:-1] - img[:, :, 1:-1, 2:] + 19 | img[:, :, 1:-1, 1:-1]) 20 | h_v = torch.sqrt(weighting * weighting * (dxx2 + dyy2 + 2 * dxy2) + 21 | (1 - weighting) * (1 - weighting) * torch.square(img[:, :, 1:-1, 1:-1])) 22 | return torch.mean(h_v) 23 | 24 | 25 | def hv_loss_3d(img: torch.Tensor, 26 | delta: float = 1, 27 | weighting: float = 0.5 28 | ) -> torch.Tensor: 29 | """Sparse Hessian regularization term 30 | 31 | :param img: Tensor of shape BCZYX containing the estimated image 32 | :param delta: Resolution delta between XY and Z 33 | :param weighting: Sparse weighting parameter in [0, 1]. 0 sparse, and 1 not sparse 34 | :return: the loss value in torch.Tensor 35 | """ 36 | img_ = img[:, :, 1:-1, 1:-1, 1:-1] 37 | d11 = -img[:, :, 1:-1, 1:-1, 2:] + 2*img_ - img[:, :, 1:-1, 1:-1, :-2] 38 | d22 = -img[:, :, 1:-1, 2:, 1:-1] + 2*img_ - img[:, :, 1:-1, :-2, 1:-1] 39 | d33 = delta*delta*(-img[:, :, 2:, 1:-1, 1:-1] + 2*img_ - img[:, :, :-2, 1:-1, 1:-1]) 40 | d12_d21 = img[:, :, 1:-1, 2:, 2:] - img[:, :, 1:-1, 1:-1, 2:] - img[:, :, 1:-1, 2:, 1:-1] + img_ 41 | d13_d31 = delta*(img[:, :, 2:, 1:-1, 2:] - img[:, :, 1:-1, 1:-1, 2:] 42 | - img[:, :, 2:, 1:-1, 1:-1] + img_) 43 | d23_d32 = delta*(img[:, :, 2:, 2:, 1:-1] - img[:, :, 1:-1, 2:, 1:-1] 44 | - img[:, :, 2:, 1:-1, 1:-1] + img_) 45 | 46 | h_v = torch.square(weighting*d11) + torch.square(weighting*d22) + torch.square( 47 | weighting*d33) + 2 * torch.square(weighting*d12_d21) + 2 * torch.square( 48 | weighting*d13_d31) + 2 * torch.square(weighting*d23_d32) + torch.square((1-weighting)*img_) 49 | 50 | return torch.mean(torch.sqrt(h_v)) 51 | 52 | 53 | def dataterm_deconv(blurry_image: torch.Tensor, 54 | deblurred_image: torch.Tensor, 55 | psf: torch.Tensor 56 | ) -> torch.Tensor: 57 | """Deconvolution L2 data-term 58 | 59 | Compute the L2 error between the original image and the convoluted reconstructed image 60 | 61 | :param blurry_image: Tensor of shape BCYX containing the original blurry image 62 | :param deblurred_image: Tensor of shape BCYX containing the estimated deblurred image 63 | :param psf: Tensor containing the point spread function 64 | :return: the loss value in torch.Tensor 65 | """ 66 | conv_op = torch.nn.Conv2d(1, 1, kernel_size=psf.shape[2], 67 | stride=1, 68 | padding=int((psf.shape[2] - 1) / 2), 69 | bias=False) 70 | with torch.no_grad(): 71 | conv_op.weight = torch.nn.Parameter(psf) 72 | mse = torch.nn.MSELoss() 73 | return mse(blurry_image, conv_op(deblurred_image)) 74 | 75 | 76 | class DataTermDeconv3D(torch.autograd.Function): 77 | """Deconvolution L2 data term 78 | 79 | This class manually implement the 3D data term backward 80 | """ 81 | @staticmethod 82 | def forward(ctx, 83 | deblurred_image: torch.Tensor, 84 | blurry_image: torch.Tensor, 85 | fft_blurry_image: torch.Tensor, 86 | fft_psf: torch.Tensor, 87 | adjoint_otf: torch.Tensor 88 | ) -> torch.Tensor: 89 | """Pytorch forward method 90 | 91 | :param deblurred_image: Candidate deblurred image 92 | :param blurry_image: Original image 93 | :param fft_blurry_image: FFT of the original image 94 | :param fft_psf: FFT of the Point Spread Function 95 | :param adjoint_otf: Fourier adjoint of PSF 96 | :return: The loss value 97 | """ 98 | fft_deblurred_image = torch.fft.fftn(deblurred_image) 99 | ctx.save_for_backward(deblurred_image, fft_deblurred_image, fft_blurry_image, fft_psf, 100 | adjoint_otf) 101 | 102 | conv_deblured_image = torch.real(torch.fft.ifftn(fft_deblurred_image * fft_psf)) 103 | mse = torch.nn.MSELoss() 104 | return mse(blurry_image, conv_deblured_image) 105 | 106 | @staticmethod 107 | def backward(ctx, grad_output): 108 | """Pytorch backward method""" 109 | deblurred_image, fft_deblurred_image, fft_blurry_image, \ 110 | fft_psf, adjoint_otf = ctx.saved_tensors 111 | 112 | real_tmp = fft_psf.real * fft_deblurred_image.real - \ 113 | fft_psf.imag * fft_deblurred_image.imag - fft_blurry_image.real 114 | imag_tmp = fft_psf.real * fft_deblurred_image.imag + \ 115 | fft_psf.imag * fft_deblurred_image.real - fft_blurry_image.imag 116 | 117 | residue_image_real = adjoint_otf.real * real_tmp - adjoint_otf.imag * imag_tmp 118 | residue_image_imag = adjoint_otf.real * imag_tmp + adjoint_otf.imag * real_tmp 119 | 120 | grad_ = torch.real(torch.fft.ifftn( 121 | torch.complex(residue_image_real, 122 | residue_image_imag))) / torch.numel(deblurred_image) 123 | return grad_output * grad_, None, None, None, None 124 | 125 | 126 | class Spitfire(SDeconvFilter): 127 | """Variational deconvolution using the Spitfire algorithm 128 | 129 | :param psf: Point spread function 130 | :param weight: model weight between hessian and sparsity. Value is in ]0, 1[ 131 | :param delta: For 3D images resolution delta between xy and z 132 | :param reg: Regularization weight. Value is in [0, 1] 133 | :param gradient_step: Gradient descent step 134 | :param precision: Stop criterion. Stop gradient descent when the 135 | loss decrease less than precision 136 | :param pad: Image padding to avoid Fourier artefacts 137 | """ 138 | def __init__(self, 139 | psf: torch.Tensor, 140 | weight: float = 0.6, 141 | delta: float = 1, 142 | reg: float = 0.995, 143 | gradient_step: float = 0.01, 144 | precision: float = 1e-7, 145 | pad: int | tuple[int, int] | tuple[int, int, int] = 0): 146 | super().__init__() 147 | self.psf = psf.to(SSettings.instance().device) 148 | self.weight = weight 149 | self.reg = reg 150 | self.precision = precision 151 | self.delta = delta 152 | self.pad = pad 153 | self.niter_ = 0 154 | self.max_iter_ = 2500 155 | self.gradient_step_ = gradient_step 156 | self.loss_ = None 157 | 158 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 159 | """Run the Spitfire deconvolution 160 | 161 | :param image: Blurry image for a single channel time point [(Z) Y X] 162 | :return: deblurred image [(Z) Y X] 163 | """ 164 | if image.ndim == 2: 165 | return self.run_2d(image) 166 | if image.ndim == 3: 167 | return self.run_3d(image) 168 | raise ValueError("Spitfire can process only 2D or 3D tensors") 169 | 170 | def run_2d(self, image: torch.Tensor) -> torch.Tensor: 171 | """Implements Spitfire for 2D images 172 | 173 | :param image: Blurry 2D image tensor 174 | :return: Deblurred image (2D torch.Tensor) 175 | """ 176 | self.progress(0) 177 | mini = torch.min(image) + 1e-5 178 | maxi = torch.max(image) 179 | image = (image-mini)/(maxi-mini) 180 | 181 | image_pad, psf_pad, padding = pad_2d(image, self.psf/torch.sum(self.psf), self.pad) 182 | psf_pad = self.psf/torch.sum(self.psf) 183 | 184 | image_pad = image_pad.view(1, 1, image_pad.shape[0], image_pad.shape[1]) 185 | psf_pad = psf_pad.view(1, 1, psf_pad.shape[0], psf_pad.shape[1]) 186 | deconv_image = image_pad.detach().clone() 187 | deconv_image.requires_grad = True 188 | optimizer = torch.optim.Adam([deconv_image], lr=self.gradient_step_) 189 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 190 | previous_loss = 9e12 191 | count_eq = 0 192 | self.niter_ = 0 193 | loss = 0 194 | 195 | for i in range(self.max_iter_): 196 | self.progress(int(100*i/self.max_iter_)) 197 | self.niter_ += 1 198 | optimizer.zero_grad() 199 | loss = self.reg * dataterm_deconv(image_pad, deconv_image, psf_pad) + \ 200 | (1-self.reg) * hv_loss(deconv_image, self.weight) 201 | #print('iter:', self.niter_, ' loss:', loss.item()) 202 | if abs(loss - previous_loss) < self.precision: 203 | count_eq += 1 204 | else: 205 | previous_loss = loss 206 | count_eq = 0 207 | if count_eq > 5: 208 | break 209 | loss.backward() 210 | optimizer.step() 211 | scheduler.step() 212 | self.loss_ = loss 213 | self.progress(100) 214 | deconv_image = (maxi-mini)*deconv_image + mini 215 | deconv_image = deconv_image.view(deconv_image.shape[2], deconv_image.shape[3]) 216 | if image_pad.shape[2] != image.shape[0] and image_pad.shape[3] != image.shape[1]: 217 | return deconv_image[padding[0]: -padding[0], padding[1]: -padding[1]] 218 | return deconv_image 219 | 220 | @staticmethod 221 | def otf_3d(psf: torch.Tensor) -> torch.Tensor: 222 | """Calculate the OTF of a PSF 223 | 224 | :param psf: 3D Point Spread Function 225 | :return: the OTF in a torch.Tensor 3D 226 | """ 227 | psf_roll = torch.roll(psf, int(-psf.shape[0] / 2), dims=0) 228 | psf_roll = torch.roll(psf_roll, int(-psf.shape[1] / 2), dims=1) 229 | psf_roll = torch.roll(psf_roll, int(-psf.shape[2] / 2), dims=2) 230 | psf_roll.view(1, psf.shape[0], psf.shape[1], psf.shape[2]) 231 | fft_psf = torch.fft.fftn(psf_roll) 232 | return fft_psf 233 | 234 | @staticmethod 235 | def adjoint_otf(psf: torch.Tensor) -> torch.Tensor: 236 | """Calculate the adjoint OTF of a PSF 237 | 238 | :param psf: 3D Point Spread Function 239 | :return: the OTF in a torch.Tensor 3D 240 | """ 241 | adjoint_psf = torch.flip(psf, [0, 1, 2]) 242 | adjoint_psf = torch.roll(adjoint_psf, -int(psf.shape[0] - 1) % 2, dims=0) 243 | adjoint_psf = torch.roll(adjoint_psf, -int(psf.shape[1] - 1) % 2, dims=1) 244 | adjoint_psf = torch.roll(adjoint_psf, -int(psf.shape[2] - 1) % 2, dims=2) 245 | 246 | adjoint_psf = torch.roll(adjoint_psf, int(-psf.shape[0] / 2), dims=0) 247 | adjoint_psf = torch.roll(adjoint_psf, int(-psf.shape[1] / 2), dims=1) 248 | adjoint_psf = torch.roll(adjoint_psf, int(-psf.shape[2] / 2), dims=2) 249 | return torch.fft.fftn(adjoint_psf) 250 | 251 | def run_3d(self, image: torch.Tensor) -> torch.Tensor: 252 | """Implements Spitfire for 3D images 253 | 254 | :param image: Blurry 2D image tensor 255 | :return: Deblurred image (2D torch.Tensor) 256 | """ 257 | self.progress(0) 258 | mini = torch.min(image) + 1e-5 259 | maxi = torch.max(image) 260 | image = (image-mini)/(maxi-mini) 261 | image_pad, psf_pad, padding = pad_3d(image, self.psf / torch.sum(self.psf), self.pad) 262 | 263 | deconv_image = image_pad.detach().clone() 264 | image_pad = image_pad.view(1, 1, image_pad.shape[0], image_pad.shape[1], image_pad.shape[2]) 265 | deconv_image = deconv_image.view(1, 1, deconv_image.shape[0], 266 | deconv_image.shape[1], deconv_image.shape[2]) 267 | deconv_image.requires_grad = True 268 | optimizer = torch.optim.Adam([deconv_image], lr=self.gradient_step_) 269 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 270 | previous_loss = 9e12 271 | count_eq = 0 272 | self.niter_ = 0 273 | 274 | fft_psf = Spitfire.otf_3d(psf_pad) 275 | adjoint_otf = Spitfire.adjoint_otf(psf_pad) 276 | fft_image = torch.fft.fftn(image_pad) 277 | dataterm_ = DataTermDeconv3D.apply 278 | loss = 0 279 | for i in range(self.max_iter_): 280 | self.progress(int(100*i/self.max_iter_)) 281 | self.niter_ += 1 282 | optimizer.zero_grad() 283 | loss = self.reg * dataterm_(deconv_image, image_pad, 284 | fft_image, fft_psf, adjoint_otf) + ( 285 | 1 - self.reg) * hv_loss_3d(deconv_image, self.delta, self.weight) 286 | print('iter:', self.niter_, ' loss:', loss.item()) 287 | if loss > previous_loss: 288 | break 289 | if abs(loss - previous_loss) < self.precision: 290 | count_eq += 1 291 | else: 292 | previous_loss = loss 293 | count_eq = 0 294 | if count_eq > 5: 295 | break 296 | loss.backward() 297 | optimizer.step() 298 | scheduler.step() 299 | self.loss_ = loss 300 | self.progress(100) 301 | deconv_image = deconv_image.view(image_pad.shape[2], 302 | image_pad.shape[3], 303 | image_pad.shape[4]) 304 | deconv_image = (maxi-mini)*deconv_image + mini 305 | if image_pad.shape[2] != image.shape[0] and image_pad.shape[3] != image.shape[1] and \ 306 | image_pad.shape[4] != image.shape[2]: 307 | return unpad_3d(deconv_image, padding) 308 | return deconv_image 309 | 310 | 311 | def spitfire(image: torch.Tensor, 312 | psf: torch.Tensor, 313 | weight: float = 0.6, 314 | delta: float = 1, 315 | reg: float = 0.995, 316 | gradient_step: float = 0.01, 317 | precision: float = 1e-7, 318 | pad: int | tuple[int, int] | tuple[int, int, int] = 13, 319 | observers=None 320 | ) -> torch.Tensor: 321 | """Convenient function to call Spitfire with numpy 322 | 323 | :param image: Image to deblur 324 | :param psf: Point spread function 325 | :param weight: model weight between hessian and sparsity. Value is in ]0, 1[ 326 | :param delta: For 3D images resolution delta between xy and z 327 | :param reg: Regularization weight. Value is in [0, 1] 328 | :param gradient_step: Gradient descent step 329 | :param precision: Stop criterion. Stop gradient descent when the loss decrease less than 330 | precision 331 | :param pad: Image padding to avoid Fourier artefacts 332 | :return: the deblurred image 333 | """ 334 | if observers is None: 335 | observers = [] 336 | 337 | psf_ = np_to_torch(psf) 338 | filter_ = Spitfire(psf_, weight, delta, reg, gradient_step, precision, pad) 339 | for observer in observers: 340 | filter_.add_observer(observer) 341 | return filter_(np_to_torch(image)) 342 | 343 | 344 | metadata = { 345 | 'name': 'Spitfire', 346 | 'label': 'Spitfire', 347 | 'fnc': spitfire, 348 | 'inputs': { 349 | 'image': { 350 | 'type': 'Image', 351 | 'label': 'Image', 352 | 'help': 'Input image' 353 | }, 354 | 'psf': { 355 | 'type': 'Image', 356 | 'label': 'PSF', 357 | 'help': 'Point Spread Function' 358 | }, 359 | 'weight': { 360 | 'type': 'float', 361 | 'label': 'weight', 362 | 'help': 'Model weight between hessian and sparsity. Value is in ]0, 1[', 363 | 'default': 0.6, 364 | 'range': (0, 1) 365 | }, 366 | 'reg': { 367 | 'type': 'float', 368 | 'label': 'Regularization', 369 | 'help': 'Regularization weight. Value is in [0, 1]', 370 | 'default': 0.995, 371 | 'range': (0, 1) 372 | }, 373 | 'delta': { 374 | 'type': 'float', 375 | 'label': 'Delta', 376 | 'help': 'For 3D images resolution delta between xy and z', 377 | 'default': 1, 378 | 'advanced': True, 379 | }, 380 | 'gradient_step': { 381 | 'type': 'float', 382 | 'label': 'Gradient Step', 383 | 'help': 'Step for ADAM gradient descente optimization', 384 | 'default': 0.01, 385 | 'advanced': True, 386 | }, 387 | 'precision': { 388 | 'type': 'float', 389 | 'label': 'Precision', 390 | 'help': 'Stop criterion. Stop gradient descent when the loss ' 391 | 'decrease less than precision', 392 | 'default': 1e-7, 393 | 'advanced': True, 394 | }, 395 | 'pad': { 396 | 'type': 'int', 397 | 'label': 'Padding', 398 | 'help': 'Padding to avoid spectrum artifacts', 399 | 'default': 13, 400 | 'range': (0, 999999), 401 | 'advanced': True 402 | } 403 | }, 404 | 'outputs': { 405 | 'image': { 406 | 'type': 'Image', 407 | 'label': 'Spitfire' 408 | }, 409 | } 410 | } 411 | -------------------------------------------------------------------------------- /sdeconv/deconv/wiener.py: -------------------------------------------------------------------------------- 1 | """Implements the Wiener deconvolution for 2D and 3D images""" 2 | import torch 3 | from sdeconv.core import SSettings 4 | from .interface import SDeconvFilter 5 | from ._utils import pad_2d, pad_3d, unpad_3d, np_to_torch 6 | 7 | 8 | def laplacian_2d(shape: tuple[int, int]) -> torch.Tensor: 9 | """Define the 2D laplacian matrix 10 | 11 | :param shape: 2D image shape 12 | :return: torch.Tensor with size defined by shape and laplacian coefficients at it center 13 | """ 14 | image = torch.zeros(shape).to(SSettings.instance().device) 15 | 16 | x_c = int(shape[0] / 2) 17 | y_c = int(shape[1] / 2) 18 | 19 | image[x_c, y_c] = 4 20 | image[x_c, y_c - 1] = -1 21 | image[x_c, y_c + 1] = -1 22 | image[x_c - 1, y_c] = -1 23 | image[x_c + 1, y_c] = -1 24 | return image 25 | 26 | 27 | def laplacian_3d(shape: tuple[int, int]) -> torch.Tensor: 28 | """Define the 3D laplacian matrix 29 | 30 | :param shape: 3D image shape 31 | :return: torch.Tensor with size defined by shape and laplacian coefficients at it center 32 | """ 33 | image = torch.zeros(shape).to(SSettings.instance().device) 34 | 35 | x_c = int(shape[2] / 2) 36 | y_c = int(shape[1] / 2) 37 | z_c = int(shape[0] / 2) 38 | 39 | image[z_c, y_c, x_c] = 6 40 | image[z_c - 1, y_c, x_c] = -1 41 | image[z_c + 1, y_c, x_c] = -1 42 | image[z_c, y_c - 1, x_c] = -1 43 | image[z_c, y_c + 1, x_c] = -1 44 | image[z_c, y_c, x_c - 1] = -1 45 | image[z_c, y_c, x_c - 1] = -1 46 | return image 47 | 48 | 49 | class SWiener(SDeconvFilter): 50 | """Apply a Wiener deconvolution 51 | 52 | :param psf: Point spread function 53 | :param beta: Regularisation weight 54 | :param pad: Padding in each dimension 55 | """ 56 | def __init__(self, 57 | psf: torch.Tensor, 58 | beta: float = 1e-5, 59 | pad: int | tuple[int, int] | tuple[int, int, int] = 0): 60 | super().__init__() 61 | self.psf = psf 62 | self.beta = beta 63 | self.pad = pad 64 | 65 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 66 | """Run the Wiener deconvolution 67 | 68 | :param image: Blurry image for a single channel time point [(Z) Y X] 69 | :return: deblurred image [(Z) Y X] 70 | """ 71 | if image.ndim == 2: 72 | return self._wiener_2d(image) 73 | if image.ndim == 3: 74 | return self._wiener_3d(image) 75 | raise ValueError('Wiener can only deblur 2D or 3D tensors') 76 | 77 | def _wiener_2d(self, image: torch.Tensor) -> torch.Tensor: 78 | """Compute the 2D wiener deconvolution 79 | 80 | :param image: 2D image tensor 81 | :return: torch.Tensor of the 2D deblurred image 82 | """ 83 | image_pad, psf_pad, padding = pad_2d(image, self.psf / torch.sum(self.psf), self.pad) 84 | 85 | fft_source = torch.fft.fft2(image_pad) 86 | psf_roll = torch.roll(psf_pad, int(-psf_pad.shape[0] / 2), dims=0) 87 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[1] / 2), dims=1) 88 | fft_psf = torch.fft.fft2(psf_roll) 89 | fft_laplacian = torch.fft.fft2(laplacian_2d(image_pad.shape)) 90 | 91 | den = fft_psf * torch.conj(fft_psf) + self.beta * fft_laplacian * torch.conj(fft_laplacian) 92 | out_image = torch.real(torch.fft.ifftn((fft_source * torch.conj(fft_psf)) / den)) 93 | if image_pad.shape != image.shape: 94 | return out_image[padding[0]: -padding[0], padding[1]: -padding[1]] 95 | return out_image 96 | 97 | def _wiener_3d(self, image: torch.Tensor) -> torch.Tensor: 98 | """Compute the 3D wiener deconvolution 99 | 100 | :param image: 2D image tensor 101 | :return: torch.Tensor of the 2D deblurred image 102 | """ 103 | image_pad, psf_pad, padding = pad_3d(image, self.psf / torch.sum(self.psf), self.pad) 104 | 105 | fft_source = torch.fft.fftn(image_pad) 106 | psf_roll = torch.roll(psf_pad, int(-psf_pad.shape[0] / 2), dims=0) 107 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[1] / 2), dims=1) 108 | psf_roll = torch.roll(psf_roll, int(-psf_pad.shape[2] / 2), dims=2) 109 | 110 | fft_psf = torch.fft.fftn(psf_roll) 111 | fft_laplacian = torch.fft.fftn(laplacian_3d(image_pad.shape)) 112 | 113 | den = fft_psf * torch.conj(fft_psf) + self.beta * fft_laplacian * torch.conj(fft_laplacian) 114 | out_image = torch.real(torch.fft.ifftn((fft_source * torch.conj(fft_psf)) / den)) 115 | if image_pad.shape != image.shape: 116 | return unpad_3d(out_image, padding) 117 | return out_image 118 | 119 | 120 | def swiener(image: torch.Tensor, 121 | psf: torch.Tensor, 122 | beta: float = 1e-5, 123 | pad: int | tuple[int, int] | tuple[int, int, int] = 0 124 | ): 125 | """Convenient function to call the SWiener on numpy array 126 | 127 | :param image: Image to deblur, 128 | :param psf: Point spread function, 129 | :param beta: Regularisation weight, 130 | :param pad: Padding in each dimension, 131 | :return: the deblurred image 132 | """ 133 | psf_ = np_to_torch(psf) 134 | filter_ = SWiener(psf_, beta, pad) 135 | return filter_(np_to_torch(image)) 136 | 137 | 138 | metadata = { 139 | 'name': 'SWiener', 140 | 'label': 'Wiener', 141 | 'fnc': swiener, 142 | 'inputs': { 143 | 'image': { 144 | 'type': 'Image', 145 | 'label': 'Image', 146 | 'help': 'Input image' 147 | }, 148 | 'psf': { 149 | 'type': 'Image', 150 | 'label': 'PSF', 151 | 'help': 'Point Spread Function' 152 | }, 153 | 'beta': { 154 | 'type': 'float', 155 | 'label': 'Beta', 156 | 'help': 'Regularisation parameter', 157 | 'default': 1e-5, 158 | 'range': (0, 999999) 159 | }, 160 | 'pad': { 161 | 'type': 'int', 162 | 'label': 'Padding', 163 | 'help': 'Padding to avoid spectrum artifacts', 164 | 'default': 13, 165 | 'range': (0, 999999) 166 | } 167 | }, 168 | 'outputs': { 169 | 'image': { 170 | 'type': 'Image', 171 | 'label': 'Wiener' 172 | }, 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /sdeconv/psfs/__init__.py: -------------------------------------------------------------------------------- 1 | """Module that implements Point Spread Function generators""" 2 | from .interface import SPSFGenerator 3 | from .gaussian import SPSFGaussian, spsf_gaussian 4 | from .gibson_lanni import SPSFGibsonLanni, spsf_gibson_lanni 5 | from .lorentz import SPSFLorentz, spsf_lorentz 6 | 7 | __all__ = ['SPSFGenerator', 8 | 'SPSFGaussian', 9 | 'spsf_gaussian', 10 | 'SPSFGibsonLanni', 11 | 'spsf_gibson_lanni', 12 | 'SPSFLorentz', 13 | 'spsf_lorentz'] 14 | -------------------------------------------------------------------------------- /sdeconv/psfs/gaussian.py: -------------------------------------------------------------------------------- 1 | """Implements the Gaussian Point Spread Function generator""" 2 | import math 3 | import torch 4 | from sdeconv.core import SSettings 5 | from .interface import SPSFGenerator 6 | 7 | 8 | class SPSFGaussian(SPSFGenerator): 9 | """Generate a Gaussian PSF 10 | 11 | :param sigma: width of the PSF. [Z, Y, X] in 3D, [Y, X] in 2D 12 | :param shape: Shape of the PSF support image. [Z, Y, X] in 3D, [Y, X] in 2D 13 | """ 14 | def __init__(self, 15 | sigma: tuple[float, float] | tuple[float, float, float], 16 | shape: tuple[int, int] | tuple[int, int, int]): 17 | super().__init__() 18 | self.sigma = sigma 19 | self.shape = shape 20 | self.psf_ = None 21 | 22 | @staticmethod 23 | def _normalize_inputs(sigma: tuple[float, float] | tuple[float, float, float], 24 | shape: tuple[int, int, int] | tuple[int, int, int] 25 | ) -> tuple: 26 | """Remove batch dimention if it exists 27 | 28 | :param sigma: Width of the PSF 29 | :param shape: Shape of the PSF 30 | :return: The modified sigma and shape 31 | """ 32 | if len(shape) == 3 and shape[0] == 1: 33 | return sigma[1:], shape[1:] 34 | return sigma, shape 35 | 36 | def __call__(self) -> torch.Tensor: 37 | """Calculate the PSF image 38 | 39 | :return: The PSF image in a Tensor 40 | """ 41 | self.sigma, self.shape = SPSFGaussian._normalize_inputs(self.sigma, self.shape) 42 | if len(self.shape) == 2: 43 | self.psf_ = torch.zeros((self.shape[0], self.shape[1])).to(SSettings.instance().device) 44 | x_0 = math.floor(self.shape[0] / 2) 45 | y_0 = math.floor(self.shape[1] / 2) 46 | # print('center= (', x0, ', ', y0, ')') 47 | sigma_x2 = 0.5 / (self.sigma[0] * self.sigma[0]) 48 | sigma_y2 = 0.5 / (self.sigma[1] * self.sigma[1]) 49 | 50 | xx_, yy_ = torch.meshgrid(torch.arange(0, self.shape[0]), 51 | torch.arange(0, self.shape[1]), 52 | indexing='ij') 53 | self.psf_ = torch.exp(- torch.pow(xx_ - x_0, 2) * sigma_x2 54 | - torch.pow(yy_ - y_0, 2) * sigma_y2) 55 | self.psf_ = self.psf_ / torch.sum(self.psf_) 56 | elif len(self.shape) == 3: 57 | self.psf_ = torch.zeros(self.shape).to(SSettings.instance().device) 58 | x_0 = math.floor(self.shape[2] / 2) 59 | y_0 = math.floor(self.shape[1] / 2) 60 | z_0 = math.floor(self.shape[0] / 2) 61 | sigma_x2 = 0.5 / (self.sigma[2] * self.sigma[2]) 62 | sigma_y2 = 0.5 / (self.sigma[1] * self.sigma[1]) 63 | sigma_z2 = 0.5 / (self.sigma[0] * self.sigma[0]) 64 | 65 | zzz, yyy, xxx = torch.meshgrid(torch.arange(0, self.shape[0]), 66 | torch.arange(0, self.shape[1]), 67 | torch.arange(0, self.shape[2]), 68 | indexing='ij') 69 | self.psf_ = torch.exp(- torch.pow(xxx - x_0, 2) * sigma_x2 70 | - torch.pow(yyy - y_0, 2) * sigma_y2 71 | - torch.pow(zzz - z_0, 2) * sigma_z2) 72 | 73 | self.psf_ = self.psf_ / torch.sum(self.psf_) 74 | else: 75 | raise ValueError('PSFGaussian: can generate only 2D or 3D PSFs') 76 | return self.psf_ 77 | 78 | 79 | def spsf_gaussian(sigma: tuple[float, float] | tuple[float, float, float], 80 | shape: tuple[int, int] | tuple[int, int, int] 81 | ) -> torch.Tensor: 82 | """Function to generate a Gaussian PSF 83 | 84 | :param sigma: width of the PSF. [Z, Y, X] in 3D, [Y, X] in 2D, 85 | :param shape: Shape of the PSF support image. [Z, Y, X] in 3D, [Y, X] in 2D, 86 | :return: The PSF image 87 | """ 88 | filter_ = SPSFGaussian(sigma, shape) 89 | return filter_() 90 | 91 | 92 | metadata = { 93 | 'name': 'SPSFGaussian', 94 | 'label': 'Gaussian PSF', 95 | 'fnc': spsf_gaussian, 96 | 'inputs': { 97 | 'sigma': { 98 | 'type': 'zyx_float', 99 | 'label': 'Sigma', 100 | 'help': 'Gaussian standard deviation in each direction', 101 | 'default': [0, 1.5, 1.5] 102 | }, 103 | 'shape': { 104 | 'type': 'zyx_int', 105 | 'label': 'Size', 106 | 'help': 'PSF image shape', 107 | 'default': [1, 128, 128] 108 | } 109 | }, 110 | 'outputs': { 111 | 'image': { 112 | 'type': 'Image', 113 | 'label': 'PSF Gaussian' 114 | }, 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /sdeconv/psfs/gibson_lanni.py: -------------------------------------------------------------------------------- 1 | """Implementation of Gibson Lanni Point Spread Function model 2 | 3 | This implementation is an adaptation of 4 | https://kmdouglass.github.io/posts/implementing-a-fast-gibson-lanni-psf-solver-in-python/ 5 | 6 | """ 7 | from math import sqrt 8 | import numpy as np 9 | import scipy.special 10 | from scipy.interpolate import interp1d 11 | import torch 12 | 13 | 14 | from sdeconv.core import SSettings 15 | from .interface import SPSFGenerator 16 | 17 | 18 | class SPSFGibsonLanni(SPSFGenerator): 19 | """Generate a Gibson-Lanni PSF 20 | 21 | :param shape: Size of the PSF array in each dimension [(Z), Y, X], 22 | :param NA: Numerical aperture, 23 | :param wavelength: Wavelength in microns, 24 | :param M: Magnification, 25 | :param ns: Specimen refractive index (RI), 26 | :param ng0: Coverslip RI design value, 27 | :param ng: Coverslip RI experimental value, 28 | :param ni0: Immersion medium RI design value, 29 | :param ni: Immersion medium RI experimental value, 30 | :param ti0: microns, working distance (immersion medium thickness) design value, 31 | :param tg0: microns, coverslip thickness design value, 32 | :param tg: microns, coverslip thickness experimental value, 33 | :param res_lateral: Lateral resolution in microns, 34 | :param res_axial: Axial resolution in microns, 35 | :param pZ: microns, particle distance from coverslip 36 | :param use_square: If true, calculate the square of the Gibson-Lanni model to simulate a 37 | pinhole. It then gives a PSF for a confocal image 38 | """ 39 | def __init__(self, 40 | shape: tuple[int, int] | tuple[int, int, int], 41 | NA: float = 1.4, 42 | wavelength: float = 0.610, 43 | M: float = 100, 44 | ns: float = 1.33, 45 | ng0: float = 1.5, 46 | ng: float = 1.5, 47 | ni0: float = 1.5, 48 | ni: float = 1.5, 49 | ti0: float = 150, 50 | tg0: float = 170, 51 | tg: float = 170, 52 | res_lateral: float = 0.1, 53 | res_axial: float = 0.25, 54 | pZ: float = 0, 55 | use_square: bool = False): 56 | super().__init__() 57 | self.shape = shape 58 | 59 | # Microscope parameters 60 | self.NA = NA 61 | self.wavelength = wavelength 62 | self.M = M 63 | self.ns = ns 64 | self.ng0 = ng0 65 | self.ng = ng 66 | self.ni0 = ni0 67 | self.ni = ni 68 | self.ti0 = ti0 69 | self.tg0 = tg0 70 | self.tg = tg 71 | self.res_lateral = res_lateral 72 | self.res_axial = res_axial 73 | self.pZ = pZ 74 | self.use_square = use_square 75 | # output 76 | self.psf_ = None 77 | 78 | def __call__(self) -> torch.Tensor: 79 | """Calculate the PSF 80 | 81 | :return: The PSF image as a Tensor 82 | """ 83 | # Precision control 84 | num_basis = 100 # Number of rescaled Bessels that approximate the phase function 85 | num_samples = 1000 # Number of pupil samples along radial direction 86 | oversampling = 2 # Defines the sampling ratio on the image space grid for computations 87 | 88 | size_x = self.shape[2] 89 | size_y = self.shape[1] 90 | size_z = self.shape[0] 91 | min_wavelength = 0.436 # microns 92 | scaling_factor = ( 93 | self.NA * (3 * np.arange(1, num_basis + 1) - 2) * min_wavelength / self.wavelength) 94 | 95 | # Place the origin at the center of the final PSF array 96 | x0 = (size_x - 1) / 2 97 | y0 = (size_y - 1) / 2 98 | # Find the maximum possible radius coordinate of the PSF array by finding the distance 99 | # from the center of the array to a corner 100 | max_radius = round(sqrt((size_x - x0) * (size_x - x0) + (size_y - y0) * (size_y - y0))) + 1 101 | # Radial coordinates, image space 102 | r = self.res_lateral * np.arange(0, oversampling * max_radius) / oversampling 103 | # Radial coordinates, pupil space 104 | a = min([self.NA, self.ns, self.ni, self.ni0, self.ng, self.ng0]) / self.NA 105 | rho = np.linspace(0, a, num_samples) 106 | # Stage displacements away from best focus 107 | z = self.res_axial * np.arange(-size_z / 2, size_z / 2) + self.res_axial / 2 108 | 109 | # Define the wavefront aberration 110 | OPDs = self.pZ * np.sqrt(self.ns * self.ns - self.NA * self.NA * rho * rho) 111 | OPDi = (z.reshape(-1, 1) + self.ti0) * np.sqrt(self.ni * self.ni - self.NA * self.NA * rho * rho) - self.ti0 * np.sqrt( 112 | self.ni0 * self.ni0 - self.NA * self.NA * rho * rho) # OPD in the immersion medium 113 | OPDg = self.tg * np.sqrt(self.ng * self.ng - self.NA * self.NA * rho * rho) - self.tg0 * np.sqrt( 114 | self.ng0 * self.ng0 - self.NA * self.NA * rho * rho) # OPD in the coverslip 115 | W = 2 * np.pi / self.wavelength * (OPDs + OPDi + OPDg) 116 | 117 | # Sample the phase 118 | # Shape is (number of z samples by number of rho samples) 119 | phase = np.cos(W) + 1j * np.sin(W) 120 | 121 | # Define the basis of Bessel functions 122 | # Shape is (number of basis functions by number of rho samples) 123 | J = scipy.special.jv(0, scaling_factor.reshape(-1, 1) * rho) 124 | 125 | # Compute the approximation to the sampled pupil phase by finding the least squares 126 | # solution to the complex coefficients of the Fourier-Bessel expansion. 127 | # Shape of C is (number of basis functions by number of z samples). 128 | # Note the matrix transposes to get the dimensions correct. 129 | C, residuals, _, _ = np.linalg.lstsq(J.T, phase.T, rcond=None) 130 | 131 | # compute the PSF 132 | b = 2 * np. pi * r.reshape(-1, 1) * self.NA / self.wavelength 133 | 134 | # Convenience functions for J0 and J1 Bessel functions 135 | J0 = lambda x: scipy.special.jv(0, x) 136 | J1 = lambda x: scipy.special.jv(1, x) 137 | 138 | # See equation 5 in Li, Xue, and Blu 139 | denom = scaling_factor * scaling_factor - b * b 140 | R = scaling_factor * J1(scaling_factor * a) * J0(b * a) * a - b * J0(scaling_factor * a) * J1(b * a) * a 141 | R /= denom 142 | 143 | # The transpose places the axial direction along the first dimension of the array, i.e. rows 144 | # This is only for convenience. 145 | PSF_rz = (np.abs(R.dot(C)) ** 2).T 146 | 147 | # Normalize to the maximum value 148 | PSF_rz /= np.max(PSF_rz) 149 | 150 | # cartesian PSF 151 | # Create the fleshed-out xy grid of radial distances from the center 152 | xy = np.mgrid[0:size_y, 0:size_x] 153 | r_pixel = np.sqrt((xy[1] - x0) * (xy[1] - x0) + (xy[0] - y0) * (xy[0] - y0)) * self.res_lateral 154 | 155 | self.psf_ = np.zeros((size_z, size_y, size_x)) 156 | 157 | for z_index in range(size_z): 158 | # Interpolate the radial PSF function 159 | PSF_interp = interp1d(r, PSF_rz[z_index, :]) 160 | 161 | # Evaluate the PSF at each value of r_pixel 162 | self.psf_[z_index, :, :] = PSF_interp(r_pixel.ravel()).reshape(size_y, size_x) 163 | 164 | if self.use_square: 165 | self.psf_ = np.square(self.psf_) 166 | 167 | return torch.from_numpy(self.psf_).to(SSettings.instance().device) 168 | 169 | 170 | def spsf_gibson_lanni(shape: tuple[int, int] | tuple[int, int, int], 171 | NA: float = 1.4, 172 | wavelength: float = 0.610, 173 | M: float = 100, 174 | ns: float = 1.33, 175 | ng0: float = 1.5, 176 | ng: float = 1.5, 177 | ni0: float = 1.5, 178 | ni: float = 1.5, 179 | ti0: float = 150, 180 | tg0: float = 170, 181 | tg: float = 170, 182 | res_lateral: float = 0.1, 183 | res_axial: float = 0.25, 184 | pZ: float = 0, 185 | use_square: bool = False 186 | ) -> torch.Tensor: 187 | """Function to generate a Gibson-Lanni PSF 188 | 189 | :param shape: Size of the PSF array in each dimension [(Z), Y, X], 190 | :param NA: Numerical aperture, 191 | :param wavelength: Wavelength in microns, 192 | :param M: Magnification, 193 | :param ns: Specimen refractive index (RI), 194 | :param ng0: Coverslip RI design value, 195 | :param ng: Coverslip RI experimental value, 196 | :param ni0: Immersion medium RI design value, 197 | :param ni: Immersion medium RI experimental value, 198 | :param ti0: microns, working distance (immersion medium thickness) design value, 199 | :param tg0: microns, coverslip thickness design value, 200 | :param tg: microns, coverslip thickness experimental value, 201 | :param res_lateral: Lateral resolution in microns, 202 | :param res_axial: Axial resolution in microns, 203 | :param pZ: microns, particle distance from coverslip 204 | :param use_square: If true, calculate the square of the Gibson-Lanni model to simulate a 205 | pinhole. It then gives a PSF for a confocal image 206 | """ 207 | filter_ = SPSFGibsonLanni(shape, NA, wavelength, M, ns, 208 | ng0, ng, ni0, ni, ti0, tg0, tg, 209 | res_lateral, res_axial, pZ, use_square) 210 | return filter_() 211 | 212 | 213 | metadata = { 214 | 'name': 'SPSFGibsonLanni', 215 | 'label': 'Gibson Lanni PSF', 216 | 'fnc': spsf_gibson_lanni, 217 | 'inputs': { 218 | 'shape': { 219 | 'type': 'zyx_int', 220 | 'label': 'Size', 221 | 'help': 'Regularisation parameter', 222 | 'default': [11, 128, 128] 223 | }, 224 | 'NA': { 225 | 'type': 'float', 226 | 'label': 'Numerical aperture', 227 | 'help': 'Numerical aperture', 228 | 'default': 1.4 229 | }, 230 | 'wavelength': { 231 | 'type': 'float', 232 | 'label': 'Wavelength', 233 | 'help': 'Wavelength', 234 | 'default': 0.610 235 | }, 236 | 'M': { 237 | 'type': 'float', 238 | 'label': 'Magnification', 239 | 'help': 'Magnification', 240 | 'default': 100 241 | }, 242 | 'ns': { 243 | 'type': 'float', 244 | 'label': 'ns', 245 | 'help': 'Specimen refractive index (RI)', 246 | 'default': 1.33 247 | }, 248 | 'ng0': { 249 | 'type': 'float', 250 | 'label': 'ng0', 251 | 'help': 'Coverslip RI design value', 252 | 'default': 1.5 253 | }, 254 | 'ng': { 255 | 'type': 'float', 256 | 'label': 'ng', 257 | 'help': 'coverslip RI experimental value', 258 | 'default': 1.5 259 | }, 260 | 'ni0': { 261 | 'type': 'float', 262 | 'label': 'ni0', 263 | 'help': 'Immersion medium RI design value', 264 | 'default': 1.5 265 | }, 266 | 'ni': { 267 | 'type': 'float', 268 | 'label': 'ni0', 269 | 'help': 'Immersion medium RI experimental value', 270 | 'default': 1.5 271 | }, 272 | 'ti0': { 273 | 'type': 'float', 274 | 'label': 'ti0', 275 | 'help': 'microns, working distance (immersion medium thickness) design value', 276 | 'default': 150 277 | }, 278 | 'tg0': { 279 | 'type': 'float', 280 | 'label': 'tg0', 281 | 'help': 'microns, coverslip thickness design value', 282 | 'default': 170 283 | }, 284 | 'tg': { 285 | 'type': 'float', 286 | 'label': 'tg', 287 | 'help': 'microns, coverslip thickness experimental value', 288 | 'default': 170 289 | }, 290 | 'res_lateral': { 291 | 'type': 'float', 292 | 'label': 'Lateral resolution', 293 | 'help': 'Lateral resolution in microns', 294 | 'default': 0.1 295 | }, 296 | 'res_axial': { 297 | 'type': 'float', 298 | 'label': 'Axial resolution', 299 | 'help': 'Axial resolution in microns', 300 | 'default': 0.25 301 | }, 302 | 'pZ': { 303 | 'type': 'float', 304 | 'label': 'Particle position', 305 | 'help': 'Particle distance from coverslip in microns', 306 | 'default': 0 307 | }, 308 | 'use_square': { 309 | 'type': 'bool', 310 | 'label': 'Confocal', 311 | 'help': 'Check for confocal PSF, uncheck for widefield', 312 | 'default': True 313 | } 314 | }, 315 | 'outputs': { 316 | 'image': { 317 | 'type': 'Image', 318 | 'label': 'PSF Gibson-Lanni' 319 | }, 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /sdeconv/psfs/interface.py: -------------------------------------------------------------------------------- 1 | """Interface for a psf generator""" 2 | import torch 3 | from sdeconv.core import SObservable 4 | 5 | 6 | class SPSFGenerator(SObservable): 7 | """Interface for a psf generator""" 8 | def __init__(self): 9 | super().__init__() 10 | self.type = 'SPSFGenerator' 11 | 12 | def __call__(self) -> torch.Tensor: 13 | """Generate the PSF 14 | 15 | return: PSF image [(Z) Y X] 16 | """ 17 | raise NotImplementedError('SPSFGenerator is an interface. Please implement the' 18 | ' __call__ method') 19 | -------------------------------------------------------------------------------- /sdeconv/psfs/lorentz.py: -------------------------------------------------------------------------------- 1 | """Implements the Lorentz Point Spread Function generator""" 2 | import math 3 | import torch 4 | from ..core import SSettings 5 | from .interface import SPSFGenerator 6 | 7 | 8 | class SPSFLorentz(SPSFGenerator): 9 | """Generate a Lorentz PSF 10 | 11 | :param gamma: Width of the Lorentz in each dimension [(Z), Y, X] 12 | :param shape: Size of the PSF image in each dimension [(Z), Y, X] 13 | """ 14 | def __init__(self, 15 | gamma: tuple[float, float] | tuple[float, float, float], 16 | shape: tuple[int, int, int] | tuple[int, int, int]): 17 | super().__init__() 18 | self.gamma = gamma 19 | self.shape = shape 20 | self.psf_ = None 21 | 22 | @staticmethod 23 | def _normalize_inputs(gamma: tuple[float, float] | tuple[float, float, float], 24 | shape: tuple[int, int, int] | tuple[int, int, int] 25 | ) -> tuple: 26 | """Remove batch dimention if it exists 27 | 28 | :param sigma: Width of the PSF 29 | :param shape: Shape of the PSF 30 | :return: The modified sigma and shape 31 | """ 32 | if len(shape) == 3 and shape[0] == 1: 33 | return gamma[1:], shape[1:] 34 | return gamma, shape 35 | 36 | def __call__(self) -> torch.Tensor: 37 | """Calculate the PSF image 38 | 39 | :return: The PSF image as a Tensor 40 | """ 41 | self.gamma, self.shape = SPSFLorentz._normalize_inputs(self.gamma, self.shape) 42 | if len(self.shape) == 2: 43 | self.psf_ = torch.zeros((self.shape[0], self.shape[1])).to(SSettings.instance().device) 44 | x_0 = math.floor(self.shape[0] / 2) 45 | y_0 = math.floor(self.shape[1] / 2) 46 | # print('center= (', x0, ', ', y0, ')') 47 | xx_, yy_ = torch.meshgrid(torch.arange(0, self.shape[0]), 48 | torch.arange(0, self.shape[1]), 49 | indexing='ij') 50 | self.psf_ = 1 / (1 + torch.pow((xx_ - x_0)/(0.5*self.gamma[0]), 2) + 51 | torch.pow((yy_ - y_0)/(0.5*self.gamma[1]), 2)) 52 | self.psf_ = self.psf_ / torch.sum(self.psf_) 53 | elif len(self.shape) == 3: 54 | self.psf_ = torch.zeros(self.shape).to(SSettings.instance().device) 55 | x_0 = math.floor(self.shape[2] / 2) 56 | y_0 = math.floor(self.shape[1] / 2) 57 | z_0 = math.floor(self.shape[0] / 2) 58 | zzz, yyy, xxx = torch.meshgrid(torch.arange(0, self.shape[0]), 59 | torch.arange(0, self.shape[1]), 60 | torch.arange(0, self.shape[2]), 61 | indexing='ij') 62 | self.psf_ = 1 / (1 + torch.pow((xxx - x_0)/(0.5*self.gamma[2]), 2) + 63 | torch.pow((yyy - y_0)/(0.5*self.gamma[1]), 2) + 64 | torch.pow((zzz - z_0)/(0.5*self.gamma[0]), 2)) 65 | self.psf_ = self.psf_ / torch.sum(self.psf_) 66 | else: 67 | raise ValueError('SPSFLorentz: can generate only 2D or 3D PSFs') 68 | return self.psf_ 69 | 70 | 71 | def spsf_lorentz(gamma: tuple[float, float] | tuple[float, float, float], 72 | shape: tuple[int, int, int] | tuple[int, int, int] 73 | ) -> torch.Tensor: 74 | """Function to generate a Lorentz PSF 75 | 76 | :param gamma: Width of the Lorentz in each dimension [(Z), Y, X], 77 | :param shape: Size of the PSF image in each dimension [(Z), Y, X], 78 | :return: The PSF mage 79 | """ 80 | filter_ = SPSFLorentz(gamma, shape) 81 | return filter_() 82 | 83 | 84 | metadata = { 85 | 'name': 'SPSFLorentz', 86 | 'label': 'Lorentz PSF', 87 | 'fnc': spsf_lorentz, 88 | 'inputs': { 89 | 'gamma': { 90 | 'type': 'zyx_float', 91 | 'label': 'Gamma', 92 | 'help': 'PSF width in each direction', 93 | 'default': [0, 1.5, 1.5] 94 | }, 95 | 'shape': { 96 | 'type': 'zyx_int', 97 | 'label': 'Size', 98 | 'help': 'PSF image shape', 99 | 'default': [1, 128, 128] 100 | } 101 | }, 102 | 'outputs': { 103 | 'image': { 104 | 'type': 'Image', 105 | 'label': 'PSF Lorentz' 106 | }, 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = sdeconv 3 | version = 1.0.4 4 | author = Sylvain Prigent 5 | author_email = meriadec.prigent@gmail.com 6 | url = https://github.com/sylvainprigent/sdeconv 7 | license = BSD 3-Clause 8 | description = Implementation of 2D and 3D scientific image deconvolution 9 | long_description = file: README.md 10 | long_description_content_type = text/markdown 11 | classifiers = 12 | Development Status :: 2 - Pre-Alpha 13 | Intended Audience :: Developers 14 | Topic :: Software Development :: Testing 15 | Programming Language :: Python 16 | Programming Language :: Python :: 3 17 | Programming Language :: Python :: 3.10 18 | Programming Language :: Python :: 3.11 19 | Operating System :: OS Independent 20 | License :: OSI Approved :: BSD License 21 | 22 | [options] 23 | packages = find: 24 | python_requires = >=3.7 25 | 26 | # add your package requirements here 27 | install_requires = 28 | scipy>=1.8.1 29 | numpy>=1.22.4 30 | torch>=1.11.0 31 | torchvision>=0.12.0 32 | scikit-image>=0.19.2 33 | 34 | [options.package_data] 35 | * = */*.tif 36 | 37 | [options.entry_points] 38 | console_scripts = 39 | spsf = sdeconv.cli.spsf:main 40 | sdeconv = sdeconv.cli.sdeconv:main 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/__init__.py -------------------------------------------------------------------------------- /tests/deconv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/__init__.py -------------------------------------------------------------------------------- /tests/deconv/celegans_richardson_lucy.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/celegans_richardson_lucy.tif -------------------------------------------------------------------------------- /tests/deconv/celegans_spitfire.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/celegans_spitfire.tif -------------------------------------------------------------------------------- /tests/deconv/celegans_wiener.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/celegans_wiener.tif -------------------------------------------------------------------------------- /tests/deconv/pollen_richardson_lucy.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/pollen_richardson_lucy.tif -------------------------------------------------------------------------------- /tests/deconv/pollen_spitfire.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/pollen_spitfire.tif -------------------------------------------------------------------------------- /tests/deconv/pollen_wiener.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/deconv/pollen_wiener.tif -------------------------------------------------------------------------------- /tests/deconv/test_richardson_lucy.py: -------------------------------------------------------------------------------- 1 | """Unit testing the Richardson-Lucy deconvolution implementation""" 2 | import os 3 | import numpy as np 4 | from skimage.io import imread 5 | 6 | from sdeconv.data import celegans, pollen_poison_noise_blurred, pollen_psf 7 | from sdeconv.deconv import SRichardsonLucy 8 | from sdeconv.psfs import SPSFGaussian 9 | 10 | 11 | def test_richardson_lucy_2d(): 12 | """Unit testing the 2D Richardson-Lucy implementation""" 13 | root_dir = os.path.dirname(os.path.abspath(__file__)) 14 | image = celegans() 15 | 16 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13)) 17 | psf = psf_generator() 18 | 19 | filter_ = SRichardsonLucy(psf, niter=30, pad=13) 20 | out_image = filter_(image) 21 | 22 | # imsave(os.path.join(root_dir, 'celegans_richardson_lucy_gpu.tif'), 23 | # out_image.detach().cpu().numpy()) 24 | ref_image = imread(os.path.join(root_dir, 'celegans_richardson_lucy.tif')) 25 | 26 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy(), ref_image, decimal=1) 27 | 28 | 29 | def test_richardson_lucy_3d(): 30 | """Unit testing the 3D Richardson-Lucy implementation""" 31 | root_dir = os.path.dirname(os.path.abspath(__file__)) 32 | image = pollen_poison_noise_blurred() 33 | psf = pollen_psf() 34 | 35 | filter_ = SRichardsonLucy(psf, niter=30, pad=(16, 64, 64)) 36 | out_image = filter_(image) 37 | 38 | # imsave(os.path.join(root_dir, 'pollen_richardson_lucy_linux.tif'), 39 | # out_image.detach().cpu().numpy()) 40 | ref_image = imread(os.path.join(root_dir, 'pollen_richardson_lucy.tif')) 41 | 42 | np.testing.assert_almost_equal(out_image[20, ...].detach().cpu().numpy(), ref_image[20, ...], 43 | decimal=1) 44 | -------------------------------------------------------------------------------- /tests/deconv/test_spitfire.py: -------------------------------------------------------------------------------- 1 | """Unit testing the Spitfire deconvolution implementation""" 2 | import os 3 | import numpy as np 4 | from skimage.io import imread 5 | 6 | from sdeconv.data import celegans, pollen_poison_noise_blurred, pollen_psf 7 | from sdeconv.deconv import Spitfire 8 | from sdeconv.psfs import SPSFGaussian 9 | 10 | 11 | def test_spitfire_2d(): 12 | """Unit testing Spitfire 2D deconvolution""" 13 | root_dir = os.path.dirname(os.path.abspath(__file__)) 14 | image = celegans() 15 | 16 | psf_generator = SPSFGaussian((1.5, 1.5), (15, 15)) 17 | psf = psf_generator() 18 | 19 | filter_ = Spitfire(psf, weight=0.6, reg=0.995, gradient_step=0.01, precision=1e-7, pad=13) 20 | out_image = filter_(image) 21 | 22 | # imsave(os.path.join(root_dir, 'celegans_spitfire.tif'), out_image.detach().cpu().numpy()) 23 | ref_image = imread(os.path.join(root_dir, 'celegans_spitfire.tif')) 24 | 25 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy()/10, ref_image/10, decimal=0) 26 | 27 | 28 | def test_spitfire_3d(): 29 | """Unit testing Spitfire 3D deconvolution""" 30 | root_dir = os.path.dirname(os.path.abspath(__file__)) 31 | image = pollen_poison_noise_blurred() 32 | psf = pollen_psf() 33 | 34 | filter_ = Spitfire(psf, weight=0.6, reg=0.99995, gradient_step=0.01, precision=1e-7) 35 | out_image = filter_(image) 36 | 37 | # imsave(os.path.join(root_dir, 'pollen_spitfire.tif'), out_image.detach().cpu().numpy()) 38 | ref_image = imread(os.path.join(root_dir, 'pollen_spitfire.tif')) 39 | 40 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy(), ref_image, decimal=0) 41 | -------------------------------------------------------------------------------- /tests/deconv/test_wiener.py: -------------------------------------------------------------------------------- 1 | """Unit testing the Wiener deconvolution implementation""" 2 | import os 3 | import numpy as np 4 | from skimage.io import imread 5 | 6 | from sdeconv.data import celegans, pollen_poison_noise_blurred, pollen_psf 7 | from sdeconv.deconv import SWiener 8 | from sdeconv.psfs import SPSFGaussian 9 | 10 | 11 | def test_wiener_2d(): 12 | """Unit testing wiener 2D deconvolution""" 13 | root_dir = os.path.dirname(os.path.abspath(__file__)) 14 | image = celegans() 15 | 16 | psf_generator = SPSFGaussian((1.5, 1.5), (13, 13)) 17 | psf = psf_generator() 18 | 19 | filter_ = SWiener(psf, beta=0.005, pad=13) 20 | out_image = filter_(image) 21 | 22 | # imsave(os.path.join(root_dir, 'celegans_wiener.tif'), out_image.detach().numpy()) 23 | ref_image = imread(os.path.join(root_dir, 'celegans_wiener.tif')) 24 | 25 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy(), ref_image, decimal=1) 26 | 27 | 28 | def test_wiener_3d(): 29 | """Unit testing wiener 3D deconvolution""" 30 | root_dir = os.path.dirname(os.path.abspath(__file__)) 31 | image = pollen_poison_noise_blurred() 32 | psf = pollen_psf() 33 | 34 | filter_ = SWiener(psf, beta=0.0005, pad=(16, 64, 64)) 35 | out_image = filter_(image) 36 | 37 | # imsave(os.path.join(root_dir, 'pollen_wiener.tif'), out_image.detach().numpy()) 38 | ref_image = imread(os.path.join(root_dir, 'pollen_wiener.tif')) 39 | 40 | np.testing.assert_almost_equal(out_image.detach().cpu().numpy(), ref_image, decimal=1) 41 | -------------------------------------------------------------------------------- /tests/psfs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/__init__.py -------------------------------------------------------------------------------- /tests/psfs/gaussian2d.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/gaussian2d.tif -------------------------------------------------------------------------------- /tests/psfs/gaussian3d.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/gaussian3d.tif -------------------------------------------------------------------------------- /tests/psfs/gibsonlanni.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/gibsonlanni.tif -------------------------------------------------------------------------------- /tests/psfs/lorentz2d.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/lorentz2d.tif -------------------------------------------------------------------------------- /tests/psfs/lorentz3d.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sylvainprigent/sdeconv/9b774221eb94fea23ae7a5ea6ad84b0b3632fce9/tests/psfs/lorentz3d.tif -------------------------------------------------------------------------------- /tests/psfs/test_gausian.py: -------------------------------------------------------------------------------- 1 | """Unit testing for Gaussian PSFs""" 2 | 3 | import os 4 | import numpy as np 5 | from skimage.io import imread 6 | 7 | from sdeconv.psfs.gaussian import SPSFGaussian 8 | 9 | 10 | def test_psf_gaussian_2d(): 11 | """Unit testing the 2D Gaussian PSF generator""" 12 | 13 | psf_generator = SPSFGaussian((1.5, 1.5), (15, 15)) 14 | psf = psf_generator() 15 | 16 | root_dir = os.path.dirname(os.path.abspath(__file__)) 17 | # imsave(os.path.join(root_dir, 'gaussian2d.tif'), psf.detach().numpy()) 18 | ref_psf = imread(os.path.join(root_dir, 'gaussian2d.tif')) 19 | 20 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5) 21 | 22 | 23 | def test_psf_gaussian_3d(): 24 | """Unit testing the 3D Gaussian PSF generator""" 25 | 26 | psf_generator = SPSFGaussian((0.5, 1.5, 1.5), (11, 15, 15)) 27 | psf = psf_generator() 28 | 29 | root_dir = os.path.dirname(os.path.abspath(__file__)) 30 | # imsave(os.path.join(root_dir, 'gaussian3d.tif'), psf.detach().numpy()) 31 | ref_psf = imread(os.path.join(root_dir, 'gaussian3d.tif')) 32 | 33 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5) 34 | -------------------------------------------------------------------------------- /tests/psfs/test_gibsonlanni.py: -------------------------------------------------------------------------------- 1 | """Module to test the Gibson-Lanni PSF implementation""" 2 | import os 3 | import numpy as np 4 | from skimage.io import imread 5 | from sdeconv.psfs.gibson_lanni import SPSFGibsonLanni 6 | 7 | 8 | def test_gibson_lanni(): 9 | """An example of how you might test your plugin.""" 10 | 11 | shape = (18, 128, 128) 12 | NA = 1.4 13 | wavelength = 0.610 14 | M = 100 15 | ns = 1.33 16 | ng0 = 1.5 17 | ng = 1.5 18 | ni0 = 1.5 19 | ni = 1.5 20 | ti0 = 150 21 | tg0 = 170 22 | tg = 170 23 | res_lateral = 0.1 24 | res_axial = 0.25 25 | pZ = 0 26 | use_square = False 27 | 28 | psf_generator = SPSFGibsonLanni(shape, NA, wavelength, M, ns, 29 | ng0, ng, ni0, ni, ti0, tg0, tg, 30 | res_lateral, res_axial, pZ, use_square) 31 | psf = psf_generator() 32 | 33 | root_dir = os.path.dirname(os.path.abspath(__file__)) 34 | # imsave(os.path.join(root_dir, 'gibsonlanni.tif'), psf.detach().numpy()) 35 | ref_psf = imread(os.path.join(root_dir, 'gibsonlanni.tif')) 36 | 37 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5) 38 | -------------------------------------------------------------------------------- /tests/psfs/test_lorentz.py: -------------------------------------------------------------------------------- 1 | """Unit testing for Lorentz PSFs""" 2 | 3 | import os 4 | import numpy as np 5 | from skimage.io import imread, imsave 6 | 7 | from sdeconv.psfs.lorentz import SPSFLorentz 8 | 9 | 10 | def test_psf_lorentz_2d(): 11 | """Unit testing the 2D Lorentz PSF generator""" 12 | 13 | psf_generator = SPSFLorentz((1.5, 1.5), (15, 15)) 14 | psf = psf_generator() 15 | 16 | root_dir = os.path.dirname(os.path.abspath(__file__)) 17 | #imsave(os.path.join(root_dir, 'lorentz2d.tif'), psf.detach().numpy()) 18 | ref_psf = imread(os.path.join(root_dir, 'lorentz2d.tif')) 19 | 20 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5) 21 | 22 | 23 | def test_psf_lorentz_3d(): 24 | """Unit testing the 3D Lorentz PSF generator""" 25 | 26 | psf_generator = SPSFLorentz((0.5, 1.5, 1.5), (11, 15, 15)) 27 | psf = psf_generator() 28 | 29 | root_dir = os.path.dirname(os.path.abspath(__file__)) 30 | #imsave(os.path.join(root_dir, 'lorentz3d.tif'), psf.detach().numpy()) 31 | ref_psf = imread(os.path.join(root_dir, 'lorentz3d.tif')) 32 | 33 | np.testing.assert_almost_equal(psf.detach().cpu().numpy(), ref_psf, decimal=5) 34 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/ 2 | [tox] 3 | envlist = py{310, 311}-{linux,macos,windows} 4 | 5 | [gh-actions] 6 | python = 7 | 3.10: py310 8 | 3.11: py311 9 | 10 | [gh-actions:env] 11 | PLATFORM = 12 | ubuntu-latest: linux 13 | macos-latest: macos 14 | windows-latest: windows 15 | 16 | [testenv] 17 | platform = 18 | macos: darwin 19 | linux: linux 20 | windows: win32 21 | passenv = 22 | CI 23 | GITHUB_ACTIONS 24 | DISPLAY_XAUTHORITY 25 | NUMPY_EXPERIMENTAL_ARRAY_FUNCTION 26 | PYVISTA_OFF_SCREEN 27 | deps = 28 | pytest # https://docs.pytest.org/en/latest/contents.html 29 | pytest-cov # https://pytest-cov.readthedocs.io/en/latest/ 30 | pytest-xvfb ; sys_platform == 'linux' 31 | numpy 32 | torch 33 | torchvision 34 | scikit-image 35 | commands = pytest -v --color=yes --cov=sdeconv --cov-report=xml --------------------------------------------------------------------------------