├── .gitignore ├── LICENSE ├── README.md ├── assets └── header.gif ├── deepd3 ├── __init__.py ├── core │ ├── __init__.py │ ├── analysis.py │ ├── dendrite.py │ ├── distance.py │ ├── export.py │ └── spines.py ├── inference │ ├── __init__.py │ ├── batch.py │ └── gui.py ├── model │ ├── __init__.py │ ├── builder.py │ └── utils.py └── training │ ├── __init__.py │ ├── generator.py │ └── stream.py ├── docs ├── Makefile ├── make.bat └── source │ ├── api │ ├── core.rst │ ├── inference.rst │ ├── model.rst │ └── training.rst │ ├── conf.py │ ├── index.rst │ └── userguide │ ├── inference.rst │ └── train.rst ├── examples └── Training DeepD3 model.ipynb ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | dev/ 6 | .vscode/ 7 | examples/Untitled* 8 | examples/*.h5 9 | examples/*.png 10 | examples/*.d3data 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | deepd3/model/simple_unet.py 137 | examples/Training DeepD3 model UNet.ipynb 138 | 139 | # Wandb files 140 | wandb 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DeepD3 project website](https://img.shields.io/website-up-down-green-red/https/naereen.github.io.svg)](https://deepd3.forschung.fau.de/) 2 | [![Documentation Status](https://readthedocs.org/projects/deepd3/badge/?version=latest)](https://deepd3.readthedocs.io/en/latest/?badge=latest) 3 | 4 | 5 | # DeepD3 6 | 7 | We provide DeepD3, a framework for the **d**etection of **d**endritic spines and **d**endrites. 8 | 9 | With DeepD3, you are able to 10 | 11 | * train a deep neural network for dendritic spine and dendrite segmentation 12 | * use pre-trained DeepD3 networks for inference 13 | * build 2D and 3D ROIs 14 | * export results to your favourite biomedical image analysis platform 15 | * use command-line or graphical user interfaces 16 | 17 | ## How to install and run DeepD3 18 | 19 | DeepD3 is written in Python. First, please download and install any Python-containing distribution, such as [Anaconda](https://www.anaconda.com/products/distribution). We recommend Python 3.7 and more recent version. 20 | 21 | Then, installing DeepD3 is as easy as follows: 22 | 23 | pip install deepd3 24 | 25 | Now, you have access to almost all the DeepD3 functionalities. 26 | 27 | If you want to use the DeepD3 Neural Network inference mode, please install **Tensorflow** using either of the following commands: 28 | 29 | # With CPU support only 30 | conda install tensorflow 31 | 32 | # With additional GPU support 33 | conda install tensorflow-gpu 34 | 35 | **Note:** Tensorflow has changed a lot in the recent months, please check back with the most recent [installation manual](https://www.tensorflow.org/install). 36 | 37 | If you would like to access DeepD3-GUIs, use the following two shortcuts in your favorite shell: 38 | 39 | # Opening the segmentation and ROI building GUI 40 | deepd3-inference 41 | 42 | # Opening the training utilities 43 | deepd3-training 44 | 45 | ## Model zoo 46 | 47 | We provide a comprehensive training dataset on zenodo and the [DeepD3 Website](https://deepd3.forschung.fau.de/): 48 | 49 | * [DeepD3_8F.h5](https://deepd3.forschung.fau.de/models/DeepD3_8F.h5) - 8 base filters, original resolution 50 | * [DeepD3_16F.h5](https://deepd3.forschung.fau.de/models/DeepD3_16F.h5) - 16 base filters, original resolution 51 | * [DeepD3_32F.h5](https://deepd3.forschung.fau.de/models/DeepD3_32F.h5) - 32 base filters, original resolution 52 | * [DeepD3_8F_94nm.h5](https://deepd3.forschung.fau.de/models/DeepD3_8F_94nm.h5) - 8 base filters, resized to 94 nm xy resolution 53 | * [DeepD3_16F_94nm.h5](https://deepd3.forschung.fau.de/models/DeepD3_16F_94nm.h5) - 16 base filters, resized to 94 nm xy resolution 54 | * [DeepD3_32F_94nm.h5](https://deepd3.forschung.fau.de/models/DeepD3_32F_94nm.h5) - 32 base filters, resized to 94 nm xy resolution 55 | 56 | Brief description: 57 | 58 | * Full (32F) DeepD3 model trained on 94 nm (fixed) or a blend of resolutions (free) 59 | * Medium (16F) DeepD3 model trained on 94 nm (fixed) or a blend of resolutions (free) 60 | * Tiny (8F) DeepD3 mode trained on 94 nm (fixed) or a blend of resolutions (free) 61 | 62 | 63 | ## Workflow 64 | 65 | ### Train DeepD3 on your own dataset 66 | 67 | Use `deepd3-training` to start the GUI for generating training sets. 68 | 69 | For each of your training set, please provide 70 | 71 | * The original stack as e.g. TIF files 72 | * The spine annotations (binary labels) as TIF or MASK files (the latter from [pipra](https://github.com/anki-xyz/pipra)) 73 | * The dendrite annotations as SWC file (only tested for SWC-files generated by [NeuTube](https://neutracing.com/download/)) 74 | 75 | ### Create training data 76 | 77 | Click on the button "Create training data". For each of your stacks, import the stack, the spine annotation and the dendrite annotation file. 78 | If you dendrite annotation is a SWC file, it will create a 3D reconstruction of the SWC file, which will be stored for later use. If you reload the SWC, it will ask you if you want to keep the 3D reconstruction. 79 | 80 | After importing all files, enter the metadata (resolution in x, y and z) and determine the region of interest using the bounding box and the sliders. 81 | Shortcuts are `B` for current plane is **z begin** and `E` for **z end**. You may enable or disable the cropping to the bounding box. If you are happy, save this region as `d3data`-file. 82 | 83 | ### View training data 84 | 85 | Click on the button "View training data" to re-visit any `d3data` files. You also are able to see and potentially manipulate the metadata associated the `d3data` file. 86 | 87 | ### Arrange training data 88 | 89 | For training, you need to create a `d3set`. This is an assembly of `d3data` files. Click on the button "Arrange training data". Then, simply load all relevant data using the "Add data to set" button and select appropriate `d3data` files. Clicking on "Create dataset" allows you to save your assembly as `d3set` file. 90 | 91 | ### Actual training 92 | 93 | We have prepared a Jupyter notebook in the folder `examples`. Follow the instructions to train your own deep neural network for DeepD3 use. 94 | For professionals, you also may utilize directly the files in `model` and `training` to allow highly individualized training. 95 | You only should ensure that your model allows arbitrary input and outputs two separate channels (dendrites and spines). 96 | 97 | ### Inference 98 | 99 | Open the inference mode using `deepd3-inference`. Load your stack of choice (we currently support TIF stacks) and specify the XY and Z dimensions. Next, you can segment dendrites and dendritic spines using a DeepD3 model from [the model zoo]() by clicking on `Analyze -> Segment dendrite and spines`. Afterwards, you may clean the predictions by clicking on `Analyze -> Cleaning`. Finally, you may build 2D or 3D ROIs using the respective functions in `Analyze`. To test the 3D ROI building, double click in the stack to a region of interest. A window opens that allows you to play with the hyperparameters and segments 3D ROIs in real-time. 100 | 101 | All results can be exported to various file formats. For convenience, DeepD3 saves related data in its "proprietary" hdf5 file (that you can open using any hdf5 viewer/program/library). In particular, you may export the predictions as TIF files, the ROIs to ImageJ file format or a folder, the ROI map to a TIF file, or the ROI centroids to a file. 102 | 103 | Most functions can be assessed using a batch command script located in `deepd3/inference/batch.py`. 104 | 105 | 106 | ## How to cite 107 | 108 | @article{10.1371/journal.pcbi.1011774, 109 | doi = {10.1371/journal.pcbi.1011774}, 110 | author = {Fernholz, Martin H. P. AND Guggiana Nilo, Drago A. AND Bonhoeffer, Tobias AND Kist, Andreas M.}, 111 | journal = {PLOS Computational Biology}, 112 | publisher = {Public Library of Science}, 113 | title = {DeepD3, an open framework for automated quantification of dendritic spines}, 114 | year = {2024}, 115 | month = {02}, 116 | volume = {20}, 117 | url = {https://doi.org/10.1371/journal.pcbi.1011774}, 118 | pages = {1-19}, 119 | number = {2}, 120 | } 121 | 122 | -------------------------------------------------------------------------------- /assets/header.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankilab/DeepD3/94de9d4697f00e82097c8775b924bf7ba4e624a7/assets/header.gif -------------------------------------------------------------------------------- /deepd3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankilab/DeepD3/94de9d4697f00e82097c8775b924bf7ba4e624a7/deepd3/__init__.py -------------------------------------------------------------------------------- /deepd3/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankilab/DeepD3/94de9d4697f00e82097c8775b924bf7ba4e624a7/deepd3/core/__init__.py -------------------------------------------------------------------------------- /deepd3/core/analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PyQt5.QtCore import QObject, pyqtSignal 3 | import imageio as io 4 | import flammkuchen as fl 5 | import os 6 | import cv2 7 | from skimage.measure import find_contours, moments 8 | from skimage.feature import peak_local_max 9 | from skimage.segmentation import watershed 10 | from skimage.draw import disk 11 | from tqdm import tqdm 12 | import cc3d 13 | from numba import njit 14 | from deepd3.core.dendrite import sphere 15 | from scipy.ndimage import grey_closing, binary_dilation, label as labelImage, distance_transform_edt 16 | 17 | @njit 18 | def centroid3D(im): 19 | """Computes centroid from a 3D binary image 20 | 21 | Args: 22 | im (numpy.ndarray): binary image 23 | 24 | Returns: 25 | tuple: centroid coordinates z,y,x 26 | """ 27 | z = 0 28 | y = 0 29 | x = 0 30 | n = 0 31 | 32 | for i in range(im.shape[0]): 33 | for j in range(im.shape[1]): 34 | for k in range(im.shape[2]): 35 | if im[i,j,k]: 36 | z += i 37 | y += j 38 | x += k 39 | n += 1 40 | 41 | if n == 0: 42 | return 0, 0, 0 43 | else: 44 | return z/n, y/n, x/n 45 | 46 | @njit(cache=True) 47 | def centroids3D_from_labels(labels): 48 | """Computes the centroid for each label in an 3D stack containing image labels. 49 | 0 is background, 1...N are foreground labels. 50 | This function uses image moments to compute the centroid. 51 | 52 | Args: 53 | labels (numpy.ndarray): ROI labeled image (0...N) 54 | 55 | Returns: 56 | tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): 57 | Returns first-order moments, zero-order moments and covered planes 58 | """ 59 | N = np.max(labels) 60 | # First-order moments for z, y, and x 61 | cs = np.zeros((N+1, 3), dtype=np.float32) 62 | # Zero-order moment (area/mass of label) 63 | px = np.zeros(N+1, dtype=np.float32) 64 | # Coverage of planes (ROIs x z-planes) 65 | planes = np.zeros((N+1, labels.shape[0]), dtype=np.int32) 66 | 67 | # Fast for-loops using numba 68 | for i in range(labels.shape[0]): 69 | for j in range(labels.shape[1]): 70 | for k in range(labels.shape[2]): 71 | l = labels[i,j,k] 72 | cs[l, 0] += i # M 1,0,0 73 | cs[l, 1] += j # M 0,1,0 74 | cs[l, 2] += k # M 0,0,1 75 | px[l] += 1 # M 0,0,0 76 | planes[l,i] = 1 # plane coverage 77 | 78 | return cs, px, planes 79 | 80 | @njit(cache=True) 81 | def minMaxProbability(labels, prediction): 82 | """Computes the minimum and maximum probabilty of a prediction map given a label map 83 | 84 | Args: 85 | labels (numpy.ndarray): labels 86 | prediction (numpy.ndarray): prediction map with probabilities 0 ... 1 87 | 88 | Returns: 89 | numpy.ndarray: for each label the minimum and maximum probability 90 | """ 91 | probs = np.ones((labels.max(), 2)) 92 | probs[:, 1] = 0 93 | 94 | for i in range(labels.shape[0]): 95 | for j in range(labels.shape[1]): 96 | for k in range(labels.shape[2]): 97 | l = labels[i,j,k] 98 | p = prediction[i,j,k] 99 | 100 | if probs[l, 0] > p: 101 | probs[l, 0] = p 102 | 103 | if probs[l, 1] < p: 104 | probs[l, 1] = p 105 | 106 | return probs 107 | 108 | 109 | @njit(cache=True) 110 | def cleanLabels(labels, rois_to_delete): 111 | """Cleans labels from label stack. Set labels in rois_to_delete to background. 112 | 113 | Args: 114 | labels ([type]): [description] 115 | rois_to_delete ([type]): [description] 116 | 117 | Returns: 118 | [type]: [description] 119 | """ 120 | clean_labels = np.zeros(labels.shape, dtype=np.int32) 121 | 122 | # Iterate through stack 123 | for i in range(labels.shape[0]): 124 | for j in range(labels.shape[1]): 125 | for k in range(labels.shape[2]): 126 | l = labels[i,j,k] 127 | 128 | # If it is background label, go on 129 | if l == 0: 130 | continue 131 | 132 | # Not background label, check if label should remain 133 | if (rois_to_delete==l).sum() > 0: 134 | clean_labels[i,j,k] = 0 135 | 136 | else: 137 | clean_labels[i,j,k] = l 138 | 139 | return clean_labels 140 | 141 | @njit(cache=True) 142 | def reid(labels): 143 | """Relabel an existing label map to ensure continuous label ids 144 | 145 | Args: 146 | labels (numpy.ndarray): original label map 147 | 148 | Returns: 149 | numpy.ndarray: re-computed label map 150 | """ 151 | ls = np.unique(labels) 152 | 153 | new_labels = np.zeros(labels.shape, dtype=np.int32) 154 | 155 | for i in range(labels.shape[0]): 156 | for j in range(labels.shape[1]): 157 | for k in range(labels.shape[2]): 158 | l = labels[i,j,k] 159 | 160 | if l == 0: 161 | continue 162 | 163 | else: 164 | new_labels[i,j,k] = np.argmax(ls==l) 165 | 166 | return new_labels 167 | 168 | @njit(cache=True) 169 | def getROIsizes(labels): 170 | """Get the ROI size for each label with one single stack pass 171 | 172 | Args: 173 | labels (numpy.ndarray): label map 174 | 175 | Returns: 176 | numpy.ndarray: the size of each ROI area 177 | """ 178 | roi_sizes = np.zeros(np.max(labels), dtype=np.uint16) 179 | 180 | for i in range(labels.shape[0]): 181 | for j in range(labels.shape[1]): 182 | for k in range(labels.shape[2]): 183 | l = labels[i,j,k] 184 | roi_sizes[l] += 1 185 | 186 | return roi_sizes 187 | 188 | ########################## 189 | # 3D connected components 190 | ########################## 191 | 192 | # @njit(cache=True) 193 | def _get_sorted_seeds(stack, threshold=0.8): 194 | """Sort seeds according to their highest prediction value 195 | 196 | Args: 197 | stack (numpy ndarray): The stack with the predictions 198 | threshold (float, optional): The threshold for being a seed pixel. Defaults to 0.8. 199 | 200 | Returns: 201 | numpy.ndarray: seed coordinates sorted by prediction value 202 | """ 203 | coords = np.nonzero(stack>=threshold) 204 | intensities = stack[coords] 205 | # Highest peak first 206 | idx_maxsort = np.argsort(-intensities) 207 | coords = np.transpose(coords)[idx_maxsort] 208 | return coords 209 | 210 | @njit(cache=True) 211 | def _neighbours(x,y,z): 212 | """Generates 26-connected neighbours 213 | 214 | Args: 215 | x (int): x-value 216 | y (int): y-value 217 | z (int): z-value 218 | 219 | Returns: 220 | list: neighbour indices of a given point (x,y,z) 221 | """ 222 | look = [] 223 | 224 | for i in range(x-1, x+2): 225 | for j in range(y-1, y+2): 226 | for k in range(z-1, z+2): 227 | if not (i == x and j == y and k == z): 228 | look.append((i,j,k)) 229 | 230 | return look 231 | 232 | 233 | @njit 234 | def _distance_to_seed(seed, pos, delta_xy = 1, delta_z = 1): 235 | """Computes the euclidean distance between seed pixel and current position `pos` 236 | 237 | Args: 238 | seed (tuple): seed pixel coordinates (x,y,z) 239 | pos (tuple): current position coordinates (x,y,z) 240 | 241 | Returns: 242 | float: euclidean distance between seed and current position 243 | """ 244 | a = (seed[0] * delta_xy - pos[0] * delta_xy)**2 245 | b = (seed[1] * delta_xy - pos[1] * delta_xy)**2 246 | c = (seed[2] * delta_z - pos[2] * delta_z)**2 247 | 248 | return np.sqrt(a+b+c) 249 | 250 | @njit 251 | def connected_components_3d(prediction, seeds, delta, threshold, distance, dimensions): 252 | """Computes connected components in 3D using various constraints. 253 | Each ROI is grown from a seed pixel. From there, in a 26-neighbour fashion more 254 | pixels are added iteratively. Each additional pixel needs to fulfill the following requirements: 255 | 256 | * The new pixel's intensity needs to be in a given range relative to the seed intensity (`delta`) 257 | * The new pixel's intensity needs to be above a given `threshold` 258 | * The new pixel's position needs to be in the vicinity (`distance`) of the seed pixel 259 | 260 | Each pixel can only be assigned to one ROI once. 261 | 262 | Args: 263 | prediction (numpy.ndarray): prediction from deep neural network 264 | seeds (numpy.ndarray): seed pixels 265 | delta (float): difference to seed pixel intensity 266 | threshold (float): threshold for pixel intensity 267 | distance (int or float): maximum euclidean distance in microns to seed pixel 268 | dimensions (dict(float, float)): xy and z pitch in microns 269 | 270 | Returns: 271 | tuple(labels, N): the labelled stack and the number of found ROIs 272 | """ 273 | # Initialize everything as background 274 | im = np.zeros(prediction.shape, dtype=np.uint16) 275 | L = 1 # Start with label 1 276 | 277 | delta_xy = dimensions[0] 278 | delta_z = dimensions[1] 279 | 280 | # Iterate through seed pixels 281 | for i in range(seeds.shape[0]): 282 | # Retrieve location and seed intensity 283 | x0, y0, z0 = seeds[i] 284 | t = prediction[x0, y0, z0] 285 | 286 | # Seed pixel has been assigned to a label already, skip 287 | if im[x0, y0, z0]: 288 | continue 289 | 290 | # Start with the floodfilling 291 | neighbours = [(x0, y0, z0)] 292 | 293 | while len(neighbours): 294 | # Look at next pixel 295 | x, y, z = neighbours.pop() 296 | 297 | # Current pixel not in stack 298 | if x >= im.shape[0] or x < 0: 299 | continue 300 | 301 | if y >= im.shape[1] or y < 0: 302 | continue 303 | 304 | if z >= im.shape[2] or z < 0: 305 | continue 306 | 307 | # Intensity at given point 308 | p0 = prediction[x, y, z] 309 | 310 | # A good pixel should be 311 | # - similar to the seed px (delta) 312 | # - intensity above a given threshold 313 | # - in label image it is still a background px 314 | # - distance to seed is lower than distance 315 | if abs(p0 - t) <= delta * t and \ 316 | p0 > threshold and \ 317 | im[x, y, z] == 0 and \ 318 | _distance_to_seed((x0, y0, z0), (x,y,z), delta_xy, delta_z) < distance: 319 | 320 | # Assign pixel current label 321 | im[x,y,z] = L 322 | 323 | # Look at neighbours 324 | neighbours.extend(_neighbours(x,y,z)) 325 | 326 | # Finished with this label 327 | L += 1 328 | 329 | return im, L-1 330 | 331 | 332 | class Stack(QObject): 333 | tileSignal = pyqtSignal(int, int) 334 | 335 | def __init__(self, fn, pred_fn=None, dimensions=dict(xy=0.094, z=0.5)): 336 | """Stack 337 | 338 | Args: 339 | fn (str): Path to file to be openend via imagio 340 | pred_fn (str, optional): Path to prediction files. Defaults to None. 341 | dimensions (dict, optional): XY and Z dimensions. Defaults to dict(xy=0.094, z=0.5). 342 | """ 343 | super().__init__() 344 | self.stack = np.asarray(io.mimread(fn, memtest=False)) 345 | 346 | # If stack is only an image, create dummy dimension 347 | if len(self.stack.shape) == 2: 348 | self.stack = self.stack[None] 349 | 350 | # Prediction and preview should be pre-allocated 351 | self.prediction = np.zeros(self.stack.shape+(3,), dtype=np.float32) 352 | self.preview = np.zeros(self.stack.shape+(3,), dtype=np.float32) 353 | self.segmented = False 354 | self.dimensions = dimensions 355 | 356 | # If predictions are already existing 357 | if pred_fn is not None: 358 | if os.path.exists(pred_fn): 359 | pred = fl.load(pred_fn) 360 | 361 | dendrite_key = 'dendrites' if 'dendrites' in pred.keys() else 'dendrite' 362 | 363 | # Reload and assign dendrite and spine keys 364 | self.prediction[..., 0] = pred[dendrite_key] 365 | self.prediction[..., 1] = pred['spines'] 366 | self.prediction[..., 2] = pred[dendrite_key] 367 | self.segmented = True 368 | 369 | def __getitem__(self, sl): 370 | """Get item slice 371 | 372 | Args: 373 | sl (int): stack index 374 | 375 | Returns: 376 | numpy.array: image 377 | """ 378 | return self.stack[sl] 379 | 380 | def cleanSpines(self, dendrite_threshold=0.7, dendrite_dilation_iterations=12, preview=False): 381 | """Cleaning spines in 2D 382 | 383 | Args: 384 | dendrite_threshold (float, optional): Dendrite threshold for segmentation. Defaults to 0.7. 385 | dendrite_dilation_iterations (int, optional): Iterations to enlarge dendrite. Defaults to 12. 386 | preview (bool, optional): Enable preview option (not overwriting predictions). Defaults to False. 387 | 388 | Returns: 389 | numpy.ndarray: cleaned spines stack 390 | """ 391 | if preview: 392 | d = self.preview[..., 0].copy() 393 | s = self.preview[..., 1].copy() 394 | else: 395 | d = self.prediction[..., 0].copy() 396 | s = self.prediction[..., 1].copy() 397 | 398 | bd = binary_dilation(np.asarray(d) > dendrite_threshold, 399 | iterations=dendrite_dilation_iterations) 400 | 401 | s[~bd] = 0 402 | 403 | return s 404 | 405 | def cleanDendrite3D(self, dendrite_threshold=0.7, min_dendrite_size=100, preview=False): 406 | """Cleaning dendrites in 3D 407 | 408 | Args: 409 | dendrite_threshold (float, optional): Dendrite semantic segmentation threshold. Defaults to 0.7. 410 | min_dendrite_size (int, optional): Minimum dendrite size in px in 3D. Defaults to 100. 411 | preview (bool, optional): Enable preview option. Defaults to False. 412 | 413 | Returns: 414 | numpy.ndarray: Cleaned dendrite 415 | """ 416 | 417 | if preview: 418 | d = self.preview[..., 0].copy() 419 | 420 | else: 421 | d = self.prediction[..., 0].copy() 422 | 423 | # Clean noisy px 424 | d[d < dendrite_threshold] = 0 425 | 426 | # Create labels for all dendritic elements 427 | labels, N = cc3d.connected_components(d > dendrite_threshold, return_N=True) 428 | 429 | # Compute ROI sizes 430 | roi_sizes = getROIsizes(labels) 431 | 432 | # Remove dendrite segments 433 | new_labels = cleanLabels(labels, np.where(roi_sizes < min_dendrite_size)[0]) 434 | 435 | # clean data 436 | d_clean = d * (new_labels > 0).astype(np.float32) 437 | 438 | return d_clean 439 | 440 | def closing(self, iterations=1, preview=False): 441 | """Closing operation on dendrite prediction map 442 | 443 | Args: 444 | iterations (int, optional): Iterations of closing operation. Defaults to 1. 445 | preview (bool, optional): Enables preview mode. Defaults to False. 446 | 447 | Returns: 448 | numpy.ndarray: cleaned dendrite map 449 | """ 450 | if preview: 451 | d = self.preview[..., 0].copy() 452 | 453 | else: 454 | d = self.prediction[..., 0].copy() 455 | 456 | d_clean = d 457 | 458 | for _ in range(iterations): 459 | d_clean = grey_closing(d_clean, size=(3,3,3)) 460 | 461 | return d_clean 462 | 463 | def cleanDendrite(self, dendrite_threshold=0.7, min_dendrite_size=100): 464 | """Cleaning dendrite 465 | 466 | Args: 467 | dendrite_threshold (float, optional): Dendrite probability threshold. Defaults to 0.7. 468 | min_dendrite_size (int, optional): Minimum dendrite size. Defaults to 100. 469 | 470 | Returns: 471 | numpy.ndarray: Cleaned dendrite prediction map 472 | """ 473 | clean = np.zeros_like(self.prediction[..., 0]) 474 | 475 | # Iterate across planes... 476 | for z in tqdm(range(self.prediction.shape[0])): 477 | # Retrieve dendrite prediction 478 | d = self.prediction[z, ..., 0] .copy() 479 | 480 | # Threshold dendrite 481 | d_thresholded = (d > dendrite_threshold).astype(np.uint8) * 255 482 | d_clean = np.zeros_like(d) 483 | 484 | # Find elements 485 | no, labels, _, _ = cv2.connectedComponentsWithStats(d_thresholded) 486 | 487 | for l in range(1, no): 488 | if (labels==l).sum() > min_dendrite_size: 489 | d_clean[labels==l] = d[labels==l] 490 | 491 | clean[z] = d_clean 492 | 493 | self.tileSignal.emit(z, self.prediction.shape[0]) 494 | 495 | return clean 496 | 497 | def predictInset(self, model_fn, tile_size=128, inset_size=96, pad_op=np.mean, zmin=None, zmax=None, clean_dendrite=True, dendrite_threshold=0.7): 498 | """Predict inset 499 | 500 | Args: 501 | model_fn (str): path to Tensorflow/Keras model 502 | tile_size (int, optional): Size of full tile. Defaults to 128. 503 | inset_size (int, optional): Size of tile inset (probability map to be kept). Defaults to 96. 504 | pad_op (_type_, optional): Padding operation. Defaults to np.mean. 505 | zmin (_type_, optional): Z-index minimum. Defaults to None. 506 | zmax (_type_, optional): Z-index maxmimum. Defaults to None. 507 | clean_dendrite (bool, optional): Cleaning dendrite. Defaults to True. 508 | dendrite_threshold (float, optional): Dendrite probability threshold. Defaults to 0.7. 509 | 510 | Returns: 511 | bool: operation was successful 512 | """ 513 | from tensorflow.keras.models import load_model 514 | model = load_model(model_fn, compile=False) 515 | 516 | # Check for z-range, if nothing is provided, assume the whole stack 517 | zmin = zmin if zmin else 0 518 | zmax = zmax if zmax else self.stack.shape[0] 519 | 520 | # Compute rim offset between tile_size and inset_size 521 | off = (tile_size-inset_size)//2 522 | 523 | # Compute the image size that is needed 524 | h = int(np.ceil(self.stack.shape[1] / inset_size) * inset_size) 525 | w = int(np.ceil(self.stack.shape[2] / inset_size) * inset_size) 526 | 527 | # Values for padding 528 | cv = pad_op(self.stack) 529 | 530 | # Padding 531 | stack_zp = np.pad(self.stack, # Pad the stack 532 | ((0, 0), # no pad in z 533 | (off, h-self.stack.shape[1]+off), # pad in y 534 | (off, w-self.stack.shape[2]+off)), # pad in x 535 | constant_values=cv) 536 | 537 | predictions = np.zeros(stack_zp.shape + (3,), dtype=np.float32) 538 | 539 | steps_y = np.arange(0, h, inset_size).astype(np.int) 540 | steps_x = np.arange(0, w, inset_size).astype(np.int) 541 | 542 | # Iterate over tiles, y 543 | for i in tqdm(steps_y): 544 | # Iterate over tiles, x 545 | for j in steps_x: 546 | # Take stack tile (column in z) at the respective tile position 547 | tile = stack_zp[zmin:zmax, i:i+tile_size, j:j+tile_size] #.copy() 548 | # Prepare tile for network inference 549 | tile = (tile.astype(np.float32)-tile.min()) / (tile.max()-tile.min()) * 2 - 1 550 | 551 | # Predict dendrite (pd) and spines (ps), add 1 pseudo-ch 552 | pd, ps = model.predict(tile[..., None]) 553 | 554 | # Save inset at tile position in prediction stack across z 555 | if tile_size != inset_size: 556 | predictions[zmin:zmax, i+off:i+tile_size-off, j+off:j+tile_size-off, 0] = pd.squeeze()[:, off:-off, off:-off] 557 | predictions[zmin:zmax, i+off:i+tile_size-off, j+off:j+tile_size-off, 1] = ps.squeeze()[:, off:-off, off:-off] 558 | predictions[zmin:zmax, i+off:i+tile_size-off, j+off:j+tile_size-off, 2] = pd.squeeze()[:, off:-off, off:-off] 559 | 560 | else: 561 | predictions[zmin:zmax, i:i+tile_size, j:j+tile_size, 0] = pd.squeeze() 562 | predictions[zmin:zmax, i:i+tile_size, j:j+tile_size, 1] = ps.squeeze() 563 | predictions[zmin:zmax, i:i+tile_size, j:j+tile_size, 2] = pd.squeeze() 564 | 565 | self.tileSignal.emit(np.argmax(steps_y==i), steps_y.size) 566 | 567 | self.prediction = predictions[:, off:off+self.stack.shape[1], off:off+self.stack.shape[2]] 568 | 569 | self.segmented = True 570 | return True 571 | 572 | def predictWholeImage(self, model_fn): 573 | """Predict whole image, plane by plane 574 | 575 | Args: 576 | model_fn (str): path to Tensorflow/Keras model file 577 | 578 | Returns: 579 | bool: operation was successful 580 | """ 581 | from tensorflow.keras.models import load_model 582 | 583 | model = load_model(model_fn, compile=False) 584 | 585 | # Iterate over tiles, y 586 | for z in tqdm(range(self.stack.shape[0])): 587 | # Iterate over tiles, x 588 | plane = self.stack[z].copy() 589 | plane = (plane - plane.min()) / (plane.max()-plane.min()) * 2 - 1 590 | 591 | h, w = plane.shape[0], plane.shape[1] 592 | 593 | # if height or widht is not divisible by 32 (issue with neural networks) 594 | if h % 32 or w % 32: 595 | plane = np.pad(plane, # Pad the plane 596 | ((0, 32-h%32 if h%32 else 0), # pad in y 597 | (0, 32-w%32 if w%32 else 0)), # pad in x 598 | mode='reflect') # already normalized 599 | 600 | # Predict dendrite (pd) and spines (ps), add 1 pseudo-ch 601 | pd, ps = model.predict(plane[None, ..., None]) 602 | 603 | # Dendrite and Spine prediction, crop back 604 | d = pd.squeeze()[:h, :w] 605 | s = ps.squeeze()[:h, :w] 606 | 607 | self.prediction[z, ..., 0] = d 608 | self.prediction[z, ..., 1] = s 609 | self.prediction[z, ..., 2] = d 610 | 611 | self.tileSignal.emit(z, self.stack.shape[0]) 612 | 613 | self.segmented = True 614 | return True 615 | 616 | def predictFourFold(self, model_fn, tile_size=128, inset_size=96, pad_op=np.mean, zmin=None, zmax=None): 617 | """Similar to `predictInset` (single tile prediction), but with four-way correction 618 | 619 | Args: 620 | model_fn (str): path to Tensorflow/Keras model 621 | tile_size (int, optional): Size of full tile. Defaults to 128. 622 | inset_size (int, optional): Size of tile inset (probability map to be kept). Defaults to 96. 623 | pad_op (_type_, optional): Padding operation. Defaults to np.mean. 624 | zmin (_type_, optional): Z-index minimum. Defaults to None. 625 | zmax (_type_, optional): Z-index maxmimum. Defaults to None. 626 | 627 | Returns: 628 | bool: operation was successful 629 | """ 630 | 631 | from tensorflow.keras.models import load_model 632 | model = load_model(model_fn, compile=False) 633 | 634 | # Check for z-range, if nothing is provided, assume the whole stack 635 | zmin = zmin if zmin else 0 636 | zmax = zmax if zmax else self.stack.shape[0] 637 | 638 | # Create empty array for z-range and 3 channels (could be two, but then color is easy going) 639 | off = (tile_size-inset_size)//2 640 | 641 | # Zero pad image to ensure that full stack is analyzed 642 | cv = pad_op(self.stack) 643 | stack_zp = np.pad(self.stack, ((0, 0), (tile_size, tile_size), (tile_size, tile_size)), constant_values=cv) 644 | 645 | # all four predictions to be stored 646 | predictions = np.zeros((4,)+stack_zp.shape + (3,), dtype=np.float32) 647 | 648 | # Predict stack 4 times with different offsets to ensure 649 | # that the prediction is properly done at the edges 650 | for it, (start_y, start_x) in enumerate([(0, 0), (tile_size//2, 0), (0, tile_size//2), (tile_size//2, tile_size//2)]): 651 | steps_y = np.arange(start_y, self.stack.shape[1], inset_size).astype(np.int) 652 | steps_x = np.arange(start_x, self.stack.shape[2], inset_size).astype(np.int) 653 | 654 | # Iterate over tiles, y 655 | for i in tqdm(steps_y): 656 | # Iterate over tiles, x 657 | for j in steps_x: 658 | # Take stack tile (column in z) at the respective tile position 659 | tile = stack_zp[zmin:zmax, i:i+tile_size, j:j+tile_size] #.copy() 660 | # Prepare tile for network inference 661 | tile = (tile.astype(np.float32)-tile.min()) / (tile.max()-tile.min()) * 2 - 1 662 | 663 | # Predict dendrite (pd) and spines (ps), add 1 pseudo-ch 664 | pd, ps = model.predict(tile[..., None]) 665 | 666 | # Save inset at tile position in prediction stack across z 667 | if tile_size != inset_size: 668 | predictions[it, zmin:zmax, i+off:i+tile_size-off, j+off:j+tile_size-off, 0] = pd.squeeze()[:, off:-off, off:-off] 669 | predictions[it, zmin:zmax, i+off:i+tile_size-off, j+off:j+tile_size-off, 1] = ps.squeeze()[:, off:-off, off:-off] 670 | predictions[it, zmin:zmax, i+off:i+tile_size-off, j+off:j+tile_size-off, 2] = pd.squeeze()[:, off:-off, off:-off] 671 | 672 | else: 673 | predictions[it, zmin:zmax, i:i+tile_size, j:j+tile_size, 0] = pd.squeeze() 674 | predictions[it, zmin:zmax, i:i+tile_size, j:j+tile_size, 1] = ps.squeeze() 675 | predictions[it, zmin:zmax, i:i+tile_size, j:j+tile_size, 2] = pd.squeeze() 676 | 677 | self.tileSignal.emit(np.argmax(steps_y==i), steps_y.size) 678 | 679 | self.prediction = predictions.mean(0)[:, tile_size:-tile_size, tile_size:-tile_size] 680 | 681 | self.segmented = True 682 | return True 683 | 684 | 685 | class ROI3D_Creator(QObject): 686 | zSignal = pyqtSignal(int, int) 687 | log = pyqtSignal(str) 688 | 689 | def __init__(self, dendrite_prediction, spine_prediction, mode='floodfill', areaThreshold=0.25, 690 | peakThreshold=0.8, seedDelta=0.1, distanceToSeed=10, dimensions=dict(xy=0.094, z=0.5)): 691 | """3D ROI Creator. 692 | 693 | Given the arguments, 3D ROIs are built dynamically from dendrite and spine prediction. 694 | 695 | Args: 696 | dendrite_prediction (numpy.ndarray): dendrite prediction probability stack 697 | spine_prediction (numpy.ndarray): spine prediction probability stack 698 | mode (str, optional): Mode for building 3D rois (floodfill or connected components). Defaults to 'floodfill'. 699 | areaThreshold (float, optional): Area threshold for floodfilling and connected components. Defaults to 0.25. 700 | peakThreshold (float, optional): Peak threshold for finding seed points. Defaults to 0.8. 701 | seedDelta (float, optional): Difference to seed in terms of relative probability. Defaults to 0.1. 702 | distanceToSeed (int, optional): Distance to seed px in micrometer. Defaults to 10. 703 | dimensions (dict, optional): Dimensions in xy and z in in micrometer. Defaults to dict(xy=0.094, z=0.5). 704 | """ 705 | 706 | super().__init__() 707 | self.dendrite_prediction = dendrite_prediction 708 | self.spine_prediction = spine_prediction 709 | self.mode = mode # floodfill or thresholded 710 | self.areaThreshold = areaThreshold 711 | self.peakThreshold = peakThreshold 712 | self.seedDelta = seedDelta 713 | self.distanceToSeed = distanceToSeed 714 | self.dimensions = dimensions 715 | self.roi_map = np.zeros_like(spine_prediction, dtype=np.int32) 716 | self.rois = {} 717 | 718 | self.computeContours = True 719 | 720 | def create(self, minPx, maxPx, minPlanes, applyWatershed=False, dilation=0): 721 | """Create 3D ROIs 722 | 723 | Args: 724 | minPx (int): only retain 3D ROIs containing at least `minPx` pixels 725 | maxPx (int): only retain 3D ROIs containing at most `maxPx` pixels 726 | minPlanes (int): only retain 3D ROIs spanning at least `minPlanes` planes 727 | applyWatershed (bool, optional): Apply watershed algorithm to divide ROIs. Defaults to False. 728 | dilation (int, optional): Dilate dendrite probability map. Defaults to 0. 729 | 730 | Returns: 731 | int: number of retained ROIs 732 | """ 733 | ROI_id = 0 734 | 735 | # Find raw labels 736 | self.log.emit("Create labels") 737 | 738 | if self.mode == 'floodfill': 739 | # Find all potential seed pixels that are at least at peakThreshold 740 | seeds = _get_sorted_seeds(self.spine_prediction, self.peakThreshold) 741 | 742 | # Generate all labels using custom 3D connected components 743 | labels, N = connected_components_3d(self.spine_prediction, 744 | seeds, 745 | self.seedDelta, 746 | self.areaThreshold, 747 | self.distanceToSeed, 748 | (self.dimensions['xy'], self.dimensions['z'])) 749 | 750 | else: 751 | thresholded_im = self.spine_prediction > self.areaThreshold 752 | 753 | if dilation > 0: 754 | thresholded_im = binary_dilation(thresholded_im, np.ones((3,3,3)), iterations=dilation) 755 | 756 | labels, N = cc3d.connected_components(thresholded_im, return_N=True) 757 | 758 | if applyWatershed: 759 | D = distance_transform_edt(labels > 0) 760 | # Seed generation 761 | localMax = peak_local_max(D, indices=False, min_distance=0, footprint=np.ones((3,3,3)), exclude_border=1) 762 | 763 | markers = labelImage(localMax, structure=np.ones((3,3,3)))[0] 764 | 765 | ### data # seeds # eliminate noise 766 | labels = watershed(-D, markers, mask=labels > 0) 767 | 768 | self.log.emit("Compute meta data") 769 | # Compute raw centroids, check size and plane span of ROIs 770 | cs, px, planes = centroids3D_from_labels(labels) 771 | planes = planes.sum(1) 772 | 773 | 774 | self.log.emit("Clean ROIs") 775 | # Find ROIs that do not match the criteria 776 | criteria_mismatch = (planes < minPlanes) | (px < minPx) | (px > maxPx) 777 | rois_to_delete = np.where(criteria_mismatch)[0] 778 | 779 | self.log.emit(f"Removing {len(rois_to_delete)} ROIs...") 780 | 781 | self.log.emit("Clean labels") 782 | # Clean all labels 783 | labels = cleanLabels(labels, rois_to_delete) 784 | labels = reid(labels) 785 | 786 | # Re-compute information 787 | cs, px, planes = centroids3D_from_labels(labels) 788 | planes = planes.sum(1) 789 | 790 | # Compute the centroids again after cleaning the labels 791 | centroids = cs/px[:, None] 792 | 793 | self.roi_centroids = centroids 794 | 795 | if self.computeContours: 796 | 797 | self.log.emit("Compute contours and create plane-wise slices") 798 | for z in tqdm(range(self.spine_prediction.shape[0])): 799 | self.rois[z] = [] 800 | labels_plane = labels[z] 801 | 802 | # Compute contours... 803 | for ROI_id in np.unique(labels_plane): 804 | if ROI_id == 0: 805 | continue 806 | 807 | c = find_contours(labels_plane==ROI_id) 808 | self.rois[z].append({ 809 | 'ROI_id': ROI_id, 810 | 'contour': c[0], 811 | 'centroid': np.asarray(centroids[ROI_id]) 812 | }) 813 | 814 | # Talk to progress bar 815 | self.zSignal.emit(z, self.spine_prediction.shape[0]) 816 | 817 | self.roi_map = labels 818 | self.log.emit("Done.") 819 | return np.max(labels) 820 | 821 | 822 | class ROI2D_Creator(QObject): 823 | zSignal = pyqtSignal(int, int) 824 | 825 | def __init__(self, dendrite_prediction, spine_prediction, threshold): 826 | """2D ROI Creator. 827 | 828 | Creates 2D ROIs dependent on dendrite and spine prediction, as well as threshold 829 | 830 | Args: 831 | dendrite_prediction (_type_): _description_ 832 | spine_prediction (_type_): _description_ 833 | threshold (_type_): _description_ 834 | """ 835 | super().__init__() 836 | self.dendrite_prediction = dendrite_prediction 837 | self.spine_prediction = spine_prediction 838 | self.threshold = threshold 839 | self.roi_map = np.zeros_like(spine_prediction, dtype=np.int32) 840 | self.rois = {} 841 | 842 | def create(self, applyWatershed=False, maskSize=3, minDistance=3): 843 | """Creates 2D ROIs 844 | 845 | Args: 846 | applyWatershed (bool, optional): Apply Watershed algorithm. Defaults to False. 847 | maskSize (int, optional): Size of the distance transform mask. Defaults to 3. 848 | minDistance (int, optional): Minimum distance between ROIs in Watershed algorithm. Defaults to 3. 849 | 850 | Returns: 851 | int: ROIs found 852 | """ 853 | ROIs_found = 0 854 | ROI_id = 0 855 | 856 | for z in tqdm(range(self.spine_prediction.shape[0])): 857 | self.rois[z] = [] 858 | 859 | im = (self.spine_prediction[z] > self.threshold).astype(np.uint8) * 255 860 | 861 | if applyWatershed: 862 | D = cv2.distanceTransform(im, cv2.DIST_L2, maskSize) 863 | Ma = peak_local_max(D, indices=False, footprint=np.ones((3,3)), min_distance=minDistance, labels=im) 864 | foreground_labels = cv2.connectedComponents(Ma.astype(np.uint8)*255)[1] 865 | 866 | labels = watershed(-D, foreground_labels, mask=im) 867 | 868 | centroids = [] 869 | 870 | uq = np.unique(labels) 871 | no = len(uq) 872 | 873 | for ix in uq: 874 | M = moments(labels==ix) 875 | cy, cx = M[1,0]/M[0,0], M[0,1]/M[0,0] 876 | centroids.append([cx,cy]) 877 | 878 | self.roi_map[z] = labels 879 | 880 | else: 881 | # find individual ROIs 882 | no, labels, stats, centroids = cv2.connectedComponentsWithStats(im) 883 | self.roi_map[z] = labels 884 | 885 | # Compute contours... 886 | for roi in range(1, no): 887 | c = find_contours(labels==roi) 888 | self.rois[z].append({ 889 | 'ROI_id': ROI_id, 890 | 'contour': c[0], 891 | 'centroid': np.asarray(centroids[roi]) 892 | }) 893 | ROI_id += 1 894 | 895 | # Talk to progress bar 896 | self.zSignal.emit(z, self.spine_prediction.shape[0]) 897 | 898 | ROIs_found += no-1 899 | 900 | return ROIs_found 901 | 902 | def clean(self, maxD, minS, dendrite_threshold=0.7): 903 | """Cleanes ROIs 904 | 905 | Args: 906 | maxD (int): maximum distance to dendrite in px 907 | minS (int): minimum size of ROIs in px 908 | dendrite_threshold (float, optional): _description_. Defaults to 0.7. 909 | 910 | Returns: 911 | tuple: old ROI count, new ROI count 912 | """ 913 | clean_rois = {} 914 | clean_roi_map = np.zeros_like(self.roi_map) 915 | 916 | old_rois_count = 0 917 | new_rois_count = 0 918 | 919 | ### Iterate over z 920 | for z, rois in tqdm(self.rois.items()): 921 | clean_rois[z] = [] 922 | ### Iterate over ROIs, check if (centroid +- maxD).sum() > 0 923 | for i, roi in enumerate(rois): 924 | old_rois_count += 1 925 | cur_clean_roi = 0 926 | 927 | # Get ROI coordinates 928 | x, y = roi['centroid'] 929 | y = int(y) 930 | x = int(x) 931 | z = int(z) 932 | maxD = int(maxD) 933 | 934 | # Retrieve circular area around the ROI center of mass 935 | rr, cc = disk((y,x), maxD, shape=self.dendrite_prediction.shape[1:3]) 936 | area = self.dendrite_prediction[z, rr, cc] 937 | 938 | # Test if any dendritic pixel are inside of this circle 939 | # and if the ROI size exceeds a given threshold 940 | dendrite_proximity = (area > dendrite_threshold).sum() 941 | roi_size = (self.roi_map[z] == i+1).sum() 942 | 943 | if dendrite_proximity and roi_size >= minS: 944 | new_rois_count += 1 945 | cur_clean_roi += 1 946 | clean_rois[z].append(roi) 947 | clean_roi_map[z][self.roi_map[z] == i+1] = cur_clean_roi 948 | 949 | # Talk to progress bar 950 | self.zSignal.emit(int(z), len(self.rois.keys())) 951 | 952 | self.rois = clean_rois 953 | self.roi_map = clean_roi_map 954 | 955 | return old_rois_count, new_rois_count 956 | 957 | 958 | 959 | 960 | -------------------------------------------------------------------------------- /deepd3/core/dendrite.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import imageio as io 4 | from tqdm import tqdm 5 | import pathlib 6 | import numpy as np 7 | from skimage.draw import disk 8 | from PyQt5.QtCore import QObject, pyqtSignal 9 | 10 | 11 | def line_w_sphere(s, p0, p1, r0, r1, color=1, spacing=[1, 1, 1]): 12 | """Draw a line with width in 3D space 13 | 14 | Args: 15 | s (numpy.ndarray): the 3D stack 16 | p0 (tuple): point 0 (x, y, z) 17 | p1 (tuple): point 1 (x, y, z) 18 | r0 (float): radius for point 0 19 | r1 (float): radius for point 1 20 | color (int, optional): Color for drawing, e.g. 255 for np.uint8 stack. Defaults to 1. 21 | spacing (list, optional): Spacing in 3D (x,y,z). Defaults to [1, 1, 1]. 22 | """ 23 | assert len(p0) == len(p1), "points should have same depth" 24 | 25 | if len(s.shape) == 2: 26 | s = s[None] 27 | p0 = p0[0], p0[1], 0 28 | p1 = p1[0], p1[1], 0 29 | 30 | assert len(s.shape) == 3, "stack s should be 2D or 3D" 31 | 32 | # Unpack coordinates 33 | x0, y0, z0 = p0 34 | x1, y1, z1 = p1 35 | x, y, z = p0 36 | 37 | dx = abs(x1 - x0) 38 | dy = abs(y1 - y0) 39 | dz = abs(z1 - z0) 40 | 41 | sx = -1 if x0 > x1 else 1 42 | sy = -1 if y0 > y1 else 1 43 | sz = -1 if z0 > z1 else 1 44 | 45 | derr = max([dx, dy, dz]) 46 | 47 | # Interpolate the differences and the radii for drawing 48 | r = np.interp(range(derr), [0, derr-1], [r0, r1]) 49 | 50 | errx = derr / 2 51 | erry = derr / 2 52 | errz = derr / 2 53 | 54 | for i in range(derr): 55 | # draw point in line 56 | # s[z, x, y] = color 57 | # draw sphere with radius r 58 | sphere(s, [x, y, z], r[i]*2, spacing=spacing, color=color) 59 | 60 | # Update coordinates 61 | errx -= dx 62 | 63 | if errx < 0: 64 | x += sx 65 | errx += derr 66 | 67 | erry -= dy 68 | 69 | if erry < 0: 70 | y += sy 71 | erry += derr 72 | 73 | errz -= dz 74 | 75 | if errz < 0: 76 | z += sz 77 | errz += derr 78 | 79 | 80 | def sphere(s, p0, d, spacing=[1, 1, 1], color=255, debug=False): 81 | """Draw a 3D sphere with given diameter d at point p0 in given color. 82 | 83 | Args: 84 | s (numpy.ndarray): numpy 3D stack 85 | p0 (tuple): x, y, z tuple 86 | d (float): diameter in 1 spacing unit 87 | spacing (list, optional): x, y, z spacing; x and y spacing must be equal. Defaults to [1, 1, 1]. 88 | color (int, optional): Draw color, e.g. 255 for np.uint8 stack. Defaults to 255. 89 | debug (bool, optional): if True prints plane related information. Defaults to False. 90 | """ 91 | assert spacing[0] == spacing[1], "x and y spacing must be the same!" 92 | 93 | # Convert to pixels 94 | d_xy = d / spacing[0] 95 | r = d_xy / 2 96 | 97 | # Initialize center 98 | x, y, z = p0 99 | 100 | # Iterate over planes where the sphere is visible 101 | for plane in range(z - int(d / spacing[2] / 2), z + int(d / spacing[2] / 2) + 1): 102 | radius = np.sqrt((d / 2) ** 2 - ((plane - z) * spacing[2]) ** 2) / spacing[0] 103 | 104 | if debug: 105 | print(plane, (plane - z) * spacing[2], "µm to center") 106 | print(radius * spacing[0], "µm, ", np.round(radius, 3), "px\n") 107 | 108 | # If sphere is to be drawn 109 | if radius > 0: 110 | # Draw a circle on a diameter x diameter grid w/ given radius 111 | rr, cc = disk((d_xy//2, d_xy//2), radius, shape=(d_xy, d_xy)) 112 | 113 | if plane < 0 or plane >= s.shape[0]: 114 | continue 115 | 116 | if (rr + x - d_xy//2).max() >= s.shape[1] or (cc + y - d_xy//2).max() >= s.shape[2]: 117 | continue 118 | 119 | # Go to plane in stack and move circle to right position, acts in-place 120 | # different to previous shift with circle somehow... 121 | s[plane, (rr + x - d_xy//2).astype(int), (cc + y - d_xy//2).astype(int)] = color 122 | 123 | def xyzr(swc, i): 124 | """returns xyz coordinates and radius as tuple from swc pandas DataFrame and loc i, 125 | actually it is y, x and z 126 | 127 | Args: 128 | swc (pandas.DataFrame): List of traced coordinates 129 | i (int): current location 130 | 131 | Returns: 132 | tuple: y, x, z and r coordinates as integers 133 | """ 134 | return (int(swc.loc[i].y), int(swc.loc[i].x), int(swc.loc[i].z)), int(swc.loc[i].r) 135 | 136 | class DendriteSWC(QObject): 137 | node = pyqtSignal(int, int) 138 | 139 | def __init__(self, spacing=[1, 1, 1]): 140 | """Converting a neuron trace saved as swc file back to a clean tif stack 141 | 142 | Args: 143 | spacing (list, optional): Spacing in 3D (x, y, z). Defaults to [1, 1, 1]. 144 | """ 145 | super().__init__() 146 | 147 | self.swc = None 148 | self.ref = None 149 | self.spacing = spacing 150 | 151 | def open(self, swc_fn, ref_fn): 152 | """Open and read the swc and the stack file. 153 | 154 | Args: 155 | swc_fn (str): The file path to the swc file 156 | ref_fn (str): The file path to the stakc file 157 | """ 158 | 159 | print('Check for comments in SWC file...') 160 | 161 | skiprows = 0 162 | 163 | with open(swc_fn) as fp: 164 | while True: 165 | line = fp.readline() 166 | if line.startswith("#"): 167 | skiprows += 1 168 | print(line.strip()) 169 | 170 | else: 171 | break 172 | print(f' --> will skip {skiprows} rows.') 173 | 174 | print('Load SWC file...') 175 | 176 | self.swc = pd.read_csv(swc_fn, 177 | sep=' ', 178 | header=None, 179 | skiprows=skiprows, 180 | index_col=0, 181 | names=('idx','kind','x','y','z','r','parent')) 182 | 183 | print('Load ref file...') 184 | self.ref = np.asarray(io.mimread(ref_fn, memtest=False)) 185 | 186 | print(self.ref.shape) 187 | 188 | self.ref_fn = ref_fn 189 | self.swc_fn = swc_fn 190 | 191 | 192 | def convert(self, target_fn=None): 193 | """Convert swc file to tif stack 194 | 195 | Args: 196 | target_fn (str, optional): Target path. Defaults to None. 197 | 198 | Returns: 199 | return: save path 200 | """ 201 | print('Create stack...') 202 | self.stack = np.zeros(self.ref.shape, dtype=np.uint8) 203 | 204 | print('Binarize SWC...') 205 | self._binarize_swc_w_spheres() 206 | 207 | if target_fn is None: 208 | target = self.ref_fn if target_fn is None else target_fn 209 | path_ref = pathlib.Path(target) 210 | ext = path_ref.suffix 211 | 212 | save_fn = target.replace(ext, f"_dendrite.tif") 213 | save_fn_max = target.replace(ext, f"_dendrite_max.png") 214 | 215 | else: 216 | path_ref = pathlib.Path(target_fn) 217 | ext = path_ref.suffix 218 | 219 | save_fn = target_fn 220 | save_fn_max = target_fn.replace(ext, f"_max.png") 221 | 222 | print('Saving stack and maximum intensity projection') 223 | io.mimwrite(save_fn, self.stack) 224 | io.imwrite(save_fn_max, self.stack.max(0)) 225 | 226 | print(f'Data saved as {save_fn} \n and \n {save_fn_max}') 227 | print() 228 | 229 | return save_fn 230 | 231 | def _binarize_swc_w_spheres(self): 232 | '''Binarizes SWC file in a given 3D stack with spheres''' 233 | 234 | for i in tqdm(range(1, self.swc.shape[0])): 235 | if self.swc.loc[i].parent > 0: 236 | p0, r0 = xyzr(self.swc, self.swc.loc[i].parent) 237 | p1, r1 = xyzr(self.swc, i) 238 | 239 | try: 240 | line_w_sphere(self.stack, p0, p1, r0, r1, 255, self.spacing) 241 | except: 242 | pass 243 | 244 | self.node.emit(i, self.swc.shape[0]) 245 | 246 | if __name__ == '__main__': 247 | pass -------------------------------------------------------------------------------- /deepd3/core/distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numba import njit 3 | import pandas as pd 4 | 5 | @njit 6 | def _computeDistance(p1, p2, dxy=0.1, dz=0.5): 7 | """compute euclidean distance of two points in space. Points are in (Z, Y, X) format""" 8 | di_z = (p1[0] * dz - p2[0] * dz) ** 2 9 | di_y = (p1[1] * dxy - p2[1] * dxy) ** 2 10 | di_x = (p1[2] * dxy - p2[2] * dxy) ** 2 11 | 12 | return np.sqrt(di_z + di_y + di_x) 13 | 14 | @njit 15 | def _distanceMatrix(pt1, pt2, dxy=0.1, dz=0.5) -> np.ndarray: 16 | """Compute distance matrix of points in 3D (Z, Y, X). 17 | Works only on 3D data 18 | 19 | Args: 20 | pt1 (numpy.ndarray): Points to be matched 21 | pt2 (numpy.ndarray): Points that can be matched 22 | dxy (float, optional): Pitch in xy. Defaults to 0.1. 23 | dz (float, optional): Pitch in z. Defaults to 0.5. 24 | 25 | Returns: 26 | numpy.ndarray: distance map from pt1 and pt2 points 27 | """ 28 | dm = np.zeros((pt1.shape[0], pt2.shape[0])) 29 | 30 | for i in range(pt1.shape[0]): 31 | for j in range(pt2.shape[0]): 32 | dm[i, j] = _computeDistance(pt1[i], pt2[j], dxy=dxy, dz=dz) 33 | 34 | return dm 35 | 36 | def distanceMatrix(pt1, pt2, dxy=0.1, dz=0.5): 37 | """Compute distance matrix of points in 2D (Y, X) and 3D (Z, Y, X) 38 | 39 | Args: 40 | pt1 (numpy.ndarray): Points to be matched 41 | pt2 (numpy.ndarray): Points that can be matched 42 | dxy (float, optional): Pitch in xy. Defaults to 0.1. 43 | dz (float, optional): Pitch in z. Defaults to 0.5. 44 | """ 45 | if pt1.shape[1] == 2: 46 | pt1 = np.insert(pt1, 0, 1, axis=1) 47 | 48 | if pt2.shape[1] == 2: 49 | pt2 = np.insert(pt2, 0, 1, axis=1) 50 | 51 | return _distanceMatrix(pt1, pt2, dxy=dxy, dz=dz) 52 | 53 | def _countOccurences(arr) -> dict: 54 | """Count occurences in array 55 | 56 | Args: 57 | arr (numpy.ndarray): Array with non-unique numbers 58 | 59 | Returns: 60 | dict: Dictionary with unique numbers as keys and their occurence as value 61 | """ 62 | d = dict() 63 | 64 | for i in arr: 65 | if i not in d.keys(): 66 | d[i] = 1 67 | 68 | else: 69 | d[i] += 1 70 | 71 | return d 72 | 73 | def findMatches(pt1, pt2, dxy=0.1, dz=0.5, threshold_distance=1.2): 74 | matched = [] 75 | unmatched = [] 76 | 77 | # Compute distance Matrix 78 | dm = distanceMatrix(pt1, pt2, dxy=dxy, dz=dz) 79 | 80 | # Find minimal distances and point ids 81 | min_di = np.min(dm, 1) 82 | min_pt = np.argmin(dm, 1) 83 | 84 | occurences = _countOccurences(min_pt) 85 | 86 | assigned_pts = [] 87 | 88 | # Iterate over pt1 points 89 | for i in range(pt1.shape[0]): 90 | 91 | # Point is close in space and was uniquely assigned 92 | if min_di[i] < threshold_distance and occurences[min_pt[i]] == 1: 93 | matched.append([ 94 | pt1[i], 95 | pt2[min_pt[i]] 96 | ]) 97 | 98 | assigned_pts.append(min_pt[i]) 99 | 100 | # Point is close in space and was assigned multiple times 101 | elif min_di[i] < threshold_distance and occurences[min_pt[i]] > 1: 102 | # The current point is the closest to the assigned point 103 | # and has not been assigned just yet (e.g. two points exactly the same distance) 104 | if min_di[min_pt == min_pt[i]].min() == min_di[i] and min_pt[i] not in assigned_pts: 105 | matched.append([ 106 | pt1[i], 107 | pt2[min_pt[i]] 108 | ]) 109 | 110 | assigned_pts.append(min_pt[i]) 111 | 112 | # Remove from list, because... 113 | else: 114 | min_pt[i] = -1 115 | unmatched.append(pt1[i]) 116 | 117 | # Point could not be matched for some reason 118 | else: 119 | unmatched.append(pt1[i]) 120 | 121 | return matched, unmatched, assigned_pts 122 | 123 | def createMatchMap(matched, safety=1.2): 124 | to_df = [] 125 | 126 | for p1, p2 in matched: 127 | if p1.size == 2: 128 | p1 = np.insert(p1, 0, 0) 129 | 130 | if p2.size == 2: 131 | p2 = np.insert(p2, 0, 0) 132 | 133 | c = (p1+p2)/2 134 | 135 | to_df.append({ 136 | 'Z1': p1[0], 137 | 'Z2': p2[0], 138 | 'Y1': p1[1], 139 | 'Y2': p2[1], 140 | 'X1': p1[2], 141 | 'X2': p2[2], 142 | 'CZ': c[0], 143 | 'CY': c[1], 144 | 'CX': c[2], 145 | 'R': np.sqrt(np.sum((p2[1:]-p1[1:])**2)) * safety 146 | }) 147 | 148 | return pd.DataFrame(to_df) 149 | 150 | if __name__ == '__main__': 151 | import matplotlib.pyplot as plt 152 | import flammkuchen as fl 153 | 154 | ####### 155 | # Create some fake data 156 | im = np.zeros((64, 64, 3), dtype=np.int32) 157 | 158 | pt1 = [ 159 | (5, 5), 160 | (23, 24), 161 | (15, 12), 162 | (3, 40), 163 | (60, 60), 164 | (45, 10), 165 | (35, 10) 166 | ] 167 | 168 | pt2 = [ 169 | (5, 7), 170 | (25, 24), 171 | (18, 15), 172 | (49, 51), 173 | (60, 61), 174 | (61, 60), 175 | (40, 10) 176 | ] 177 | 178 | for p in pt1: 179 | im[p] += (255, 0, 255) 180 | 181 | for p in pt2: 182 | im[p] += (0, 255, 0) 183 | 184 | pt1 = np.asarray(pt1) 185 | pt2 = np.asarray(pt2) 186 | 187 | im = np.uint8(im) 188 | 189 | ##### 190 | # Compute matching and create match map 191 | 192 | matched, unmatched, assigned_pts = findMatches(pt1, pt2) 193 | 194 | print(matched) 195 | 196 | df = createMatchMap(matched) 197 | df.to_csv("matched_points_test.matched") 198 | fl.save("matched_points_test.h5", dict(stack=im[None])) 199 | 200 | ###### 201 | # Show the data 202 | plt.figure() 203 | ax = plt.subplot(111) 204 | plt.imshow(im) 205 | 206 | for i, row in df.iterrows(): 207 | c = plt.Circle((row.CX, row.CY), row.R, color='b', fill=False) 208 | ax.add_patch(c) 209 | 210 | plt.show() 211 | -------------------------------------------------------------------------------- /deepd3/core/export.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from roifile import ImagejRoi, roiwrite 3 | from PyQt5.QtCore import QObject, pyqtSignal 4 | import pandas as pd 5 | import imageio as io 6 | import os 7 | from pathlib import Path 8 | 9 | class ExportFolder(QObject): 10 | zSignal = pyqtSignal(int, int) 11 | 12 | def __init__(self, rois): 13 | """Class to export ROIs to a folder 14 | 15 | Args: 16 | rois (dict): dictionary of ROIs, keys are z-plane, value is list with ROIs in given plane. 17 | """ 18 | super().__init__() 19 | self.rois = rois 20 | 21 | def export(self, fn, folder): 22 | """Export ROIs to folder 23 | 24 | Args: 25 | fn (str): file name 26 | folder (str): target folder 27 | """ 28 | basename = Path(fn).stem 29 | 30 | # Folder for filename 31 | target = os.path.join(folder, basename) 32 | 33 | if not os.path.exists(target): 34 | os.mkdir(target) 35 | 36 | for k, v in self.rois.items(): 37 | subtarget = os.path.join(target, str(k)) 38 | 39 | # Folder for Z 40 | if not os.path.exists(subtarget): 41 | os.mkdir(subtarget) 42 | 43 | for i in v: 44 | # Each ROI in this Z 45 | df = pd.DataFrame(i['contour']) 46 | df.columns = 'y', 'x' 47 | df.to_csv(os.path.join(subtarget, str(i['ROI_id'])+".csv"), index=False) 48 | 49 | # Tell GUI that there's some progress 50 | self.zSignal.emit(int(z), len(self.rois.keys())) 51 | 52 | 53 | class ExportImageJ(QObject): 54 | zSignal = pyqtSignal(int, int) 55 | 56 | def __init__(self, rois): 57 | """Class to export ROIs to ImageJ ROI zip file 58 | 59 | Args: 60 | rois (dict): dictionary of ROIs, keys are z-plane, value is list with ROIs in given plane. 61 | """ 62 | super().__init__() 63 | self.rois = rois 64 | 65 | def export(self, fn): 66 | """Export ROIs to ImageJ ROI zip file 67 | 68 | Args: 69 | fn (str): path to zip file 70 | """ 71 | 72 | # Iterate over z-planes 73 | for z, v in self.rois.items(): 74 | 75 | # Iterate over available ROIs 76 | for i in v: 77 | # Create ImageJ ROI from contour, 78 | # change x and y for ImageJ logic 79 | r = ImagejRoi.frompoints(i['contour'][:, ::-1]) 80 | 81 | # ImageJ z starts with 1, correct for it 82 | r.z_position = z+1 83 | 84 | # Write to ROI to zipfile 85 | roiwrite(fn, r) 86 | 87 | # Tell GUI that there's some progress 88 | self.zSignal.emit(int(z), len(self.rois.keys())) 89 | 90 | class ExportCentroids(QObject): 91 | def __init__(self, roi_centroids) -> None: 92 | """Class to export ROI centroids to file 93 | 94 | Args: 95 | roi_centroids (dict): ROI centroids 96 | """ 97 | super().__init__() 98 | self.roi_centroids = roi_centroids 99 | 100 | def export(self, fn): 101 | """Exports ROIs to file 102 | 103 | Args: 104 | fn (str): target filename and location 105 | """ 106 | tmp = [] 107 | 108 | for i in self.roi_centroids: 109 | tmp.append(dict(Pos=i[0], Y=i[1], X=i[2])) 110 | 111 | pd.DataFrame(tmp).to_csv(fn) 112 | 113 | 114 | class ExportPredictions(QObject): 115 | def __init__(self, pred_spines, pred_dendrites): 116 | """Class to export ROIs to a folder 117 | 118 | Args: 119 | rois (dict): dictionary of ROIs, keys are z-plane, value is list with ROIs in given plane. 120 | """ 121 | super().__init__() 122 | self.pred_spines = pred_spines 123 | self.pred_dendrites = pred_dendrites 124 | 125 | def export(self, fn, folder): 126 | """Export predictions as tif files 127 | 128 | Args: 129 | fn (str): file name 130 | folder (str): target folder 131 | """ 132 | basename = Path(fn).stem 133 | 134 | # Folder for filename 135 | target_spines = os.path.join(folder, basename+"_spines.tif") 136 | target_dendrites = os.path.join(folder, basename+"_dendrites.tif") 137 | 138 | try: 139 | io.mimwrite(target_spines, (self.pred_spines * 255).astype(np.uint8)) 140 | io.mimwrite(target_dendrites, (self.pred_dendrites * 255).astype(np.uint8)) 141 | 142 | return True, target_spines+"\n"+target_dendrites 143 | 144 | except Exception as e: 145 | return False, e 146 | 147 | 148 | class ExportROIMap(QObject): 149 | def __init__(self, roi_map, binarize=False): 150 | """Class to export ROIs to a folder 151 | 152 | Args: 153 | rois (dict): dictionary of ROIs, keys are z-plane, value is list with ROIs in given plane. 154 | """ 155 | super().__init__() 156 | self.roi_map = roi_map 157 | self.binarize = binarize 158 | 159 | def export(self, fn): 160 | """Export predictions as tif files 161 | 162 | Args: 163 | fn (str): file name 164 | folder (str): target folder 165 | """ 166 | try: 167 | if not self.binarize: 168 | N = self.roi_map.max() 169 | 170 | if N < 2**8: 171 | dtype = np.uint8 172 | if N < 2**16: 173 | dtype = np.uint16 174 | else: 175 | dtype = np.int32 176 | 177 | exp = self.roi_map.astype(dtype) 178 | 179 | else: 180 | exp = (self.roi_map > 0).astype(np.uint8) * 255 181 | 182 | io.mimwrite(fn, exp) 183 | 184 | return True, "" 185 | 186 | except Exception as e: 187 | return False, e 188 | -------------------------------------------------------------------------------- /deepd3/core/spines.py: -------------------------------------------------------------------------------- 1 | import flammkuchen as fl 2 | import numpy as np 3 | import imageio as io 4 | import pathlib 5 | 6 | class Spines: 7 | def __init__(self): 8 | """Spine annotation data processing 9 | """ 10 | self.spines_fn = None 11 | 12 | def open(self, spines_fn : str): 13 | """Saves path to object 14 | 15 | Args: 16 | spines_fn (str): path to spines annotation file 17 | """ 18 | self.spines_fn = spines_fn 19 | 20 | def convert(self): 21 | """Loads and converts spine annotation files to TIFF stacks 22 | 23 | Returns: 24 | str: Path to saved TIFF stack 25 | """ 26 | # Mask drawn in e.g. ImageJ and saved as TIF file 27 | if self.spines_fn.endswith('tif'): 28 | stack = io.mimread(self.spines_fn) 29 | 30 | # Mask drawn using pipra and saved as mask file 31 | elif self.spines_fn.endswith('mask'): 32 | stack = fl.load(self.spines_fn)['mask'] 33 | stack = stack.astype(np.uint8).transpose(0,2,1)*255 34 | 35 | else: 36 | print(f"We don't know how to open that file... \n {self.spines_fn}") 37 | return 38 | 39 | path_ref = pathlib.Path(self.spines_fn) 40 | ext = path_ref.suffix 41 | 42 | save_fn = self.spines_fn.replace(ext, f"_spines.tif") 43 | save_fn_max = self.spines_fn.replace(ext, f"_spines_max.png") 44 | 45 | print('Saving stack and maximum intensity projection') 46 | io.mimsave(save_fn, stack) 47 | io.imsave(save_fn_max, stack.max(0)) 48 | 49 | return save_fn 50 | 51 | if __name__ == '__main__': 52 | s = Spines() 53 | -------------------------------------------------------------------------------- /deepd3/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankilab/DeepD3/94de9d4697f00e82097c8775b924bf7ba4e624a7/deepd3/inference/__init__.py -------------------------------------------------------------------------------- /deepd3/inference/batch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import tensorflow as tf 4 | import flammkuchen as fl 5 | import numpy as np 6 | from deepd3.core.analysis import Stack, ROI2D_Creator, ROI3D_Creator 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('stack', type=argparse.FileType('r', encoding='UTF-8'), help='File to be segmented') 10 | parser.add_argument('neuralnetwork', metavar='nn', type=argparse.FileType('r', encoding='UTF-8'), help='Deep neural network for spine and dendrite segmentation') 11 | parser.add_argument('--tile_size', metavar='ts', type=int, help='Tile size for segmentation (default: 128)', default=128) 12 | parser.add_argument('--inset_size', metavar='is', type=int, help='Inset size for segmentation (default: 96)', default=96) 13 | parser.add_argument('--average', action='store_const', const=True, help='Predict segmentation with four offsets, average. (default: False)', default=False) 14 | parser.add_argument('--plane', action='store_const', const=True, help='Predict segmentation in a whole plane at once. (default: False)', default=False) 15 | parser.add_argument('--clean_dendrite', action='store_const', const=True, help='Clean dendrite using connected components per plane (2D). (default: False)', default=False) 16 | parser.add_argument('--clean_dendrite_3d', action='store_const', const=True, help='Clean dendrite using connected components in three dimensions (3D). (default: True)', default=True) 17 | parser.add_argument('--min_dendrite_size', type=int, help='Minimum dendrite element size in px for cleaning (default: 100)', default=100) 18 | parser.add_argument('--dendrite_threshold', type=float, help='Dendrite probability threshold for cleaning (default: 0.7)', default=0.7) 19 | parser.add_argument('--dendrite_dilation', type=float, help='Dendrite dilation factor for spine cleaning (default: 11)', default=11) 20 | parser.add_argument('--clean_spines', action='store_const', const=True, help='Clean spines using dendrite prediction dilation. (default: True)', default=True) 21 | 22 | 23 | parser.add_argument('--build_rois_2d', action='store_const', const=True, help='Enable 2D ROI building') 24 | parser.add_argument('--build_rois_3d', action='store_const', const=True, help='Enable 3D ROI building') 25 | 26 | parser.add_argument('--roi_method', type=str, help='ROI building method: floodfill or connected components (default: floodfill)', default="floodfill") 27 | parser.add_argument('--roi_areaThreshold', type=float, help='ROI probability threshold for area (default: 0.25)', default=0.25) 28 | parser.add_argument('--roi_peakThreshold', type=float, help='ROI probability threshold for peak (default: 0.80)', default=0.80) 29 | parser.add_argument('--roi_seedDelta', type=float, help='Pixel similarity to seed pixel (default: 0.2)', default=0.2) 30 | parser.add_argument('--roi_distanceToSeed', type=float, help='Distance to seed pixel in px (default: 10)', default=10) 31 | 32 | parser.add_argument('--watershed', action='store_const', const=True, help='Apply watershed (default: False)', default=False) 33 | parser.add_argument('--clean_rois', action='store_const', const=True, help='Enable ROI cleaning') 34 | parser.add_argument('--min_roi_size', type=int, help='Minimum ROI size in px (default: 10)', default=10) 35 | parser.add_argument('--max_roi_size', type=int, help='Maximum ROI size in px (default: 1000)', default=1000) 36 | parser.add_argument('--min_planes', type=int, help='Minimum Planes an ROI should span (default: 1)', default=1) 37 | 38 | if __name__ == '__main__': 39 | args = parser.parse_args() 40 | 41 | # Define filenames 42 | fn = args.stack.name 43 | ext = fn.split(".")[-1] 44 | pred_fn = fn[:-len(ext)-1]+".prediction" 45 | rois_fn = fn[:-len(ext)-1]+".rois" 46 | 47 | 48 | print("Loading stack...") 49 | S = Stack(fn) 50 | 51 | if args.average: 52 | print("Predicting inset four times, average") 53 | S.predictFourFold(args.neuralnetwork.name, 54 | args.tile_size, 55 | args.inset_size) 56 | 57 | elif args.plane: 58 | print("Predict whole image in plane") 59 | S.predictWholeImage(args.neuralnetwork.name) 60 | 61 | else: 62 | print("Predicting inset") 63 | S.predictInset(args.neuralnetwork.name, 64 | args.tile_size, 65 | args.inset_size) 66 | 67 | 68 | if args.clean_dendrite: 69 | print("Cleaning dendrite in 2D") 70 | d = S.cleanDendrite(args.dendrite_threshold, args.min_dendrite_size) 71 | S.prediction[..., 0] = d 72 | S.prediction[..., 2] = d 73 | 74 | if args.clean_dendrite_3d: 75 | print("Cleaning dendrite in 3D") 76 | S.cleanDendrite3D(args.dendrite_threshold, args.min_dendrite_size) 77 | S.prediction[..., 0] = d 78 | S.prediction[..., 2] = d 79 | 80 | if args.clean_spines: 81 | print("Cleaning spines") 82 | s = S.cleanSpines(args.dendrite_threshold, args.dendrite_dilation) 83 | S.prediction[..., 1] = s 84 | 85 | print("Saving predictions") 86 | fl.save(pred_fn, 87 | dict(dendrites=S.prediction[...,0].astype(np.float32), 88 | spines=S.prediction[..., 1].astype(np.float32)), 89 | compression='blosc') 90 | 91 | 92 | if args.build_rois_2d: 93 | print("Building 2D ROIs...") 94 | 95 | if args.build_rois_3d: 96 | print("********************") 97 | print("Caution: You also selected 3D ROI building, ", end="") 98 | print("please disable then 2D ROI building.") 99 | print("No 3D ROIs will be built.") 100 | print("********************") 101 | 102 | r = ROI2D_Creator(S.prediction[..., 0], 103 | S.prediction[..., 1], 104 | args.roi_areaThreshold) 105 | 106 | r.create(args.watershed) 107 | 108 | if args.clean_rois: 109 | print("Cleaning 2D ROIs...") 110 | r.clean(args.max_dendrite_displacement, 111 | args.min_roi_size, 112 | args.dendrite_threshold) 113 | 114 | print("Saving 2D ROIs...") 115 | fl.save(rois_fn, 116 | dict(rois=r.rois, roi_map=r.roi_map), 117 | compression='blosc') 118 | 119 | elif args.build_rois_3d: 120 | print("Building 3D ROIs...") 121 | r = ROI3D_Creator(S.prediction[..., 0], 122 | S.prediction[..., 1], 123 | args.method, 124 | args.roi_areaThreshold, 125 | args.roi_peakThreshold, 126 | args.roi_seedDelta, 127 | args.roi_distanceToSeed) 128 | 129 | r.create(args.min_roi_size, 130 | args.max_roi_size, 131 | args.min_planes) 132 | 133 | print("Saving ROIs...") 134 | fl.save(rois_fn, 135 | dict(rois=r.rois, roi_map=r.roi_map), 136 | compression='blosc') 137 | 138 | print("Done!") 139 | 140 | 141 | -------------------------------------------------------------------------------- /deepd3/inference/gui.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, \ 3 | QMessageBox, QFileDialog, QGridLayout, QLabel, QPushButton, \ 4 | QProgressBar, QDialog, QTableWidget, QTableWidgetItem, QHeaderView, \ 5 | QLineEdit, QAction, QGraphicsPathItem, QCheckBox, QFrame, QComboBox, \ 6 | QSlider, QGraphicsEllipseItem 7 | from PyQt5.QtGui import QKeySequence, QPainter, QPen, QPainterPath, \ 8 | QPolygonF, QIntValidator, QDoubleValidator, QColor 9 | from PyQt5.QtCore import Qt, QPointF, pyqtSignal 10 | import pyqtgraph as pg 11 | import imageio as io 12 | from deepd3.core.analysis import Stack, ROI2D_Creator, ROI3D_Creator 13 | from deepd3.core.export import ExportCentroids, ExportImageJ, \ 14 | ExportFolder, ExportPredictions, ExportROIMap 15 | from time import time 16 | from datetime import datetime 17 | import flammkuchen as fl 18 | import json 19 | import pandas as pd 20 | import sys, os 21 | from skimage.color import label2rgb 22 | from scipy.ndimage import gaussian_filter 23 | 24 | ############################################## 25 | class QHLine(QFrame): 26 | def __init__(self): 27 | super().__init__() 28 | self.setFrameShape(QFrame.HLine) 29 | self.setFrameShadow(QFrame.Sunken) 30 | 31 | ######################## 32 | ## Ask for dimensions 33 | ######################## 34 | class askDimensions(QDialog): 35 | def __init__(self, xy=0.094, z=0.5) -> None: 36 | """Window asking for dimensions of loaded stack 37 | 38 | Args: 39 | xy (float, optional): xy dimensions in µm. Defaults to 0.094. 40 | z (float, optional): z dimensions in µm. Defaults to 0.5. 41 | """ 42 | super().__init__() 43 | 44 | self.default_xy = xy 45 | self.default_z = z 46 | 47 | self.l = QGridLayout(self) 48 | 49 | self.l.addWidget(QLabel("Enter your stack dimensions:")) 50 | self.l.addWidget(QHLine()) 51 | 52 | self.xy = QLineEdit() 53 | self.xy.setPlaceholderText("Default 0.094") 54 | self.xy.setValidator(QDoubleValidator()) 55 | self.xy.setText(str(self.default_xy)) 56 | self.l.addWidget(QLabel("XY px width in micrometer")) 57 | self.l.addWidget(self.xy) 58 | 59 | self.z = QLineEdit() 60 | self.z.setPlaceholderText("Default 0.5") 61 | self.z.setValidator(QDoubleValidator()) 62 | self.z.setText(str(self.default_z)) 63 | self.l.addWidget(QLabel("Z step in micrometer")) 64 | self.l.addWidget(self.z) 65 | 66 | self.exec_() 67 | 68 | def dimensions(self): 69 | """Returns dictionary containing xy and z dimensions in µm 70 | 71 | Returns: 72 | dict: returns xy and z dimensions 73 | """ 74 | xy = self.default_xy if self.xy.text() == "" else float(self.xy.text()) 75 | z = self.default_z if self.z.text() == "" else float(self.z.text()) 76 | 77 | return dict(xy=xy, z=z) 78 | 79 | class ROI3D(QDialog): 80 | def __init__(self, settings=None): 81 | """Dialog for 3D ROI build settings 82 | """ 83 | super().__init__() 84 | self.l = QGridLayout(self) 85 | self.go = False 86 | 87 | self.areaThreshold = QLineEdit("0.25") 88 | self.areaThreshold.setValidator(QDoubleValidator(0, 1, 2)) 89 | 90 | self.peakThreshold = QLineEdit("0.80") 91 | self.peakThreshold.setValidator(QDoubleValidator(0, 1, 2)) 92 | 93 | self.minPlanes = QLineEdit("1") 94 | self.minPlanes.setValidator(QIntValidator(1, 100)) 95 | self.minPx = QLineEdit("20") 96 | self.minPx.setValidator(QIntValidator(1, 10000)) 97 | self.maxPx = QLineEdit("1000") 98 | self.maxPx.setValidator(QIntValidator(1, 10000)) 99 | 100 | self.distanceToSeed = QLineEdit("2.00") 101 | self.distanceToSeed.setValidator(QDoubleValidator(0, 20., 2)) 102 | self.seedDelta = QLineEdit("0.2") 103 | self.seedDelta.setValidator(QDoubleValidator(0, 1, 2)) 104 | 105 | self.method = QComboBox() 106 | self.method.addItems(["floodfill", "connected components"]) 107 | 108 | ######## Design 109 | self.l.addWidget(QLabel("ROI building method")) 110 | self.l.addWidget(self.method) 111 | 112 | self.l.addWidget(QLabel("Area threshold (0...1)")) 113 | self.l.addWidget(self.areaThreshold) 114 | 115 | self.l.addWidget(QLabel("Peak threshold (0...1)")) 116 | self.l.addWidget(self.peakThreshold) 117 | 118 | self.l.addWidget(QHLine()) 119 | 120 | self.l.addWidget(QLabel("Minimum planes an ROI should span:")) 121 | self.l.addWidget(self.minPlanes) 122 | 123 | self.l.addWidget(QLabel("Minimum 3D ROI size [px]:")) 124 | self.l.addWidget(self.minPx) 125 | 126 | self.l.addWidget(QLabel("Maximum 3D ROI size [px]:")) 127 | self.l.addWidget(self.maxPx) 128 | 129 | self.l.addWidget(QLabel("Maximum prediction difference to seed pixel prediction")) 130 | self.l.addWidget(self.seedDelta) 131 | 132 | self.l.addWidget(QLabel("Maximum 3D euclidean distance to seed pixel [px]")) 133 | self.l.addWidget(self.distanceToSeed) 134 | 135 | self.watershed = QCheckBox("Apply watershed when building ROIs") 136 | self.l.addWidget(self.watershed) 137 | 138 | self.l.addWidget(QHLine()) 139 | 140 | if type(settings) != type(None): 141 | self.areaThreshold.setText(str(settings['areaThreshold'])) 142 | self.peakThreshold.setText(str(settings['peakThreshold'])) 143 | self.seedDelta.setText(str(settings['seedDelta'])) 144 | self.distanceToSeed.setText(str(settings['distanceToSeed'])) 145 | self.minPx.setText(str(settings['minPx'])) 146 | self.maxPx.setText(str(settings['maxPx'])) 147 | self.minPlanes.setText(str(settings['minPlanes'])) 148 | self.watershed.setChecked(settings['applyWatershed']) 149 | 150 | self.closeButton = QPushButton("Save and start building ROIs") 151 | self.closeButton.clicked.connect(self.close) 152 | self.l.addWidget(self.closeButton) 153 | 154 | def close(self): 155 | self.settings = { 156 | 'method': self.method.currentText(), 157 | 'areaThreshold': float(self.areaThreshold.text()), 158 | 'peakThreshold': float(self.peakThreshold.text()), 159 | 'minPx': int(self.minPx.text()), 160 | 'maxPx': int(self.maxPx.text()), 161 | 'minPlanes': int(self.minPlanes.text()), 162 | 'distanceToSeed': float(self.distanceToSeed.text()), 163 | 'seedDelta': float(self.seedDelta.text()), 164 | 'watershed': bool(self.watershed.isChecked()) 165 | } 166 | self.go = True 167 | super().close() 168 | 169 | class ROI2D(QDialog): 170 | def __init__(self): 171 | """Dialog for 2D ROI build settings 172 | """ 173 | super().__init__() 174 | self.l = QGridLayout(self) 175 | self.go = False 176 | 177 | self.threshold = QLineEdit("0.25") 178 | self.threshold.setValidator(QDoubleValidator(0, 1, 2)) 179 | 180 | self.l.addWidget(QLabel("Threshold (0...1)")) 181 | self.l.addWidget(self.threshold) 182 | 183 | 184 | self.l.addWidget(QHLine()) 185 | 186 | self.cleanROIs = QCheckBox("Clean ROIs") 187 | self.cleanROIs.setChecked(True) 188 | 189 | self.maxDistanceToDendrite = QLineEdit("30") 190 | self.maxDistanceToDendrite.setValidator(QIntValidator(0, 100)) 191 | 192 | self.minROIsize = QLineEdit("10") 193 | self.minROIsize.setValidator(QIntValidator(0, 1000)) 194 | 195 | self.cleanDendriteThreshold = QLineEdit("0.7") 196 | self.cleanDendriteThreshold.setValidator(QDoubleValidator(0, 1, 2)) 197 | 198 | self.l.addWidget(self.cleanROIs) 199 | 200 | self.l.addWidget(QLabel("Maximum Distance to dendrite [px]")) 201 | self.l.addWidget(self.maxDistanceToDendrite) 202 | 203 | self.l.addWidget(QLabel("Minimum ROI size [px]")) 204 | self.l.addWidget(self.minROIsize) 205 | 206 | self.l.addWidget(QLabel("Dendrite threshold")) 207 | self.l.addWidget(self.cleanDendriteThreshold) 208 | 209 | self.l.addWidget(QHLine()) 210 | 211 | self.applyWatershed = QCheckBox("Apply watershed transform") 212 | self.l.addWidget(self.applyWatershed) 213 | 214 | self.l.addWidget(QHLine()) 215 | 216 | self.closeButton = QPushButton("Save and start building ROIs") 217 | self.closeButton.clicked.connect(self.close) 218 | self.l.addWidget(self.closeButton) 219 | 220 | def close(self): 221 | try: 222 | threshold = float(self.threshold.text()) 223 | except Exception as e: 224 | QMessageBox.critical(self, "Error with threshold", e) 225 | return 226 | 227 | try: 228 | maxD = float(self.maxDistanceToDendrite.text()) 229 | except Exception as e: 230 | QMessageBox.critical(self, "Error with maximal Distance", e) 231 | return 232 | 233 | try: 234 | minS = int(self.minROIsize.text()) 235 | except Exception as e: 236 | QMessageBox.critical(self, "Error with minimal ROI size", e) 237 | return 238 | 239 | if not 0 <= threshold <= 1: 240 | QMessageBox.critical(self, "Wrong threshold", "Threshold should be between 0 and 1!") 241 | return 242 | 243 | if maxD < 0: 244 | QMessageBox.critical(self, "Wrong distance", "max Dendrite Distance should be above 0!") 245 | return 246 | 247 | self.settings = { 248 | 'threshold': threshold, 249 | 'maxDendriteDistance': maxD, 250 | 'minSize': minS, 251 | 'dendriteThreshold': float(self.cleanDendriteThreshold.text()) 252 | } 253 | self.go = True 254 | super().close() 255 | 256 | ############################################## 257 | class Segment(QDialog): 258 | def __init__(self, model_fn=None): 259 | """Dialog for segmentation settings 260 | """ 261 | super().__init__() 262 | self.l = QGridLayout(self) 263 | self.model_fn = model_fn 264 | 265 | self.go = False 266 | 267 | self.findModelButton = QPushButton("Find...") 268 | self.findModelButton.clicked.connect(self.findModel) 269 | 270 | self.selectInferenceMode = QComboBox() 271 | self.selectInferenceMode.addItems([ 272 | 'Plane inference', 273 | 'Tile inference [1x]', 274 | 'Tile inference [4x avg]' 275 | ]) 276 | 277 | self.tileSize = QLineEdit("128") 278 | self.tileSize.setValidator(QIntValidator(32, 128)) 279 | self.insetSize = QLineEdit("96") 280 | self.insetSize.setValidator(QIntValidator(32, 128)) 281 | # self.fourFoldSegmentation = QCheckBox("Four fold prediction with averaging") 282 | # self.wholeImageInference = QCheckBox("Whole image inference") 283 | 284 | 285 | self.paddingOperation = QComboBox() 286 | self.paddingOperation.addItems(['min', 'mean']) #("Four fold prediction with averaging") 287 | 288 | self.l.addWidget(QLabel("Select model")) 289 | 290 | if model_fn != None: 291 | model_fn_ext = model_fn.split("\\")[-1] 292 | self.l.addWidget(QLabel(f"Current model: \n{model_fn_ext}")) 293 | 294 | self.l.addWidget(self.findModelButton) 295 | 296 | self.l.addWidget(QHLine()) 297 | 298 | self.l.addWidget(QLabel("Inference mode")) 299 | self.l.addWidget(self.selectInferenceMode) 300 | 301 | self.l.addWidget(QHLine()) 302 | 303 | self.l.addWidget(QLabel("Tile size (only for tile inference)")) 304 | self.l.addWidget(self.tileSize) 305 | 306 | self.l.addWidget(QLabel("Inset size (only for tile inference)")) 307 | self.l.addWidget(self.insetSize) 308 | 309 | self.l.addWidget(QLabel("Padding operation (only for tile inference)")) 310 | self.l.addWidget(self.paddingOperation) 311 | 312 | self.l.addWidget(QHLine()) 313 | 314 | self.runOnCPU = QCheckBox("Run on CPU (uncheck if you want GPU)") 315 | self.runOnCPU.setChecked(True) 316 | self.l.addWidget(self.runOnCPU) 317 | 318 | self.closeButton = QPushButton("Save and start segmenting") 319 | self.closeButton.clicked.connect(self.close) 320 | self.l.addWidget(self.closeButton) 321 | 322 | def findModel(self): 323 | """Find TensorFlow/Keras model on file system 324 | """ 325 | model_fn = QFileDialog.getOpenFileName(caption="Find TensorFlow Keras model", 326 | filter="*.h5") 327 | 328 | if model_fn: 329 | self.model_fn = model_fn 330 | 331 | def close(self): 332 | if self.model_fn is None: 333 | QMessageBox.critical(self, 334 | "No model selected", 335 | "Please select an appropriate model") 336 | return 337 | 338 | self.settings = { 339 | 'infierenceMode': self.selectInferenceMode.currentText(), 340 | 'tileSize': int(self.tileSize.text()), 341 | 'insetSize': int(self.insetSize.text()), 342 | 'paddingOperation': str(self.paddingOperation.currentText()), 343 | 'runOnCPU': bool(self.runOnCPU.isChecked()) 344 | } 345 | 346 | if self.settings['insetSize'] > self.settings['tileSize']: 347 | QMessageBox.critical(self, 348 | "Settings", 349 | "Please choose a insetSize smaller or equal to tileSize") 350 | return 351 | 352 | self.go = True 353 | super().close() 354 | 355 | 356 | ############################################## 357 | class Cleaning(QDialog): 358 | previewSignal = pyqtSignal(dict) 359 | 360 | def __init__(self): 361 | """Dialog for cleaning settings 362 | """ 363 | super().__init__() 364 | self.l = QGridLayout(self) 365 | 366 | self.go = False 367 | 368 | ### Closing dendrite 369 | self.closingDendrite = QCheckBox("Connect single dendrite elements") 370 | self.closingDendrite.setChecked(True) 371 | self.closingDendrite.stateChanged.connect(self.previewCleaning) 372 | 373 | self.closingDendriteIterations = QLineEdit("3") 374 | self.closingDendriteIterations.setValidator(QIntValidator(1, 100)) 375 | self.closingDendriteIterations.textChanged.connect(self.previewCleaning) 376 | 377 | ### Dendrite 378 | self.cleanDendrite = QCheckBox("Clean dendrite in 3D") 379 | self.cleanDendrite.setChecked(True) 380 | self.cleanDendrite.stateChanged.connect(self.previewCleaning) 381 | 382 | self.cleanDendriteThreshold = QLineEdit("0.7") 383 | self.cleanDendriteThreshold.setValidator(QDoubleValidator(0, 1, 2)) 384 | self.cleanDendriteThreshold.textChanged.connect(self.previewCleaning) 385 | 386 | self.minDendriteSize = QLineEdit("100") 387 | self.minDendriteSize.setValidator(QIntValidator(1, 10000)) 388 | self.minDendriteSize.textChanged.connect(self.previewCleaning) 389 | 390 | self.l.addWidget(QHLine()) 391 | 392 | self.cleanSpines = QCheckBox("Clean spines using dendrite proximity") 393 | self.cleanSpines.setChecked(True) 394 | self.cleanSpines.stateChanged.connect(self.previewCleaning) 395 | 396 | self.dendriteDilation = QLineEdit("21") 397 | self.dendriteDilation.setValidator(QIntValidator(1, 100)) 398 | self.dendriteDilation.textChanged.connect(self.previewCleaning) 399 | 400 | 401 | self.preview = QCheckBox("preview") 402 | 403 | self.l.addWidget(QHLine()) 404 | 405 | self.l.addWidget(QLabel("Connecting elements")) 406 | self.l.addWidget(self.closingDendrite) 407 | self.l.addWidget(QLabel("Connection iterations")) 408 | self.l.addWidget(self.closingDendriteIterations) 409 | 410 | self.l.addWidget(QHLine()) 411 | 412 | self.l.addWidget(QLabel("Clean dendrite segmentation")) 413 | self.l.addWidget(self.cleanDendrite) 414 | self.l.addWidget(QLabel("Dendrite threshold")) 415 | self.l.addWidget(self.cleanDendriteThreshold) 416 | self.l.addWidget(QLabel("Minimum Dendrite size in px")) 417 | self.l.addWidget(self.minDendriteSize) 418 | 419 | self.l.addWidget(QHLine()) 420 | 421 | self.l.addWidget(QLabel("Clean spine prediction")) 422 | self.l.addWidget(self.cleanSpines) 423 | self.l.addWidget(QLabel("Dendrite proximity [dilation iterations]")) 424 | self.l.addWidget(self.dendriteDilation) 425 | 426 | self.l.addWidget(QHLine()) 427 | 428 | self.l.addWidget(self.preview) 429 | 430 | self.l.addWidget(QHLine()) 431 | 432 | self.closeButton = QPushButton("Save and start cleaning") 433 | self.closeButton.clicked.connect(self.close) 434 | self.l.addWidget(self.closeButton) 435 | 436 | def _settings(self): 437 | return { 438 | 'closing': self.closingDendrite.isChecked(), 439 | 'closingIterations': int(self.closingDendriteIterations.text() if self.closingDendriteIterations.text() != '' else 1), 440 | 'cleanDendrite': self.cleanDendrite.isChecked(), 441 | 'cleanSpines': self.cleanSpines.isChecked(), 442 | 'dendriteDilation': int(self.dendriteDilation.text() if self.dendriteDilation.text() != '' else 1), 443 | 'cleanDendriteThreshold': float(self.cleanDendriteThreshold.text()), 444 | 'minDendriteSize': int(self.minDendriteSize.text()) 445 | } 446 | 447 | def previewCleaning(self): 448 | if self.preview.isChecked(): 449 | self.previewSignal.emit(self._settings()) 450 | 451 | def close(self): 452 | self.settings = self._settings() 453 | 454 | self.go = True 455 | super().close() 456 | 457 | ########### 458 | ## Hack for Double Slider 459 | ## https://stackoverflow.com/questions/42820380/use-float-for-qslider 460 | ########### 461 | class DoubleSlider(QSlider): 462 | # create our our signal that we can connect to if necessary 463 | doubleValueChanged = pyqtSignal(float) 464 | 465 | def __init__(self, decimals=2, *args, **kargs): 466 | super(DoubleSlider, self).__init__( *args, **kargs) 467 | self._multi = 10 ** decimals 468 | 469 | self.valueChanged.connect(self.emitDoubleValueChanged) 470 | 471 | def emitDoubleValueChanged(self): 472 | value = float(super(DoubleSlider, self).value())/self._multi 473 | self.doubleValueChanged.emit(value) 474 | 475 | def value(self): 476 | return float(super(DoubleSlider, self).value()) / self._multi 477 | 478 | def setMinimum(self, value): 479 | return super(DoubleSlider, self).setMinimum(value * self._multi) 480 | 481 | def setMaximum(self, value): 482 | return super(DoubleSlider, self).setMaximum(value * self._multi) 483 | 484 | def setSingleStep(self, value): 485 | return super(DoubleSlider, self).setSingleStep(value * self._multi) 486 | 487 | def singleStep(self): 488 | return float(super(DoubleSlider, self).singleStep()) / self._multi 489 | 490 | def setValue(self, value): 491 | super(DoubleSlider, self).setValue(int(value * self._multi)) 492 | 493 | ############################################## 494 | class testROI(QWidget): 495 | settings = pyqtSignal(dict) 496 | 497 | def __init__(self, stack, d, s, settings=None) -> None: 498 | """Tests ROI building in 3D 499 | 500 | Args: 501 | stack (numpy.ndarray): Part of stack for testing 502 | d (numpy.ndarray): Dendrite prediction, same shape as `stack` 503 | s (numpy.ndarray): Spine prediction, same shape as `stack` 504 | settings (dict, optional): Settings for ROI testing. Defaults to None. 505 | """ 506 | super().__init__() 507 | 508 | self.stack = stack 509 | self.d = d 510 | self.s = s 511 | 512 | self.l = QGridLayout(self) 513 | 514 | self.l.addWidget(QLabel("Gaussian filter")) 515 | self.gaussianFilter = DoubleSlider(orientation=Qt.Horizontal) 516 | self.gaussianFilter.setMinimum(0) 517 | self.gaussianFilter.setMaximum(3.0) 518 | self.gaussianFilter.setSingleStep(0.1) 519 | self.gaussianFilter.setValue(0.0) 520 | self.gaussianFilter.valueChanged.connect(self.do) 521 | self.l.addWidget(self.gaussianFilter) 522 | 523 | self.data = QComboBox() 524 | self.data.addItems(["spine prediction", "intensity"]) 525 | self.data.currentTextChanged.connect(self.do) 526 | self.l.addWidget(self.data) 527 | 528 | self.mode = QComboBox() 529 | self.mode.addItems(["floodfill", "connected components"]) 530 | self.mode.currentTextChanged.connect(self.do) 531 | self.l.addWidget(self.mode) 532 | 533 | 534 | self.l.addWidget(QHLine()) 535 | 536 | self.l.addWidget(QLabel("Area threshold")) 537 | self.areaThreshold = DoubleSlider(orientation=Qt.Horizontal) 538 | self.areaThreshold.setMinimum(0) 539 | self.areaThreshold.setMaximum(1.0) 540 | self.areaThreshold.setSingleStep(0.05) 541 | self.areaThreshold.setValue(0.2) 542 | self.areaThreshold.valueChanged.connect(self.do) 543 | self.l.addWidget(self.areaThreshold) 544 | 545 | self.l.addWidget(QLabel("Peak threshold")) 546 | self.peakThreshold = DoubleSlider(orientation=Qt.Horizontal) 547 | self.peakThreshold.setMinimum(0) 548 | self.peakThreshold.setMaximum(1.0) 549 | self.peakThreshold.setSingleStep(0.05) 550 | self.peakThreshold.setValue(0.8) 551 | self.peakThreshold.valueChanged.connect(self.do) 552 | self.l.addWidget(self.peakThreshold) 553 | 554 | self.l.addWidget(QLabel("Difference to seed intensity [%]")) 555 | self.seedDelta = DoubleSlider(orientation=Qt.Horizontal) 556 | self.seedDelta.setMinimum(0) 557 | self.seedDelta.setMaximum(1.0) 558 | self.seedDelta.setSingleStep(0.05) 559 | self.seedDelta.setValue(0.5) 560 | self.seedDelta.valueChanged.connect(self.do) 561 | self.l.addWidget(self.seedDelta) 562 | 563 | self.l.addWidget(QLabel("Distance to seed pixel in microns")) 564 | self.distanceToSeed = DoubleSlider(orientation=Qt.Horizontal) 565 | self.distanceToSeed.setMinimum(0) 566 | self.distanceToSeed.setMaximum(10) 567 | self.distanceToSeed.setSingleStep(0.05) 568 | self.distanceToSeed.setValue(2) 569 | self.distanceToSeed.valueChanged.connect(self.do) 570 | self.l.addWidget(self.distanceToSeed) 571 | 572 | self.l.addWidget(QHLine()) 573 | 574 | self.l.addWidget(QLabel("Minimum pixel in ROI")) 575 | self.minPx = QSlider(orientation=Qt.Horizontal) 576 | self.minPx.setMinimum(1) 577 | self.minPx.setMaximum(500) 578 | self.minPx.setSingleStep(5) 579 | self.minPx.setValue(10) 580 | self.minPx.valueChanged.connect(self.do) 581 | self.l.addWidget(self.minPx) 582 | 583 | self.l.addWidget(QLabel("Maximum pixel in ROI")) 584 | self.maxPx = QSlider(orientation=Qt.Horizontal) 585 | self.maxPx.setMinimum(1) 586 | self.maxPx.setMaximum(10000) 587 | self.maxPx.setSingleStep(5) 588 | self.maxPx.setValue(1000) 589 | self.maxPx.valueChanged.connect(self.do) 590 | self.l.addWidget(self.maxPx) 591 | 592 | self.l.addWidget(QLabel("Minimum planes in ROI")) 593 | self.minPlanes = QSlider(orientation=Qt.Horizontal) 594 | self.minPlanes.setMinimum(1) 595 | self.minPlanes.setMaximum(10) 596 | self.minPlanes.setSingleStep(1) 597 | self.minPlanes.setValue(3) 598 | self.minPlanes.valueChanged.connect(self.do) 599 | self.l.addWidget(self.minPlanes) 600 | 601 | self.computeContours = QCheckBox("Compute contours") 602 | self.computeContours.clicked.connect(self.do) 603 | self.l.addWidget(self.computeContours) 604 | 605 | self.applyWatershed = QCheckBox("Apply Watershed") 606 | self.applyWatershed.clicked.connect(self.do) 607 | self.l.addWidget(self.applyWatershed) 608 | 609 | self.saveBtn = QPushButton("Save settings") 610 | self.saveBtn.clicked.connect(self.saveSettings) 611 | self.l.addWidget(self.saveBtn) 612 | 613 | self.imv = pg.ImageView() 614 | self.imv.setMinimumWidth(600) 615 | self.imv.setMinimumHeight(400) 616 | self.l.addWidget(self.imv, 0, 1, 20, 1) 617 | self.imv.setImage(self.stack.transpose(0,2,1)) 618 | 619 | self.overlayItem = pg.ImageItem(np.zeros(self.stack.shape[1:]), 620 | compositionMode=QPainter.CompositionMode_Plus) 621 | self.imv.getView().addItem(self.overlayItem) 622 | 623 | if type(settings) != type(None): 624 | self.areaThreshold.setValue(settings['areaThreshold']) 625 | self.peakThreshold.setValue(settings['peakThreshold']) 626 | self.seedDelta.setValue(settings['seedDelta']) 627 | self.distanceToSeed.setValue(settings['distanceToSeed']) 628 | self.minPx.setValue(settings['minPx']) 629 | self.maxPx.setValue(settings['maxPx']) 630 | self.minPlanes.setValue(settings['minPlanes']) 631 | self.applyWatershed.setChecked(settings['applyWatershed']) 632 | 633 | self.do() 634 | 635 | # Update overlay when z location changes 636 | self.imv.sigTimeChanged.connect(self.changeOverlay) 637 | 638 | def changeOverlay(self): 639 | ix = self.imv.currentIndex 640 | 641 | self.overlayItem.setImage(self.rgb[ix].transpose(1,0,2)) 642 | 643 | def do(self): 644 | """Actually generating ROIs 645 | """ 646 | # Uses either intensity or spine prediction 647 | if self.data.currentText() == 'intensity': 648 | s = self.stack 649 | else: 650 | s = self.s 651 | 652 | sg = gaussian_filter(s, self.gaussianFilter.value()) 653 | 654 | self.roi3d = ROI3D_Creator(self.d, sg, 655 | mode=self.mode.currentText(), 656 | areaThreshold=self.areaThreshold.value(), 657 | peakThreshold=self.peakThreshold.value(), 658 | seedDelta=self.seedDelta.value(), 659 | distanceToSeed=self.distanceToSeed.value()) 660 | 661 | self.roi3d.computeContours = self.computeContours.isChecked() 662 | 663 | self.roi3d.create(minPx=self.minPx.value(), 664 | maxPx=self.maxPx.value(), 665 | minPlanes=self.minPlanes.value(), 666 | applyWatershed=self.applyWatershed.isChecked()) 667 | 668 | self.rgb = label2rgb(self.roi3d.roi_map) 669 | 670 | self.changeOverlay() 671 | 672 | def saveSettings(self): 673 | self.settings.emit({ 674 | 'areaThreshold': self.areaThreshold.value(), 675 | 'peakThreshold': self.peakThreshold.value(), 676 | 'seedDelta': self.seedDelta.value(), 677 | 'distanceToSeed': self.distanceToSeed.value(), 678 | 'minPx': self.minPx.value(), 679 | 'maxPx': self.maxPx.value(), 680 | 'minPlanes': self.minPlanes.value(), 681 | 'applyWatershed': self.applyWatershed.isChecked() 682 | }) 683 | 684 | ############################################## 685 | class ImageView(pg.ImageView): 686 | xy = pyqtSignal(QPointF) 687 | testROIbuilding = pyqtSignal(QPointF) 688 | 689 | def __init__(self, *args, **kwargs): 690 | """ImageView - interact with mouse press event 691 | """ 692 | super().__init__(*args, **kwargs) 693 | 694 | def mousePressEvent(self, e): 695 | if e.button() == Qt.LeftButton: 696 | # Get xy coordinate in relation to scene 697 | xy = self.getImageItem().mapFromScene(e.pos()) 698 | # Emit signal to tell interface where mouse click position is 699 | self.xy.emit(xy) 700 | 701 | def mouseDoubleClickEvent(self, e) -> None: 702 | xy = self.getImageItem().mapFromScene(e.pos()) 703 | 704 | self.testROIbuilding.emit(xy) 705 | 706 | 707 | class Interface(QWidget): 708 | def __init__(self, fn, pred_fn, rois_fn, logs_fn, 709 | dimensions=dict(xy=0.094, z=0.5)): 710 | """Main GUI interface 711 | 712 | Args: 713 | fn (str): path to microscopy stack 714 | pred_fn (str): path to microscopy stack prediction file 715 | rois_fn (str): path to microscopy stack ROIs file 716 | logs_fn (str): path to microscopy stack log file 717 | dimensions (dict, optional): Dimensions of stack in µm. Defaults to dict(xy=0.094, z=0.5). 718 | """ 719 | super().__init__() 720 | self.l = QGridLayout(self) 721 | self.fn = fn 722 | self.pred_fn = pred_fn 723 | self.rois_fn = rois_fn 724 | self.logs_fn = logs_fn 725 | 726 | self.dimensions = dimensions 727 | 728 | self.log("##################") 729 | self.log("File opened.") 730 | 731 | self.S = Stack(self.fn, self.pred_fn, dimensions) 732 | self.selectedRow = -1 733 | 734 | self.settings = None 735 | 736 | # Load ROIs if previously generated and saved... 737 | if os.path.exists(self.rois_fn): 738 | r = fl.load(self.rois_fn) 739 | self.rois = r['rois'] 740 | self.roi_map = r['roi_map'] 741 | else: 742 | self.rois = None 743 | self.roi_map = None 744 | 745 | # ROIs currently seen 746 | self.roisOnImage = [] 747 | 748 | # Some presets 749 | self.currentIndex = 0 750 | self.showROIs = True 751 | self.showSegmentation = True 752 | self.showLabels = False 753 | self.showMaxProjection = False 754 | 755 | # Annotations 756 | self.annotations = None 757 | self.annotationItems = [] 758 | self.annotationColor = QColor(0,0,255, 127) 759 | 760 | # Create an ImageView inside the central widget 761 | self.imv = ImageView() 762 | self.imv.setImage(self.S.stack.transpose(0,2,1)) 763 | self.imv.xy.connect(self.roiSelection) 764 | self.imv.testROIbuilding.connect(self.testROIbuilding) 765 | 766 | # Prediction overlay 767 | self.overlay = np.zeros(self.S.stack.shape[1:]) 768 | self.overlayItem = pg.ImageItem(self.overlay, compositionMode=QPainter.CompositionMode_Plus) 769 | self.imv.getView().addItem(self.overlayItem) 770 | 771 | # Update overlay when z location changes 772 | self.imv.sigTimeChanged.connect(self._changeOverlay) 773 | 774 | self.l.addWidget(self.imv) 775 | 776 | ## ROI TABLE 777 | self.table = QTableWidget() 778 | self.table.setMinimumWidth(100) 779 | self.table.setMaximumWidth(250) 780 | self.table.setColumnCount(3) 781 | 782 | self.table.setHorizontalHeaderLabels(["Z", "X", "Y"]) 783 | self.table.itemSelectionChanged.connect(self.getSelection) 784 | 785 | h = self.table.horizontalHeader() 786 | h.setSectionResizeMode(QHeaderView.Stretch) 787 | 788 | self.l.addWidget(self.table, 0, 1) 789 | 790 | ## PROGRESS BAR 791 | self.p = QProgressBar() 792 | self.p.setMinimumWidth(100) 793 | self.p.setMaximumWidth(250) 794 | self.p.setMinimum(0) 795 | self.p.setMaximum(1) 796 | 797 | self.l.addWidget(self.p, 1, 1) 798 | 799 | ## INFO LABEL BOTTOM LEFT 800 | self.info = QLabel() 801 | self.l.addWidget(self.info, 1, 0) 802 | 803 | self.t = [] 804 | 805 | def log(self, s): 806 | with open(self.logs_fn, "a+") as fp: 807 | now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 808 | fp.write(f"{now}: {s}\n") 809 | 810 | def testROIbuilding(self, xy): 811 | """Test ROI building using dedicated interface. 812 | Interface is opened at particular stack location where user double-clicked 813 | 814 | Args: 815 | xy (QPoint): XY Location of pointer during click 816 | """ 817 | x = int(xy.x()) 818 | y = int(xy.y()) 819 | z = int(self.currentIndex) 820 | 821 | # Stack size for testing, hardcoded 822 | w = 128 823 | h = 128 824 | 825 | # Open window of a subset of microscopy stack 826 | # z-stack +- 3 planes, w/2 left and right to click, h/2 top and bottom to click 827 | self.testROIwindow = testROI(self.S.stack[z-3:z+3, y-h//2:y+h//2, x-w//2:x+w//2], 828 | self.S.prediction[z-3:z+3, y-h//2:y+h//2, x-w//2:x+w//2, 0], 829 | self.S.prediction[z-3:z+3, y-h//2:y+h//2, x-w//2:x+w//2, 1], 830 | self.settings) 831 | self.testROIwindow.settings.connect(self.saveSettingsROI3D) 832 | self.testROIwindow.show() 833 | 834 | def saveSettingsROI3D(self, settings): 835 | """Save test settings to global settings 836 | 837 | Args: 838 | settings (dict): 3D ROI settings 839 | """ 840 | QMessageBox.information(self, "Saved!", "Settings were saved.") 841 | self.settings = settings 842 | 843 | def drawROIs(self): 844 | if type(self.rois) is type(None): 845 | return 846 | 847 | # Remove all items from scene beforehand 848 | for roi in self.roisOnImage: 849 | self.imv.getView().removeItem(roi) 850 | 851 | self.roisOnImage = [] 852 | 853 | ####### SHOW ROI 854 | if self.showROIs: 855 | for i, roi in enumerate(self.rois[self.currentIndex]): 856 | # Create a path containing the contour 857 | path = QPainterPath() 858 | path.addPolygon(QPolygonF([QPointF(*i[::-1]) for i in roi['contour']])) 859 | 860 | # Add path to a graphical item 861 | roiOnImage = QGraphicsPathItem() 862 | roiOnImage.setPath(path) 863 | # White if not selected, yellow if selected, 1px width, solid line 864 | roiOnImage.setPen(QPen(Qt.white if i != self.selectedRow else Qt.yellow, 1, Qt.SolidLine)) 865 | 866 | self.roisOnImage.append(roiOnImage) 867 | self.imv.getView().addItem(roiOnImage) 868 | 869 | def populateTable(self): 870 | """Populates ROI table 871 | """ 872 | # No ROIs available? Do nothing. 873 | if self.rois is None: 874 | return 875 | 876 | # Remove all entries 877 | self.table.setRowCount(0) 878 | 879 | # Populate table 880 | for roi in self.rois[self.currentIndex]: 881 | i = self.table.rowCount() 882 | self.table.setRowCount(i+1) 883 | 884 | if len(roi['centroid']) == 2: 885 | y, x = roi['centroid'] 886 | z = float(self.currentIndex) 887 | else: 888 | z, y, x = roi['centroid'] 889 | 890 | self.table.setItem(i, 0, QTableWidgetItem(f"{z:.2f}")) 891 | self.table.setItem(i, 1, QTableWidgetItem(f"{x:.2f}")) 892 | self.table.setItem(i, 2, QTableWidgetItem(f"{y:.2f}")) 893 | 894 | # And then draw all ROIs 895 | self.drawROIs() 896 | 897 | def roiSelection(self, xy): 898 | """Highlight selected ROI due to left click 899 | 900 | Args: 901 | xy (QPoint): Clicking location 902 | """ 903 | # Get xy mouse click coordinates 904 | x = xy.x() 905 | y = xy.y() 906 | 907 | ds = [] 908 | 909 | 910 | if self.rois is None: 911 | return 912 | 913 | # Compute euclidean distance from mouse click position 914 | # to ROI centroid 915 | for i, roi in enumerate(self.rois[self.currentIndex]): 916 | if len(roi['centroid']) == 2: 917 | d = (roi['centroid'][1]-y)**2 + (roi['centroid'][0]-x)**2 918 | else: 919 | d = (roi['centroid'][1]-y)**2 + (roi['centroid'][2]-x)**2 920 | 921 | ds.append(d) 922 | 923 | # Check if there are any ROIs around... 924 | if not len(ds): 925 | return 926 | 927 | # Find closest 928 | closest = np.argmin(ds) 929 | 930 | # Highlights ROI closest to click 931 | self.info.setText(f"z:{self.currentIndex}, x:{x}, y:{y}, closest ROI idx: {closest}") 932 | 933 | # Select element in table and update ROIs 934 | self.table.selectRow(closest) 935 | self.drawROIs() 936 | 937 | def getSelection(self): 938 | """Sets the current row selection in table and updates ROIs, 939 | because the selected ROI has a different color. 940 | """ 941 | if len(self.table.selectedIndexes()): 942 | self.selectedRow = self.table.selectedIndexes()[0].row() 943 | else: 944 | self.selectedRow = -1 945 | 946 | self.drawROIs() 947 | 948 | 949 | def _changeOverlay(self, z): 950 | """Hook for z-slider 951 | 952 | Args: 953 | z (int): current z index 954 | """ 955 | self.changeOverlay(z) 956 | 957 | def changeOverlay(self, z, preview=False): 958 | """When the z-slider is changed, update the overlay image (i.e. the prediction) 959 | 960 | Args: 961 | z (int): current z-location in stack 962 | """ 963 | 964 | # If segmentation should be shown 965 | if self.showSegmentation and self.showMaxProjection: 966 | if preview == False: 967 | self.overlayItem.setImage(self.S.predictionMaxProjection.transpose(1,0,2)) 968 | else: 969 | self.overlayItem.setImage(self.S.previewMaxProjection.transpose(1,0,2)) 970 | 971 | elif self.showSegmentation and not self.showMaxProjection: 972 | if preview == False: 973 | self.overlayItem.setImage(self.S.prediction[z].transpose(1,0,2)) 974 | else: 975 | self.overlayItem.setImage(self.S.preview[z].transpose(1,0,2)) 976 | 977 | 978 | elif self.showLabels and self.roi_map is not None: 979 | rm = np.zeros(self.roi_map.shape[1:]+(3,)) 980 | rm[self.roi_map[z] > 0] = (0, 255, 0) 981 | self.overlayItem.setImage(rm.transpose(1, 0, 2)) 982 | 983 | else: 984 | self.overlayItem.setImage(np.zeros_like(self.S.stack[z])) 985 | 986 | if type(self.annotations) != type(None): 987 | for i in self.annotationItems: 988 | self.imv.removeItem(i) 989 | 990 | self.annotationItems = [] 991 | 992 | for i, row in self.annotations.iterrows(): 993 | if 'Z' in row.keys(): 994 | pos = 'Z' 995 | else: 996 | pos = 'Pos' 997 | 998 | distance = abs(row[pos]-self.currentIndex) 999 | 1000 | if distance < 2: 1001 | size = 2-distance 1002 | e = QGraphicsEllipseItem(row['X']-size/2, row['Y']-size/2, size, size) 1003 | e.setBrush(self.annotationColor) 1004 | e.setPen(QPen(Qt.NoPen)) 1005 | 1006 | self.annotationItems.append(e) 1007 | self.imv.addItem(e) 1008 | 1009 | 1010 | if self.currentIndex != z: 1011 | # Re-populate table with ROIs from current z-location 1012 | self.populateTable() 1013 | 1014 | self.drawROIs() 1015 | 1016 | self.currentIndex = z 1017 | 1018 | def updateProgress(self, pval, pmax): 1019 | """updates progress bar 1020 | 1021 | Args: 1022 | pval (int): current value 1023 | pmax (int): target value 1024 | """ 1025 | self.t.append(time()) 1026 | 1027 | # Provide ETA 1028 | if len(self.t) > 1: 1029 | dt = abs(np.diff(np.asarray(self.t)).mean()) 1030 | self.p.setFormat(f"{int((pval+1)/pmax*100)}% - ETA: {int((pmax-pval+1)*dt)} s") 1031 | 1032 | # Adjust progress bar settings 1033 | self.p.setMinimum(0) 1034 | self.p.setMaximum(pmax) 1035 | self.p.setValue(pval+1) 1036 | 1037 | # Check if everything is done 1038 | if pval+1 == pmax: 1039 | self.t = [] 1040 | self.p.setFormat("DONE!") 1041 | 1042 | # Ensure progress bar updates 1043 | QApplication.processEvents() 1044 | 1045 | def keyPressEvent(self, e): 1046 | if e.key() == Qt.Key_Delete: 1047 | q = QMessageBox.question(self, "Delete ROIs", "Do you want to delete the selected ROIs?") 1048 | 1049 | if q == QMessageBox.Yes: 1050 | checked_rows = [] 1051 | 1052 | # Delete rows from the back to keep order 1053 | for i in self.table.selectedIndexes()[::-1]: 1054 | cur_row = i.row() 1055 | 1056 | if cur_row in checked_rows: 1057 | continue 1058 | 1059 | checked_rows.append(cur_row) 1060 | 1061 | # Remove ROI at given location 1062 | ROI_id = self.rois[self.currentIndex][cur_row]['ROI_id'] 1063 | 1064 | for z in self.rois.keys(): 1065 | self.rois[z] = [j for j in self.rois[z] if j['ROI_id'] != ROI_id] 1066 | # self.rois[self.currentIndex].pop(cur_row) 1067 | 1068 | # Update Table and ROIs in scene 1069 | self.populateTable() 1070 | self.drawROIs() 1071 | 1072 | 1073 | ############################## 1074 | ##### MAIN WINDOW 1075 | ############################## 1076 | class Main(QMainWindow): 1077 | def __init__(self): 1078 | """Main window for inference GUI 1079 | """ 1080 | super().__init__() 1081 | self.status = self.statusBar() 1082 | self.menu = self.menuBar() 1083 | self.model_fn = None 1084 | 1085 | # Main top menu 1086 | self.file = self.menu.addMenu("&File") 1087 | self.file.addAction("Open", self.open, shortcut=QKeySequence("Ctrl+N")) 1088 | self.file.addAction("Import annotations", self.importAnnotations, shortcut=QKeySequence("Ctrl+I")) 1089 | self.file.addAction("Save", self.save, shortcut=QKeySequence("Ctrl+S")) 1090 | self.file.addAction("Close", self.close) 1091 | 1092 | self.analyze = self.menu.addMenu("&Analyze") 1093 | self.analyze.setEnabled(False) 1094 | self.analyze.addAction("Segment dendrite and spines", self.segment) 1095 | self.analyze.addAction("Cleaning", self.cleaning) 1096 | self.analyze.addAction("2D ROI detection", self.roi2d) 1097 | self.analyze.addAction("3D ROI detection", self.roi3d) 1098 | self.analyze.addAction("Z projection", self.zprojection) 1099 | self.analyze.addAction("Set dimensions", self.setDimensions) 1100 | 1101 | 1102 | self.view = self.menu.addMenu("&View") 1103 | self.view.setEnabled(False) 1104 | 1105 | self.showROIs = QAction("Show ROIs", self, checkable=True) 1106 | self.showROIs.setChecked(True) 1107 | self.showROIs.triggered.connect(self.setShowROIs) 1108 | 1109 | self.showSegmentation = QAction("Show Segmentation", self, checkable=True) 1110 | self.showSegmentation.setChecked(True) 1111 | self.showSegmentation.triggered.connect(self.setShowSegmentation) 1112 | 1113 | self.showMaxProjection = QAction("Show Maximum Projection", self, checkable=True) 1114 | self.showMaxProjection.setChecked(False) 1115 | self.showMaxProjection.triggered.connect(self.setShowMaxProjection) 1116 | 1117 | self.showLabels = QAction("Show Labels", self, checkable=True) 1118 | self.showLabels.setChecked(False) 1119 | self.showLabels.triggered.connect(self.setShowLabels) 1120 | 1121 | self.view.addAction(self.showSegmentation) 1122 | self.view.addAction(self.showMaxProjection) 1123 | self.view.addAction(self.showLabels) 1124 | self.view.addAction(self.showROIs) 1125 | 1126 | 1127 | self.exportData = self.menu.addMenu("&Export") 1128 | self.exportData.setEnabled(False) 1129 | self.exportData.addAction("Export predictions as tif", self.exportPredictions) 1130 | self.exportData.addAction("Export ROIs to ImageJ", self.exportImageJ) 1131 | self.exportData.addAction("Export ROIs to folder", self.exportToFolderStructure) 1132 | self.exportData.addAction("Export ROI map to file", self.exportRoiMap) 1133 | self.exportData.addAction("Export ROI centroids to file", self.exportRoiCentroids) 1134 | 1135 | 1136 | # Central widget 1137 | self.w = None 1138 | 1139 | # Title 1140 | self.setWindowTitle("Interface for spine and dendrite detection") 1141 | self.setGeometry(100, 100, 1200, 600) 1142 | 1143 | def setShowROIs(self): 1144 | """Toggle ROIs on central widget 1145 | """ 1146 | self.w.showROIs = self.showROIs.isChecked() 1147 | self.w.drawROIs() 1148 | 1149 | def setShowMaxProjection(self): 1150 | """Shows maximum projection of stack and prediction in central widget 1151 | """ 1152 | self.w.showMaxProjection = self.showMaxProjection.isChecked() 1153 | 1154 | if self.w.showMaxProjection: 1155 | self.w.S.predictionMaxProjection = self.w.S.prediction.max(0) 1156 | self.w.S.previewMaxProjection = self.w.S.preview.max(0) 1157 | 1158 | self.w.changeOverlay(self.w.currentIndex) 1159 | 1160 | def setShowSegmentation(self): 1161 | """Toggle the segmentation visualization on central widget 1162 | """ 1163 | self.w.showSegmentation = self.showSegmentation.isChecked() 1164 | self.w.changeOverlay(self.w.currentIndex) 1165 | 1166 | def setShowLabels(self): 1167 | """Toggles the visualization of labels on central widget 1168 | """ 1169 | self.w.showLabels = self.showLabels.isChecked() 1170 | self.w.showSegmentation = False 1171 | self.showSegmentation.setChecked(False) 1172 | self.w.changeOverlay(self.w.currentIndex) 1173 | 1174 | def open(self): 1175 | """Open a z-stack for inference. 1176 | 1177 | If a prediction and/or ROIs already exist, do load these as well. 1178 | """ 1179 | self.fn = QFileDialog.getOpenFileName()[0] 1180 | 1181 | if self.fn: 1182 | # Create filepaths for related files 1183 | ext = self.fn.split(".")[-1] 1184 | self.pred_fn = self.fn[:-len(ext)-1]+".prediction" 1185 | self.rois_fn = self.fn[:-len(ext)-1]+".rois" 1186 | self.logs_fn = self.fn[:-len(ext)-1]+".log" 1187 | self.roi3d_settings_fn = self.fn[:-len(ext)-1]+".roi3d_settings" 1188 | 1189 | dim = askDimensions() 1190 | 1191 | # Create new instance of main widget 1192 | self.status.showMessage(self.fn) 1193 | self.w = Interface(self.fn, self.pred_fn, self.rois_fn, self.logs_fn, dim.dimensions()) 1194 | self.setCentralWidget(self.w) 1195 | 1196 | # Allow the analysis options 1197 | self.analyze.setEnabled(True) 1198 | self.view.setEnabled(True) 1199 | self.exportData.setEnabled(True) 1200 | 1201 | def importAnnotations(self): 1202 | """Import annotations to visualize those on the central widget 1203 | """ 1204 | fn = QFileDialog.getOpenFileName(filter="*.csv")[0] 1205 | 1206 | if fn: 1207 | df = pd.read_csv(fn, index_col=0) 1208 | 1209 | self.w.annotations = df 1210 | self.w.changeOverlay(self.w.currentIndex) 1211 | 1212 | def save(self): 1213 | """Save segmentation predictions and ROIs 1214 | """ 1215 | if type(self.w) == type(None): 1216 | QMessageBox.critical(self, "No file open", "Please open a file first.") 1217 | return 1218 | 1219 | text = [] 1220 | 1221 | # If segmentation is available 1222 | if self.w.S.segmented: 1223 | fl.save(self.pred_fn, dict(dendrites=self.w.S.prediction[...,0].astype(np.float32), 1224 | spines=self.w.S.prediction[..., 1].astype(np.float32)), 1225 | compression='blosc') 1226 | text.append("Predictions saved!") 1227 | 1228 | # If ROIs are available 1229 | if self.w.rois: 1230 | fl.save(self.rois_fn, dict(rois=self.w.rois, roi_map=self.w.roi_map), 1231 | compression='blosc') 1232 | 1233 | text.append("ROIs saved!") 1234 | 1235 | QMessageBox.information(self, "Saved.", "\n".join(text)) 1236 | 1237 | 1238 | def segment(self): 1239 | """Segment stack using user-defined settings 1240 | """ 1241 | s = Segment(self.model_fn) 1242 | s.exec_() 1243 | 1244 | if s.go: 1245 | # Only load tensorflow if needed. 1246 | # This is not neccessarily pythonic, but allows for fast GUI loading times 1247 | import tensorflow as tf 1248 | 1249 | self.w.log("Segmenting data...") 1250 | 1251 | for k, v in s.settings.items(): 1252 | self.w.log(f"{k}: {v}") 1253 | 1254 | self.w.S.tileSignal.connect(self.w.updateProgress) 1255 | self.model_fn = s.model_fn[0] 1256 | 1257 | # Decide which hardware will be utilized for inference 1258 | if s.settings['runOnCPU']: 1259 | context = '/cpu:0' 1260 | 1261 | else: 1262 | context = '/gpu:0' 1263 | 1264 | with tf.device(context): 1265 | # Predict tiles [4x average] 1266 | if s.selectInferenceMode.currentIndex() == 2: 1267 | pad_op = np.mean if s.settings['paddingOperation'] == 'mean' else np.min 1268 | 1269 | self.w.S.predictFourFold(s.model_fn[0], 1270 | tile_size=s.settings['tileSize'], 1271 | inset_size=s.settings['insetSize'], 1272 | pad_op=pad_op) 1273 | 1274 | # Predict the each plane completely in one go 1275 | elif s.selectInferenceMode.currentIndex() == 0: 1276 | self.w.S.predictWholeImage(s.model_fn[0]) 1277 | 1278 | # Predict tiles [1x] 1279 | else: 1280 | self.w.S.predictInset(s.model_fn[0], 1281 | s.settings['tileSize'], 1282 | s.settings['insetSize']) 1283 | 1284 | self.w.changeOverlay(self.w.currentIndex) 1285 | 1286 | def cleaning(self): 1287 | """Clean the prediction using user-specified settings 1288 | """ 1289 | self.c = Cleaning() 1290 | self.c.previewSignal.connect(self.previewCleaning) 1291 | self.c.exec_() 1292 | 1293 | if self.c.go: 1294 | print("Actually cleaning...") 1295 | settings = self.c._settings() 1296 | # settings = self.c.settings 1297 | 1298 | self.w.log("Cleaning stack...") 1299 | 1300 | for k, v in settings.items(): 1301 | self.w.log(f"{k}: {v}") 1302 | 1303 | if settings['closing']: 1304 | d = self.w.S.closing(settings['closingIterations']) 1305 | self.w.S.prediction[..., 0] = d 1306 | self.w.S.prediction[..., 2] = d 1307 | 1308 | if settings['cleanDendrite']: 1309 | d = self.w.S.cleanDendrite3D(settings['cleanDendriteThreshold'], settings['minDendriteSize']) 1310 | self.w.S.prediction[..., 0] = d 1311 | self.w.S.prediction[..., 2] = d 1312 | 1313 | if settings['cleanSpines']: 1314 | s = self.w.S.cleanSpines(settings['cleanDendriteThreshold'], settings['dendriteDilation']) 1315 | self.w.S.prediction[..., 1] = s 1316 | 1317 | if self.w.showMaxProjection: 1318 | self.w.S.predictionMaxProjection = self.w.S.prediction.max(0) 1319 | 1320 | def previewCleaning(self, settings): 1321 | """Preview cleaning settings to specify the settings 1322 | 1323 | Args: 1324 | settings (dict): cleaning settings 1325 | """ 1326 | self.w.S.preview = self.w.S.prediction.copy() 1327 | 1328 | # Clean dendrite in 3D 1329 | if settings['closing']: 1330 | d = self.w.S.closing(settings['closingIterations'], True) 1331 | self.w.S.preview[..., 0] = d 1332 | self.w.S.preview[..., 2] = d 1333 | 1334 | if settings['cleanDendrite']: 1335 | d = self.w.S.cleanDendrite3D(settings['cleanDendriteThreshold'], settings['minDendriteSize'], True) 1336 | self.w.S.preview[..., 0] = d 1337 | self.w.S.preview[..., 2] = d 1338 | 1339 | if settings['cleanSpines']: 1340 | s = self.w.S.cleanSpines(settings['cleanDendriteThreshold'], settings['dendriteDilation'], True) 1341 | self.w.S.preview[..., 1] = s 1342 | 1343 | if self.w.showMaxProjection: 1344 | self.w.S.previewMaxProjection = self.w.S.preview.max(0) 1345 | 1346 | self.w.changeOverlay(self.w.currentIndex, True) 1347 | 1348 | def roi2d(self): 1349 | """Create ROIs from segmentation 1350 | """ 1351 | if not self.w.S.segmented: 1352 | QMessageBox.critical(self, "No segmentation", "Please provide first a segmentation!") 1353 | return 1354 | 1355 | roi = ROI2D() 1356 | roi.exec_() 1357 | 1358 | if roi.go: 1359 | self.w.log("Building 2D ROIs...") 1360 | 1361 | for k, v in roi.settings.items(): 1362 | self.w.log(f"{k}: {v}") 1363 | 1364 | r = ROI2D_Creator(self.w.S.prediction[..., 0], # dendrite pred 1365 | self.w.S.prediction[..., 1], # spines pred 1366 | roi.settings['threshold']) # Threshold for labelling map 1367 | r.zSignal.connect(self.w.updateProgress) 1368 | 1369 | # Create ROIs for each z-plane 1370 | self.w.info.setText("Building ROIs...") 1371 | r.create(roi.applyWatershed.isChecked()) 1372 | 1373 | # Clean ROIs using maximum distance to dendrite 1374 | if roi.cleanROIs.isChecked(): 1375 | self.w.info.setText("Cleaning ROIs...") 1376 | old, new = r.clean(roi.settings['maxDendriteDistance'], 1377 | roi.settings['minSize'], 1378 | roi.settings['dendriteThreshold']) 1379 | 1380 | QMessageBox.information(self, "ROIs cleaned", 1381 | f"I cleaned all ROIs for you! \nOld rois: {old}\nNew rois: {new}.") 1382 | 1383 | self.w.roi_map = r.roi_map 1384 | self.w.rois = r.rois 1385 | self.w.populateTable() 1386 | 1387 | def roi3d(self): 1388 | """Create ROIs from segmentation in 3D 1389 | """ 1390 | if not self.w.S.segmented: 1391 | QMessageBox.critical(self, "No segmentation", "Please provide first a segmentation!") 1392 | return 1393 | 1394 | roi = ROI3D(self.w.settings) 1395 | roi.exec_() 1396 | 1397 | if roi.go: 1398 | self.w.log("Building 3D ROIs...") 1399 | 1400 | for k, v in roi.settings.items(): 1401 | self.w.log(f"{k}: {v}") 1402 | 1403 | with open(self.roi3d_settings_fn, "w+") as fp: 1404 | json.dump(roi.settings, fp) 1405 | 1406 | r = ROI3D_Creator(self.w.S.prediction[..., 0], # dendrite pred 1407 | self.w.S.prediction[..., 1], # spines pred 1408 | roi.settings['method'], 1409 | roi.settings['areaThreshold'], 1410 | roi.settings['peakThreshold'], 1411 | roi.settings['seedDelta'], 1412 | roi.settings['distanceToSeed'], 1413 | dimensions=self.w.dimensions) # Threshold for labelling map 1414 | 1415 | r.zSignal.connect(self.w.updateProgress) 1416 | r.log.connect(self.w.log) 1417 | 1418 | # Create ROIs for each z-plane 1419 | self.w.info.setText("Building ROIs...") 1420 | r.create(roi.settings['minPx'], 1421 | roi.settings['maxPx'], 1422 | roi.settings['minPlanes'], 1423 | roi.settings['watershed']) 1424 | 1425 | # Save results 1426 | self.w.roi_map = r.roi_map 1427 | self.w.rois = r.rois 1428 | self.w.roi_centroids = r.roi_centroids 1429 | self.w.populateTable() 1430 | 1431 | def setDimensions(self): 1432 | """Set dimensions for z-stack to ensure proper functionality (e.g. distance measures) 1433 | """ 1434 | dim = askDimensions(self.w.dimensions['xy'], self.w.dimensions['z']) 1435 | self.w.dimensions = dim.dimensions() 1436 | self.w.S.dimensions = dim.dimensions() 1437 | 1438 | def zprojection(self): 1439 | """Show a maximum and summed intensity z-projection for the full stack in separate windows 1440 | """ 1441 | self.i = pg.image(self.w.S.prediction.max(0).transpose(1,0,2), 1442 | title="Maximum intensity projection") 1443 | 1444 | self.j = pg.image(self.w.S.prediction.sum(0).transpose(1,0,2), 1445 | title="Summed intensity over z") 1446 | 1447 | def exportImageJ(self): 1448 | """Export ROIs to ImageJ 1449 | """ 1450 | fn = QFileDialog.getSaveFileName(caption="Select an ROI file for saving", 1451 | filter="*.zip") 1452 | 1453 | if fn[0]: 1454 | self.w.info.setText("Exporting ROIs as ImageJ ROI zip file...") 1455 | eij = ExportImageJ(self.w.rois) 1456 | eij.zSignal.connect(self.w.updateProgress) 1457 | eij.export(fn[0]) 1458 | 1459 | QMessageBox.information(self, "Done!", f"ROIs exported to\n{fn[0]}") 1460 | 1461 | def exportRoiCentroids(self): 1462 | """Export ROI centroids as CSV file 1463 | """ 1464 | fn = QFileDialog.getSaveFileName(caption="Select a CSV file to save the ROI centroids", 1465 | filter="*.csv") 1466 | 1467 | if fn[0]: 1468 | e = ExportCentroids(self.w.roi_centroids) 1469 | e.export(fn[0]) 1470 | 1471 | QMessageBox.information(self, "Done!", f"ROI centroids exported to\n{fn[0]}") 1472 | 1473 | 1474 | def exportRoiMap(self): 1475 | """Export ROI map as TIFF stack 1476 | """ 1477 | fn = QFileDialog.getSaveFileName(caption="Select a tif file to save the ROI map", 1478 | filter="*.tif") 1479 | 1480 | if fn[0]: 1481 | result = QMessageBox.question(self, "Binarize?", "Should I binarize the ROIs?") 1482 | 1483 | self.w.info.setText("Exporting ROIs as ROI map to tif file...") 1484 | emap = ExportROIMap(self.w.roi_map, result == QMessageBox.Yes) 1485 | emap.export(fn[0]) 1486 | 1487 | QMessageBox.information(self, "Done!", f"ROI map exported to\n{fn[0]}") 1488 | 1489 | 1490 | def exportToFolderStructure(self): 1491 | """Export ROIs as folder structure 1492 | """ 1493 | folder = QFileDialog.getExistingDirectory(caption="Select folder to export ROIs") 1494 | 1495 | if not folder: 1496 | return 1497 | 1498 | ef = ExportFolder(self.w.rois) 1499 | ef.zSignal.connect(self.w.updateProgress) 1500 | ef.export(self.rois_fn, folder) 1501 | 1502 | def exportPredictions(self): 1503 | """Export neural network prediction as TIFF stacks 1504 | """ 1505 | folder = QFileDialog.getExistingDirectory(caption="Select folder to export predictions") 1506 | 1507 | if not folder: 1508 | return 1509 | 1510 | ep = ExportPredictions(self.w.S.prediction[..., 1], # spines pred 1511 | self.w.S.prediction[..., 0]) # dendrite pred 1512 | r, e = ep.export(self.fn, folder) 1513 | 1514 | if r: 1515 | QMessageBox.information(self, "Exported!", f"Successfully exported to \n{e}") 1516 | 1517 | else: 1518 | QMessageBox.critical(self, "Failed to export!", f"I had problems to export the data... \n{e}") 1519 | 1520 | def main(): 1521 | app = QApplication([]) 1522 | 1523 | m = Main() 1524 | m.show() 1525 | 1526 | sys.exit(app.exec_()) 1527 | 1528 | 1529 | if __name__ == '__main__': 1530 | main() -------------------------------------------------------------------------------- /deepd3/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import DeepD3_Model -------------------------------------------------------------------------------- /deepd3/model/builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.layers import Input, Conv2D, Concatenate, MaxPool2D, \ 4 | UpSampling2D, BatchNormalization, Activation, Add 5 | 6 | def convlayer(x, filters, activation, name, residual=None, use_batchnorm=True): 7 | """Convolutional layer with normalization and residual connection 8 | 9 | Args: 10 | x (Keras.layer): input layer 11 | filters (int): filters used in convolutional layer 12 | activation (str): Activation function 13 | name (str): Description of layer 14 | residual (Keras.layer, optional): Residual layer. Defaults to None. 15 | use_batchnorm (bool, optional): Use of batch normalization. Defaults to True. 16 | 17 | Returns: 18 | Keras.layer: Full convolutional procedure 19 | """ 20 | x = Conv2D(filters, 3, padding='same', use_bias=False, name=name)(x) 21 | 22 | if use_batchnorm: 23 | x = BatchNormalization(name=name+"_BN")(x) 24 | 25 | if type(residual) is not type(None): 26 | x = Add(name=name+"_residual_connection")((residual, x)) 27 | 28 | x = Activation(activation, name=name+"_activation")(x) 29 | return x 30 | 31 | def identity(x, filters, name): 32 | """Identity layer for residual layers 33 | 34 | Args: 35 | x (Keras.layer): Keras layer 36 | filters (int): Used filters 37 | name (str): Layer description 38 | 39 | Returns: 40 | Keras.layer: Identity layer 41 | """ 42 | return Conv2D(filters, 1, padding='same', use_bias=False, name=name)(x) 43 | 44 | def decoder(x, filters, layers, to_concat, name, activation): 45 | """Decoder for neural network. 46 | 47 | Args: 48 | x (Keras layer): Start of decoder, normally the latent space 49 | filters (int): The filter multiplier 50 | layers (int): Depth layers to be used for upsampling 51 | to_concat (list): Encoder layers to be concatenated 52 | name (str): Description of the decoder 53 | activation (str): Activation function used in Decoder 54 | 55 | Returns: 56 | Keras layer: Full decoder across layers 57 | """ 58 | # Decoder 59 | for i in range(layers): 60 | # Upsamples the current activation maps by a factor of 2x2 61 | x = UpSampling2D()(x) 62 | 63 | # Concatenates respective encoder layer 64 | x = Concatenate()([x, to_concat.pop()]) 65 | 66 | # Applies two convolutional layers 67 | x = convlayer(x, filters*2**(layers-1-i), activation, f"{name}_dec_layer{layers-i}_conv1") 68 | x = convlayer(x, filters*2**(layers-1-i), activation, f"{name}_dec_layer{layers-i}_conv2") 69 | 70 | # Final point-wise convolution to achieve prediction maps 71 | x = Conv2D(1, 1, padding='same', name=name, activation='sigmoid')(x) 72 | 73 | return x 74 | 75 | def DeepD3_Model(filters=32, input_shape=(None, None, 1), layers=4, activation="swish"): 76 | """DeepD3 TensorFlow Keras Model. It defines the architecture, 77 | together with the single encoder and dual decoders. 78 | 79 | Args: 80 | filters (int, optional): Base filter multiplier. Defaults to 32. 81 | input_shape (tuple, optional): Image shape for training. Defaults to (128, 128, 1). 82 | layers (int, optional): Network depth layers. Defaults to 4. 83 | activation (str, optional): Activation function used in convolutional layers. Defaults to "swish". 84 | 85 | Returns: 86 | Model: function TensorFlow/Keras model 87 | """ 88 | # Save concatenation layers 89 | to_concat = [] 90 | 91 | # Create model input 92 | model_input = Input(input_shape, name="input") 93 | x = model_input 94 | 95 | # Common Encoder 96 | for i in range(layers): 97 | residual = identity(x, filters*2**i, f"enc_layer{i}_identity") 98 | x = convlayer(x, filters*2**i, activation, f"enc_layer{i}_conv1") 99 | x = convlayer(x, filters*2**i, activation, residual=residual, name=f"enc_layer{i}_conv2") 100 | to_concat.append(x) 101 | x = MaxPool2D()(x) 102 | 103 | # Latent 104 | x = convlayer(x, filters*2**(i+1), activation, f"latent_conv") 105 | 106 | # Two decoder, for dendrites and spines each 107 | dendrites = decoder(x, filters, layers, to_concat.copy(), "dendrites", activation) 108 | spines = decoder(x, filters, layers, to_concat.copy(), "spines", activation) 109 | 110 | return Model(model_input, [dendrites, spines]) 111 | 112 | if __name__ == '__main__': 113 | # Test if model creation and training works 114 | import numpy as np 115 | 116 | # Create Model 117 | m = DeepD3_Model(8, input_shape=(48,48,1)) 118 | 119 | # Create a random dataset of 100 images of tile size 128x128 120 | X = np.random.randn(48*48*100).reshape(100, 48, 48, 1) 121 | 122 | print(m.summary()) 123 | 124 | # Prepare and fit for one epoch 125 | m.compile('sgd',['mae','mse']) 126 | m.fit(X, [X, X], epochs=1) -------------------------------------------------------------------------------- /deepd3/model/utils.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Input 2 | from tfkerassurgeon import delete_layer, insert_layer 3 | 4 | def changeFirstLayer(model): 5 | new_input = Input(shape=(None, None, 1), name='arbitrary_input') 6 | 7 | model = delete_layer(model.layers[0]) 8 | # inserts before layer 0 9 | model = insert_layer(model.layers[0], new_input) 10 | 11 | return model 12 | -------------------------------------------------------------------------------- /deepd3/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankilab/DeepD3/94de9d4697f00e82097c8775b924bf7ba4e624a7/deepd3/training/__init__.py -------------------------------------------------------------------------------- /deepd3/training/generator.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QApplication, QLabel, QGridLayout, \ 2 | QSizePolicy, QWidget, QPushButton, QFileDialog, QLineEdit, QDialog, \ 3 | QProgressBar, QMessageBox, QCheckBox, QListWidget, QTreeWidgetItem, QTreeWidget 4 | from PyQt5.QtGui import QPainter, QKeyEvent, QDoubleValidator, QIntValidator 5 | from PyQt5.QtCore import Qt, pyqtSignal 6 | import pyqtgraph as pg 7 | import imageio as io 8 | import numpy as np 9 | import os 10 | import flammkuchen as fl 11 | import pandas as pd 12 | from datetime import datetime 13 | from deepd3.core.dendrite import DendriteSWC 14 | 15 | class Viewer(QWidget): 16 | def __init__(self, fn) -> None: 17 | """View d3data set 18 | 19 | Args: 20 | fn (str): path to d3data file 21 | """ 22 | super().__init__() 23 | 24 | self.fn = fn 25 | self.d = fl.load(fn, "/data") 26 | self.m = fl.load(fn, "/meta") 27 | 28 | self.l = QGridLayout(self) 29 | self.imv = pg.ImageView() 30 | self.imv.setMinimumWidth(800) 31 | self.imv.setMinimumHeight(800) 32 | 33 | self.imv.setImage(self.d['stack'].transpose(0, 2, 1)) 34 | self.imv.sigTimeChanged.connect(self.plane) 35 | 36 | ### Add overlay 37 | # Prediction overlay 38 | self.overlay = np.zeros(self.d['stack'].shape[1:]+(3,), dtype=np.uint8) 39 | self.overlayItem = pg.ImageItem(self.overlay, compositionMode=QPainter.CompositionMode_Plus) 40 | self.imv.getView().addItem(self.overlayItem) 41 | 42 | self.plane() 43 | 44 | # TREE 45 | self.tree = QTreeWidget() 46 | self.tree.setHeaderLabels(['Meta data']) 47 | 48 | for k, v in self.m.items(): 49 | ki = QTreeWidgetItem([k]) 50 | child = QTreeWidgetItem([str(v)]) 51 | child.setFlags(child.flags() | Qt.ItemIsEditable) 52 | ki.addChild(child) 53 | 54 | self.tree.addTopLevelItem(ki) 55 | 56 | self.tree.expandToDepth(2) 57 | 58 | self.l.addWidget(self.imv, 0, 0) 59 | self.l.addWidget(self.tree, 0, 1) 60 | 61 | saveButton = QPushButton("Save") 62 | saveButton.clicked.connect(self.save) 63 | self.l.addWidget(saveButton, 1, 1) 64 | 65 | def save(self): 66 | """Save d3data set 67 | """ 68 | ok = QMessageBox.question(self, "Overwrite?", 69 | "Are you sure? Do you want to overwrite this dataset? There are no sanity checks!") 70 | 71 | if ok != QMessageBox.Yes: 72 | return 73 | 74 | new_m = dict() 75 | 76 | # Go through each item 77 | for i in range(self.tree.invisibleRootItem().childCount()): 78 | c = self.tree.invisibleRootItem().child(i) 79 | 80 | k = c.text(0) 81 | v = c.child(0).text(0) 82 | 83 | t = type(self.m[k]) 84 | 85 | if t == np.float32 or t == np.float64: 86 | new_m[k] = float(v) 87 | 88 | elif t == np.bool_: 89 | new_m[k] = True if v.lower() == "true" else False 90 | 91 | elif t == np.int64 or t == np.int32: 92 | new_m[k] = int(v) 93 | 94 | else: 95 | new_m[k] = str(v) 96 | 97 | new_m["Changed"] = datetime.now().strftime(r"%Y%m%d_%H%M%S") 98 | 99 | fl.save(self.fn, dict(data=self.d, meta=new_m), compression='blosc') 100 | 101 | QMessageBox.information(self, "Saved!", 102 | f"You saved the data successfully to \n {self.fn}") 103 | 104 | def plane(self): 105 | """Overlay current annotation plane 106 | """ 107 | idx = self.imv.currentIndex 108 | 109 | self.overlay[:, :, 0] = self.d['dendrite'][idx].astype(np.uint8) * 255 110 | self.overlay[:, :, 2] = self.d['dendrite'][idx].astype(np.uint8) * 255 111 | self.overlay[:, :, 1] = self.d['spines'][idx].astype(np.uint8) * 255 112 | 113 | self.overlayItem.setImage(self.overlay.transpose(1, 0, 2)) 114 | 115 | class Arrange(QWidget): 116 | def __init__(self) -> None: 117 | """Arrange d3data files to a common d3set. 118 | """ 119 | super().__init__() 120 | 121 | self.setGeometry(100, 100, 400, 600) 122 | 123 | addData = QPushButton("Add data to set") 124 | addData.clicked.connect(self.addData) 125 | 126 | self.l = QGridLayout(self) 127 | self.l.addWidget(addData) 128 | 129 | self.l.addWidget(QLabel("Added stacks")) 130 | 131 | self.list = QListWidget() 132 | self.list.setMinimumHeight(200) 133 | self.l.addWidget(self.list) 134 | 135 | createSet = QPushButton("Create dataset") 136 | createSet.clicked.connect(self.createSet) 137 | 138 | self.l.addWidget(createSet) 139 | 140 | self.show() 141 | 142 | def keyPressEvent(self, a0) -> None: 143 | if a0.key() == Qt.Key_Delete: 144 | print("Del key pressed") 145 | self.removeSelection() 146 | 147 | return super().keyPressEvent(a0) 148 | 149 | def addData(self): 150 | """Add selected d3data files 151 | """ 152 | self.fns = QFileDialog.getOpenFileNames(filter='*.d3data')[0] 153 | 154 | if self.fns: 155 | for i in self.fns: 156 | self.list.addItem(i) 157 | 158 | def removeSelection(self): 159 | """Remove selected d3data sets 160 | """ 161 | listItems = self.list.selectedItems() 162 | 163 | if not listItems: 164 | return 165 | 166 | ok = QMessageBox.question(self, "Really?", 167 | "Do you want to delete the selected items?") 168 | 169 | if ok != QMessageBox.Yes: 170 | return 171 | 172 | for item in listItems: 173 | self.list.takeItem(self.list.row(item)) 174 | 175 | def createSet(self): 176 | """Create d3set from selected d3data files. 177 | """ 178 | save_fn = QFileDialog.getSaveFileName(caption="Select dataset filename", filter="*.d3set")[0] 179 | 180 | # If the saving fn is ok 181 | if save_fn: 182 | stacks = {} 183 | dendrites = {} 184 | spines = {} 185 | meta = pd.DataFrame() 186 | 187 | # For each dataset, add to set 188 | for i in range(self.list.count()): 189 | fn = self.list.item(i).text() 190 | print(fn, "...") 191 | 192 | d = fl.load(fn) 193 | 194 | stacks[f"x{i}"] = d['data']['stack'] 195 | dendrites[f"x{i}"] = d['data']['dendrite'] 196 | spines[f"x{i}"] = d['data']['spines'] 197 | 198 | m = pd.DataFrame([d['meta']]) 199 | meta = pd.concat((meta, m), axis=0, ignore_index=True) 200 | 201 | fl.save(save_fn, dict(data=dict(stacks=stacks, dendrites=dendrites, spines=spines), 202 | meta=meta), compression='blosc') 203 | 204 | QMessageBox.information(self, "Saved!", 205 | f"New dataset saved as \n{save_fn}") 206 | 207 | 208 | ####################################### 209 | class ImageView(pg.ImageView): 210 | # Signal for first/last plane [0, 1] and current index 211 | cur_i = pyqtSignal(int, int) 212 | 213 | def __init__(self, *args, **kwargs): 214 | """Custom Image View for emitting current index w.r.t. key press event 215 | """ 216 | super().__init__(*args, **kwargs) 217 | 218 | def keyPressEvent(self, ev): 219 | # Emit first plane at current index 220 | if ev == Qt.Key_S: 221 | self.cur_i.emit(0, self.currentIndex) 222 | 223 | # Emit last plane at current index 224 | elif ev == Qt.Key_E: 225 | self.cur_i.emit(1, self.currentIndex) 226 | 227 | return super().keyPressEvent(ev) 228 | 229 | 230 | ######################## 231 | ## ASK SPACING for DENDRITE TRACINGS 232 | ######################## 233 | class askSpacing(QDialog): 234 | def __init__(self) -> None: 235 | """User interface for spacing w.r.t. dendrite tracing 236 | """ 237 | super().__init__() 238 | 239 | self.l = QGridLayout(self) 240 | 241 | self.x = QLineEdit() 242 | self.x.setPlaceholderText("Default: 1") 243 | self.x.setValidator(QDoubleValidator()) 244 | self.l.addWidget(QLabel("X spacing")) 245 | self.l.addWidget(self.x) 246 | 247 | self.y = QLineEdit() 248 | self.y.setPlaceholderText("Default: 1") 249 | self.y.setValidator(QDoubleValidator()) 250 | self.l.addWidget(QLabel("Y spacing")) 251 | self.l.addWidget(self.y) 252 | 253 | self.z = QLineEdit() 254 | self.z.setPlaceholderText("Default: 1") 255 | self.z.setValidator(QDoubleValidator()) 256 | self.l.addWidget(QLabel("Z spacing")) 257 | self.l.addWidget(self.z) 258 | 259 | self.exec_() 260 | 261 | def spacing(self): 262 | """ Converts spacing 263 | 264 | Returns: 265 | tuple(float, float, float): spacing in µm in x, y and z 266 | """ 267 | x = 1. if self.x.text() == "" else float(self.x.text()) 268 | y = 1. if self.y.text() == "" else float(self.y.text()) 269 | z = 1. if self.z.text() == "" else float(self.z.text()) 270 | 271 | return [x,y,z] 272 | 273 | 274 | ########################## 275 | ## Window for creating a dataset for training 276 | ## from annotated data 277 | ########################## 278 | class addStackWidget(QWidget): 279 | def __init__(self) -> None: 280 | """Creates d3data dataset for DeepD3 training 281 | """ 282 | super().__init__() 283 | l = QGridLayout(self) 284 | self.setGeometry(100, 100, 1100, 650) 285 | 286 | self.imv = ImageView() 287 | self.imv.setMinimumWidth(800) 288 | self.imv.setMaximumWidth(1400) 289 | self.imv.cur_i.connect(self.updateZ) 290 | 291 | self.roi = pg.RectROI((100, 100), (200, 300), pen="r") 292 | self.roi.sigRegionChangeFinished.connect(self.updateROI) 293 | self.imv.addItem(self.roi) 294 | 295 | 296 | ### 297 | self.stack = None 298 | self.dendrite = None 299 | self.spines = None 300 | 301 | ################## 302 | # Stack 303 | ################## 304 | l.addWidget(QLabel("Stack"), 0, 0, 1, 2) 305 | self.fn_stack = QLabel() 306 | l.addWidget(self.fn_stack, 1, 0, 1, 2) 307 | 308 | l.addWidget(self.imv, 0, 2, 25, 1) 309 | 310 | selectStackBtn = QPushButton("Select stack") 311 | selectStackBtn.clicked.connect(self.selectStack) 312 | l.addWidget(selectStackBtn, 2, 0, 1, 2) 313 | 314 | ################## 315 | # Dendrite tracing 316 | ################## 317 | l.addWidget(QLabel("Dendrite tracings"), 3, 0, 1, 2) 318 | self.fn_d = QLabel("") 319 | l.addWidget(self.fn_d, 4, 0, 1, 2) 320 | 321 | self.selectDendriteBtn = QPushButton("Select dendrite tracings") 322 | self.selectDendriteBtn.clicked.connect(self.selectDendrite) 323 | self.selectDendriteBtn.setEnabled(False) 324 | l.addWidget(self.selectDendriteBtn, 5, 0, 1, 2) 325 | 326 | ################## 327 | # Spines 328 | ################## 329 | l.addWidget(QLabel("Spines"), 6, 0, 1, 2) 330 | self.fn_s = QLabel("") 331 | l.addWidget(self.fn_s, 7, 0, 1, 2) 332 | 333 | self.selectSpinesBtn = QPushButton("Select spine annotations") 334 | self.selectSpinesBtn.clicked.connect(self.selectSpines) 335 | self.selectSpinesBtn.setEnabled(False) 336 | l.addWidget(self.selectSpinesBtn, 8, 0, 1, 2) 337 | 338 | l.addWidget(QLabel("Resolution"), 9, 0, 1, 2) 339 | 340 | self.res_xy = QLineEdit() 341 | self.res_xy.setValidator(QDoubleValidator()) 342 | self.res_xy.setPlaceholderText("XY, in microns, e.g. 0.09 for 90 nm resolution in x and y") 343 | 344 | self.res_z = QLineEdit() 345 | self.res_z.setValidator(QDoubleValidator()) 346 | self.res_z.setPlaceholderText("Z, in microns, e.g. 0.5 for 500 nm step size") 347 | 348 | l.addWidget(self.res_xy, 10, 0, 1, 2) 349 | l.addWidget(self.res_z, 11, 0, 1, 2) 350 | 351 | l.addWidget(QLabel("Determine offsets using the ROI"), 13, 0, 1, 2) 352 | 353 | self.cropToROI = QCheckBox("Crop annotation to ROI") 354 | self.cropToROI.setChecked(True) 355 | l.addWidget(self.cropToROI, 14, 0) 356 | 357 | l.addWidget(QLabel("x"), 15, 0) 358 | 359 | self.offsets_x = QLineEdit("") 360 | self.offsets_x.setValidator(QIntValidator()) 361 | l.addWidget(self.offsets_x, 15, 1) 362 | 363 | l.addWidget(QLabel("y"), 16, 0) 364 | 365 | self.offsets_y = QLineEdit("") 366 | self.offsets_y.setValidator(QIntValidator()) 367 | l.addWidget(self.offsets_y, 16, 1) 368 | 369 | l.addWidget(QLabel("w"), 17, 0) 370 | 371 | self.offsets_w = QLineEdit("") 372 | self.offsets_w.setValidator(QIntValidator()) 373 | l.addWidget(self.offsets_w, 17, 1) 374 | 375 | l.addWidget(QLabel("h"), 18, 0) 376 | 377 | self.offsets_h = QLineEdit("") 378 | self.offsets_h.setValidator(QIntValidator()) 379 | l.addWidget(self.offsets_h, 18, 1) 380 | 381 | self.zValidator = QIntValidator() 382 | self.zValidator.setRange(0, 1) 383 | 384 | l.addWidget(QLabel("z (begin), shortcut B"), 19, 0) 385 | self.offsets_z_begin = QLineEdit("") 386 | self.offsets_z_begin.setValidator(self.zValidator) 387 | l.addWidget(self.offsets_z_begin, 19, 1) 388 | 389 | l.addWidget(QLabel("z (end), shortcut E"), 20, 0) 390 | self.offsets_z_end = QLineEdit("") 391 | self.offsets_z_end.setValidator(self.zValidator) 392 | l.addWidget(self.offsets_z_end, 20, 1) 393 | 394 | self.progressbar = QProgressBar() 395 | 396 | l.addWidget(self.progressbar, 21, 0, 1, 2) 397 | 398 | saveBtn = QPushButton("Save annotation stack") 399 | saveBtn.clicked.connect(self.save) 400 | l.addWidget(saveBtn, 22, 0, 1, 2) 401 | 402 | 403 | expand = QLabel() 404 | sizePolicy = QSizePolicy(QSizePolicy.Expanding , QSizePolicy.Expanding ) 405 | expand.setSizePolicy(sizePolicy) 406 | 407 | l.addWidget(expand, 24, 0) 408 | 409 | def updateZ(self, a, b): 410 | """Updates z-level in stack 411 | 412 | Args: 413 | a (int): z-stack begin, first plane 414 | b (int): z-stack end, last plane 415 | """ 416 | if a == 0: 417 | self.offsets_z_begin.setText(str(b)) 418 | 419 | else: 420 | self.offsets_z_end.setText(str(b)) 421 | 422 | def keyPressEvent(self, a0: QKeyEvent) -> None: 423 | """Key press event to enable shortcuts 424 | 425 | Args: 426 | a0 (QKeyEvent): Key event 427 | 428 | """ 429 | if a0.key() == Qt.Key_S or a0.key() == Qt.Key_B: 430 | self.offsets_z_begin.setText(str(self.imv.currentIndex)) 431 | 432 | elif a0.key() == Qt.Key_E: 433 | self.offsets_z_end.setText(str(self.imv.currentIndex)) 434 | 435 | return super().keyPressEvent(a0) 436 | 437 | def save(self): 438 | """Save a d3data set 439 | """ 440 | fn = QFileDialog.getSaveFileName(filter="*.d3data")[0] 441 | 442 | if not fn: 443 | return 444 | 445 | x = int(self.offsets_x.text()) 446 | y = int(self.offsets_y.text()) 447 | w = int(self.offsets_w.text()) 448 | h = int(self.offsets_h.text()) 449 | z_begin = int(self.offsets_z_begin.text()) 450 | z_end = int(self.offsets_z_end.text()) 451 | 452 | if x < 0: 453 | x = 0 454 | 455 | if y < 0: 456 | y = 0 457 | 458 | # Check for maximum size 459 | if x+w >= self.im.shape[2]: 460 | w = self.im.shape[2]-x-1 461 | 462 | if y+h >= self.im.shape[1]: 463 | h = self.im.shape[1]-y-1 464 | 465 | if z_end - z_begin < 0 or z_begin < 0 or z_end >= self.im.shape[0]: 466 | QMessageBox.critical(self, "Z span invalid", 467 | "Please check for z_begin and z_end.") 468 | 469 | try: 470 | res_xy = float(self.res_xy.text()) 471 | res_z = float(self.res_z.text()) 472 | except Exception as e: 473 | QMessageBox.critical(self, "Something went wrong", 474 | f"{e}") 475 | return 476 | 477 | if self.cropToROI.isChecked(): 478 | stack = self.im[z_begin:z_end+1, y:y+h, x:x+w] 479 | dendrite = self.dendrite[z_begin:z_end+1, y:y+h, x:x+w] 480 | spines = self.spines[z_begin:z_end+1, y:y+h, x:x+w] 481 | 482 | else: 483 | stack = self.im[z_begin:z_end+1] 484 | dendrite = self.dendrite[z_begin:z_end+1] 485 | spines = self.spines[z_begin:z_end+1] 486 | 487 | 488 | data = { 489 | 'stack': stack, 490 | 'dendrite': dendrite > 0, 491 | 'spines': spines > 0 492 | } 493 | 494 | meta = { 495 | 'crop': self.cropToROI.isChecked(), 496 | 'X': x, 497 | 'Y': y, 498 | 'Width': w, 499 | 'Height': h, 500 | 'Depth': z_end-z_begin+1, 501 | 'Z_begin': z_begin, 502 | 'Z_end': z_end, 503 | 'Resolution_XY': res_xy, 504 | 'Resolution_Z': res_z, 505 | 'Timestamp': datetime.now().strftime(r"%Y%m%d_%H%M%S"), 506 | 'Generated_from': self.fn_stack.text() 507 | } 508 | 509 | try: 510 | fl.save(fn, dict(data=data, meta=meta), compression='blosc') 511 | QMessageBox.information(self, "Data saved", 512 | f"Your data was successfully saved:\n{fn}") 513 | 514 | except Exception as e: 515 | QMessageBox.critical(self, "Could not save data", 516 | f"{e}") 517 | 518 | 519 | def updateROI(self): 520 | """Updates the ROI chosen as dataset 521 | """ 522 | # Retrieve xy location and ROI rectangle size 523 | pos = int(self.roi.pos().x()), int(self.roi.pos().y()) # x, y 524 | size = int(self.roi.size().x()), int(self.roi.size().y()) 525 | 526 | # Update the offset fields 527 | self.offsets_x.setText(str(pos[0])) 528 | self.offsets_y.setText(str(pos[1])) 529 | self.offsets_w.setText(str(size[0])) 530 | self.offsets_h.setText(str(size[1])) 531 | 532 | 533 | def selectStack(self): 534 | """Select a microscopy stack 535 | """ 536 | fn = QFileDialog.getOpenFileName(caption="Select stack", filter="*.tif")[0] 537 | 538 | if fn: 539 | print(fn) 540 | self.progressbar.setMaximum(10) 541 | self.fn_stack.setText(fn) 542 | self.im = np.asarray(io.mimread(fn, memtest=False)) 543 | self.progressbar.setValue(9) 544 | self.imv.setImage(self.im.transpose(0, 2, 1)) 545 | self.progressbar.setValue(10) 546 | 547 | ### Enable other buttons 548 | self.selectDendriteBtn.setEnabled(True) 549 | self.selectSpinesBtn.setEnabled(True) 550 | 551 | ### Add overlay 552 | # Prediction overlay 553 | self.overlay = np.zeros(self.im.shape[1:]+(3,), dtype=np.uint8) 554 | self.overlayItem = pg.ImageItem(self.overlay, compositionMode=QPainter.CompositionMode_Plus) 555 | self.imv.getView().addItem(self.overlayItem) 556 | 557 | # Update overlay when z location changes 558 | self.imv.sigTimeChanged.connect(self.changeOverlay) 559 | 560 | self.offsets_z_begin.setText("0") 561 | self.offsets_z_end.setText(str(self.im.shape[0]-1)) 562 | 563 | self.zValidator.setRange(0, self.im.shape[0]-1) 564 | 565 | def changeOverlay(self): 566 | """Show the dendrite and spine annotations as overlay in addition to the original stack 567 | """ 568 | self.overlay = np.zeros_like(self.overlay) 569 | 570 | # current z index 571 | cur_i = self.imv.currentIndex 572 | 573 | # if dendrite segmentation is available 574 | if type(self.dendrite) != type(None): 575 | self.overlay[..., 0] = self.dendrite[cur_i] 576 | self.overlay[..., 2] = self.dendrite[cur_i] 577 | 578 | # if spines segmentation is available 579 | if type(self.spines) != type(None): 580 | self.overlay[..., 1] = self.spines[cur_i] * 255 581 | 582 | # Show the image 583 | self.overlayItem.setImage(self.overlay.transpose(1, 0, 2)) 584 | 585 | ########################################### 586 | def selectDendrite(self): 587 | """Select dendrite annotation file 588 | """ 589 | fn = QFileDialog.getOpenFileName(caption="Select dendrite tracings", filter="*.tif; *.swc")[0] 590 | 591 | if fn: 592 | self.fn_d.setText(fn) 593 | 594 | if fn.endswith("swc"): 595 | target_fn = fn[:-4] + "_dendrite.tif" 596 | 597 | if os.path.exists(target_fn): 598 | ok = QMessageBox.question(self, 599 | "Keep it?", 600 | "We found an existing converted dendritic trace. Should I keep it?") 601 | 602 | if ok == QMessageBox.Yes: 603 | self.dendrite = np.asarray(io.mimread(target_fn, memtest=False)) 604 | return 605 | 606 | aS = askSpacing() 607 | 608 | ### Now convert dendrite 609 | d = DendriteSWC(spacing=aS.spacing()) 610 | d.node.connect(self.updateProgress) 611 | d.open(fn, self.fn_stack.text()) 612 | d.convert(target_fn) 613 | 614 | self.dendrite = np.asarray(io.mimread(target_fn, memtest=False)) 615 | 616 | else: 617 | self.dendrite = np.asarray(io.mimread(fn, memtest=False)) 618 | 619 | def updateProgress(self, a, b): 620 | """Update progress bar 621 | 622 | Args: 623 | a (int, float): maximum of progress bar 624 | b (int, float): current value of progress bar 625 | """ 626 | self.progressbar.setMaximum(b) 627 | self.progressbar.setValue(a) 628 | 629 | def selectSpines(self): 630 | """Select a spine annotation 631 | """ 632 | fn = QFileDialog.getOpenFileName(caption="Select stack", filter="*.tif, *.mask")[0] 633 | 634 | if fn: 635 | self.fn_s.setText(fn) 636 | 637 | if fn.endswith("mask"): 638 | # Load a pipra annotated mask file 639 | mask = fl.load(fn, "/mask").transpose(0, 2, 1) 640 | 641 | else: 642 | # Load a tif file 643 | mask = np.asarray(io.mimread(fn, memtest=False)) 644 | 645 | self.spines = mask 646 | 647 | 648 | class Selector(QWidget): 649 | def __init__(self): 650 | """Select a given task in the DeepD3 training pipeline 651 | """ 652 | super().__init__() 653 | 654 | l = QGridLayout(self) 655 | 656 | add = QPushButton("Create training data") 657 | add.setFixedWidth(300) 658 | add.setFixedHeight(100) 659 | add.clicked.connect(self.createTrainingData) 660 | 661 | l.addWidget(add) 662 | 663 | view = QPushButton("View training data") 664 | view.setFixedWidth(300) 665 | view.setFixedHeight(100) 666 | view.clicked.connect(self.viewTrainingData) 667 | 668 | l.addWidget(view) 669 | 670 | arrange = QPushButton("Arrange training data") 671 | arrange.setFixedWidth(300) 672 | arrange.setFixedHeight(100) 673 | arrange.clicked.connect(self.arrangeTrainingData) 674 | 675 | l.addWidget(arrange) 676 | 677 | def viewTrainingData(self): 678 | """View training data 679 | """ 680 | fn = QFileDialog.getOpenFileName(filter="*.d3data")[0] 681 | 682 | if fn: 683 | # If a valid filename was given 684 | self.c = Viewer(fn) 685 | self.c.show() 686 | 687 | def createTrainingData(self): 688 | """Create d3data set 689 | """ 690 | self.a = addStackWidget() 691 | self.a.show() 692 | 693 | def arrangeTrainingData(self): 694 | """Arrange training data (d3data files) in a d3set 695 | """ 696 | self.b = Arrange() 697 | self.b.show() 698 | 699 | def main(): 700 | """Main entry point to GUI 701 | """ 702 | app = QApplication([]) 703 | 704 | # Select which part of the DeepD3 training pipeline is used. 705 | s = Selector() 706 | s.show() 707 | 708 | app.exec_() 709 | 710 | if __name__ == '__main__': 711 | main() -------------------------------------------------------------------------------- /deepd3/training/stream.py: -------------------------------------------------------------------------------- 1 | # Matrix operations 2 | import numpy as np 3 | # Loading data 4 | import flammkuchen as fl 5 | # from keras.utils import Sequence 6 | from tensorflow.keras.utils import Sequence 7 | # Image manipulation 8 | import cv2 9 | # Image augmentations 10 | import albumentations as A 11 | # Shuffling images 12 | import random 13 | 14 | class DataGeneratorStream(Sequence): 15 | def __init__(self, fn, batch_size, samples_per_epoch=50000, size=(1, 128, 128), target_resolution=None, augment=True, 16 | shuffle=True, seed=42, normalize=[-1, 1], min_content=0. 17 | ): 18 | """Data Generator that streams data dynamically for training DeepD3. 19 | 20 | Args: 21 | fn (str): The path to the training data file 22 | batch_size (int): Batch size for training deep neural networks 23 | samples_per_epoch (int, optional): Samples used in each epoch. Defaults to 50000. 24 | size (tuple, optional): Shape of a single sample. Defaults to (1, 128, 128). 25 | target_resolution (float, optional): Target resolution in microns. Defaults to None. 26 | augment (bool, optional): Enables augmenting the data. Defaults to True. 27 | shuffle (bool, optional): Enabled shuffling the data. Defaults to True. 28 | seed (int, optional): Creates pseudorandom numbers for shuffling. Defaults to 42. 29 | normalize (list, optional): Values range when normalizing data. Defaults to [-1, 1]. 30 | min_content (float, optional): Minimum content in image (annotated dendrite or spine), not considered if 0. Defaults to 0. 31 | """ 32 | 33 | # Save settings 34 | self.batch_size = batch_size 35 | self.augment = augment 36 | self.fn = fn 37 | self.shuffle = shuffle 38 | self.aug = self._get_augmenter() 39 | self.seed = seed 40 | self.normalize = normalize 41 | self.samples_per_epoch = samples_per_epoch 42 | self.size = size 43 | self.target_resolution = target_resolution 44 | self.min_content = min_content 45 | 46 | self.d = fl.load(self.fn) 47 | self.data = self.d['data'] 48 | self.meta = self.d['meta'] 49 | 50 | # Seed randomness 51 | random.seed(self.seed) 52 | np.random.seed(self.seed) 53 | 54 | self.on_epoch_end() 55 | 56 | def __len__(self): 57 | """Denotes the number of batches per epoch""" 58 | return self.samples_per_epoch // self.batch_size 59 | 60 | def __getitem__(self, index): 61 | """Generate one batch of data 62 | 63 | Parameters 64 | ---------- 65 | index : int 66 | batch index in image/label id list 67 | 68 | Returns 69 | ------- 70 | tuple 71 | Contains two numpy arrays, 72 | each of shape (batch_size, height, width, 1). 73 | """ 74 | X = [] 75 | Y0 = [] 76 | Y1 = [] 77 | eps = 1e-5 78 | 79 | if self.shuffle is False: 80 | np.random.seed(index) 81 | 82 | # Create all pairs in a given batch 83 | for i in range(self.batch_size): 84 | # Retrieve a single sample pair 85 | image, dendrite, spines = self.getSample() 86 | 87 | # Augmenting the data 88 | if self.augment: 89 | augmented = self.aug(image=image, 90 | mask1=dendrite.astype(np.uint8), 91 | mask2=spines.astype(np.uint8)) #augment image 92 | 93 | image = augmented['image'] 94 | dendrite = augmented['mask1'] 95 | spines = augmented['mask2'] 96 | 97 | # Min/max scaling 98 | image = (image.astype(np.float32) - image.min()) / (image.max() - image.min() + eps) 99 | # Shifting and scaling 100 | image = image * (self.normalize[1]-self.normalize[0]) + self.normalize[0] 101 | 102 | X.append(image) 103 | Y0.append(dendrite.astype(np.float32) / (dendrite.max() + eps)) 104 | Y1.append(spines.astype(np.float32) / (spines.max() + eps)) # to ensure binary targets 105 | 106 | return np.asarray(X, dtype=np.float32)[..., None], (np.asarray(Y0, dtype=np.float32)[..., None], 107 | np.asarray(Y1, dtype=np.float32)[..., None]) 108 | 109 | 110 | def _get_augmenter(self): 111 | """Defines used augmentations""" 112 | aug = A.Compose([ 113 | A.RandomBrightnessContrast(p=0.25), 114 | A.Rotate(limit=10, border_mode=cv2.BORDER_REFLECT, p=0.5), 115 | A.RandomRotate90(p=0.5), 116 | A.HorizontalFlip(p=0.5), 117 | A.VerticalFlip(p=0.5), 118 | A.Blur(p=0.2), 119 | A.GaussNoise(p=0.5)], p=1, 120 | additional_targets={ 121 | 'mask1': 'mask', 122 | 'mask2': 'mask' 123 | }) 124 | return aug 125 | 126 | def getSample(self, squeeze=True): 127 | """Get a sample from the provided data 128 | 129 | Args: 130 | squeeze (bool, optional): if plane is 2D, skip 3D. Defaults to True. 131 | 132 | Returns: 133 | list(np.ndarray, np.ndarray, np.ndarray): stack image with respective labels 134 | """ 135 | while True: 136 | r = self._getSample(squeeze) 137 | 138 | # If sample was successfully generated 139 | # and we don't care about the content 140 | if r is not None and self.min_content == 0: 141 | return r 142 | 143 | # If sample was successfully generated 144 | # and we do care about the content 145 | elif r is not None: 146 | # In either or both annotation should be at least `min_content` pixels 147 | # that are being labelled. 148 | if (r[1]).sum() > self.min_content or (r[2]).sum() > self.min_content: 149 | return r 150 | else: 151 | continue 152 | 153 | else: 154 | continue 155 | 156 | def _getSample(self, squeeze=True): 157 | """Retrieves a sample 158 | 159 | Args: 160 | squeeze (bool, optional): Squeezes return shape. Defaults to True. 161 | 162 | Returns: 163 | tuple: Tuple of stack (X), dendrite (Y0) and spines (Y1) 164 | """ 165 | # Adjust for 2 images 166 | if len(self.size) == 2: 167 | size = (1,) + self.size 168 | 169 | else: 170 | size = self.size 171 | 172 | # sample random stack 173 | r_stack = np.random.choice(len(self.meta)) 174 | 175 | target_h = size[1] 176 | target_w = size[2] 177 | 178 | 179 | if self.target_resolution is None: 180 | # Keep everything as is 181 | scaling = 1 182 | h = target_h 183 | w = target_w 184 | 185 | else: 186 | # Computing scaling factor 187 | scaling = self.target_resolution / self.meta.iloc[r_stack].Resolution_XY 188 | 189 | # Compute the height and width and random offsets 190 | h = round(scaling * target_h) 191 | w = round(scaling * target_w) 192 | 193 | # Correct for stack dimensions 194 | if self.meta.iloc[r_stack].Width-w == 0: 195 | x = 0 196 | 197 | elif self.meta.iloc[r_stack].Width-w < 0: 198 | return 199 | 200 | else: 201 | x = np.random.choice(self.meta.iloc[r_stack].Width-w) 202 | 203 | # Correct for stack dimensions 204 | if self.meta.iloc[r_stack].Height-h == 0: 205 | y = 0 206 | 207 | elif self.meta.iloc[r_stack].Height-h < 0: 208 | return 209 | 210 | else: 211 | y = np.random.choice(self.meta.iloc[r_stack].Height-h) 212 | 213 | ## Select random plane + range 214 | r_plane = np.random.choice(self.meta.iloc[r_stack].Depth-size[0]+1) 215 | 216 | z_begin = r_plane 217 | z_end = r_plane+size[0] 218 | 219 | 220 | # Scale if neccessary to the correct dimensions 221 | tmp_stack = self.data['stacks'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w] 222 | tmp_dendrites = self.data['dendrites'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w] 223 | tmp_spines = self.data['spines'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w] 224 | 225 | # Data needs to be rescaled 226 | if scaling != 1: 227 | return_stack = [] 228 | return_dendrites = [] 229 | return_spines = [] 230 | 231 | # Do this for each plane 232 | # and ensure that OpenCV is happy 233 | for i in range(tmp_stack.shape[0]): 234 | return_stack.append(cv2.resize(tmp_stack[i], (target_h, target_w))) 235 | return_dendrites.append(cv2.resize(tmp_dendrites[i].astype(np.uint8), (target_h, target_w)).astype(bool)) 236 | return_spines.append(cv2.resize(tmp_spines[i].astype(np.uint8), (target_h, target_w)).astype(bool)) 237 | 238 | return_stack = np.asarray(return_stack) 239 | return_dendrites = np.asarray(return_dendrites) 240 | return_spines = np.asarray(return_spines) 241 | 242 | else: 243 | return_stack = tmp_stack 244 | return_dendrites = tmp_dendrites 245 | return_spines = tmp_spines 246 | 247 | if squeeze: 248 | # Return sample 249 | return return_stack.squeeze(), return_dendrites.squeeze(), return_spines.squeeze() 250 | 251 | else: 252 | return return_stack, return_dendrites, return_spines -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/api/core.rst: -------------------------------------------------------------------------------- 1 | DeepD3 core 2 | =========== 3 | 4 | analysis 5 | -------- 6 | 7 | .. automodule:: deepd3.core.analysis 8 | :members: 9 | :undoc-members: 10 | :private-members: 11 | 12 | dendrite 13 | -------- 14 | 15 | .. automodule:: deepd3.core.dendrite 16 | :members: 17 | :private-members: 18 | 19 | spines 20 | -------- 21 | 22 | .. automodule:: deepd3.core.spines 23 | :members: 24 | :private-members: 25 | :undoc-members: 26 | 27 | export 28 | -------- 29 | 30 | .. automodule:: deepd3.core.export 31 | :members: 32 | :private-members: 33 | 34 | distance 35 | -------- 36 | 37 | .. automodule:: deepd3.core.distance 38 | :members: 39 | :private-members: -------------------------------------------------------------------------------- /docs/source/api/inference.rst: -------------------------------------------------------------------------------- 1 | DeepD3 inference 2 | ================ 3 | 4 | gui 5 | -------- 6 | 7 | .. automodule:: deepd3.inference.gui 8 | :members: 9 | :private-members: 10 | -------------------------------------------------------------------------------- /docs/source/api/model.rst: -------------------------------------------------------------------------------- 1 | DeepD3 model 2 | ============ 3 | 4 | builder 5 | -------- 6 | 7 | .. automodule:: deepd3.model.builder 8 | :members: 9 | :private-members: -------------------------------------------------------------------------------- /docs/source/api/training.rst: -------------------------------------------------------------------------------- 1 | DeepD3 training 2 | =============== 3 | 4 | generator 5 | --------- 6 | 7 | .. automodule:: deepd3.training.generator 8 | :members: 9 | :private-members: 10 | 11 | stream 12 | --------- 13 | 14 | .. automodule:: deepd3.training.stream 15 | :members: 16 | :private-members: 17 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import os 10 | import sys 11 | sys.path.insert(0, os.path.abspath("./")) 12 | sys.path.insert(0, os.path.abspath("../")) 13 | 14 | project = 'DeepD3' 15 | copyright = '2023, Andreas M Kist, Martin H P Fernholz' 16 | author = 'Andreas M Kist, Martin H P Fernholz' 17 | release = '2023' 18 | 19 | # -- General configuration --------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 21 | 22 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx.ext.autosummary'] 23 | 24 | templates_path = ['_templates'] 25 | exclude_patterns = [] 26 | 27 | 28 | 29 | # -- Options for HTML output ------------------------------------------------- 30 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 31 | 32 | html_theme = 'sphinx_rtd_theme' 33 | html_static_path = ['_static'] 34 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. DeepD3 documentation master file, created by 2 | sphinx-quickstart on Fri Jan 27 14:13:09 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to DeepD3's documentation! 7 | ================================== 8 | 9 | .. image:: https://deepd3.forschung.fau.de/header.gif 10 | :alt: some DeepD3 inference 11 | :target: https://deepd3.forschung.fau.de/ 12 | 13 | With DeepD3, you are able to predict the presence and absence of 14 | dendritic spines and dendrites in 3D microscopy stacks. We evaluated 15 | DeepD3 on a variety of species, resolutions, dyes and microscopy types. 16 | In this documentation, you find information how to train your own DeepD3 model, 17 | how to use DeepD3 in inference mode and how to use the API to create custom scripts. 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :caption: User Guide 22 | :glob: 23 | 24 | userguide/* 25 | 26 | .. toctree:: 27 | :maxdepth: 2 28 | :caption: API 29 | :glob: 30 | 31 | api/* 32 | 33 | 34 | Indices and tables 35 | ================== 36 | 37 | * :ref:`genindex` 38 | * :ref:`modindex` 39 | * :ref:`search` 40 | -------------------------------------------------------------------------------- /docs/source/userguide/inference.rst: -------------------------------------------------------------------------------- 1 | DeepD3 inference 2 | ================ 3 | 4 | Open the inference mode using ``deepd3-inference``. 5 | Load your stack of choice (we currently support TIF stacks) 6 | and specify the XY and Z dimensions. Next, you can segment dendrites and 7 | dendritic spines using a DeepD3 model from `the Model Zoo `_ by clicking 8 | on ``Analyze -> Segment dendrite and spines``. Afterwards, you may clean 9 | the predictions by clicking on ``Analyze -> Cleaning``. Finally, you may 10 | build 2D or 3D ROIs using the respective functions in ``Analyze``. 11 | To test the 3D ROI building, double click in the stack to a region of interest. 12 | A window opens that allows you to play with the hyperparameters and segments 13 | 3D ROIs in real-time. 14 | 15 | All results can be exported to various file formats. 16 | For convenience, DeepD3 saves related data in its "proprietary" 17 | hdf5 file (that you can open using any hdf5 viewer/program/library). 18 | In particular, you may export the predictions as TIF files, the ROIs to 19 | ImageJ file format or a folder, the ROI map to a TIF file, or the 20 | ROI centroids to a file. 21 | 22 | Most functions can be assessed using a batch command script 23 | located in ``deepd3/inference/batch.py``. -------------------------------------------------------------------------------- /docs/source/userguide/train.rst: -------------------------------------------------------------------------------- 1 | Training your own DeepD3 model 2 | ============================== 3 | 4 | Use ``deepd3-training`` to start the GUI for generating training sets. 5 | 6 | For each of your training set, please provide 7 | 8 | * The original stack as e.g. TIF files 9 | * The spine annotations (binary labels) as TIF or MASK files (the latter from [pipra](https://github.com/anki-xyz/pipra)) 10 | * The dendrite annotations as SWC file (only tested for SWC-files generated by [NeuTube](https://neutracing.com/download/)) 11 | 12 | Create training data 13 | ------------------------ 14 | 15 | Click on the button "Create training data". For each of your stacks, import the stack, the spine annotation and the dendrite annotation file. 16 | If you dendrite annotation is a SWC file, it will create a 3D reconstruction of the SWC file, which will be stored for later use. If you reload the SWC, it will ask you if you want to keep the 3D reconstruction. 17 | 18 | After importing all files, enter the metadata (resolution in x, y and z) and determine the region of interest using the bounding box and the sliders. 19 | Shortcuts are ``B`` for current plane is **z begin** and ``E`` for **z end**. You may enable or disable the cropping to the bounding box. If you are happy, save this region as ``d3data``-file. 20 | 21 | View training data 22 | ------------------------ 23 | 24 | Click on the button "View training data" to re-visit 25 | any ``d3data`` files. You also are able to see and potentially 26 | manipulate the metadata associated the ``d3data`` file. 27 | 28 | Arrange training data 29 | ------------------------ 30 | 31 | For training, you need to create a ``d3set``. 32 | This is an assembly of ``d3data`` files. Click on the button 33 | "Arrange training data". Then, simply load all relevant data using 34 | the "Add data to set" button and select appropriate ``d3data`` files. 35 | Clicking on "Create dataset" allows you to save your assembly 36 | as ``d3set`` file. 37 | 38 | Actual training 39 | ------------------------ 40 | 41 | We have prepared a Jupyter notebook in the folder ``examples``. 42 | Follow the instructions to train your own deep neural network for DeepD3 use. 43 | For professionals, you also may utilize directly the files in ``model`` 44 | and ``training`` to allow highly individualized training. 45 | You only should ensure that your model allows arbitrary input and 46 | outputs two separate channels (dendrites and spines). -------------------------------------------------------------------------------- /examples/Training DeepD3 model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "7fccdf22", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Neural network libraries\n", 11 | "import os\n", 12 | "os.environ[\"SM_FRAMEWORK\"] = \"tf.keras\"\n", 13 | "import tensorflow as tf\n", 14 | "from tensorflow.keras.optimizers import Adam\n", 15 | "import segmentation_models as sm\n", 16 | "sm.set_framework(\"tf.keras\")\n", 17 | "\n", 18 | "# Plotting\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "%matplotlib inline\n", 21 | "\n", 22 | "# DeepD3 \n", 23 | "from deepd3.model import DeepD3_Model\n", 24 | "from deepd3.training.stream import DataGeneratorStream" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "232bb972", 30 | "metadata": {}, 31 | "source": [ 32 | "## Load training data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "4b70a2b2", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "TRAINING_DATA_PATH = r\"DeepD3_Training.d3set\"\n", 43 | "VALIDATION_DATA_PATH = r\"DeepD3_Validation.d3set\"\n", 44 | "\n", 45 | "dg_training = DataGeneratorStream(TRAINING_DATA_PATH, \n", 46 | " batch_size=32, # Data processed at once, depends on your GPU\n", 47 | " target_resolution=0.094, # fixed to 94 nm, can be None for mixed resolution training\n", 48 | " min_content=50) # images need to have at least 50 segmented px\n", 49 | "\n", 50 | "dg_validation = DataGeneratorStream(VALIDATION_DATA_PATH, \n", 51 | " batch_size=32, \n", 52 | " target_resolution=0.094,\n", 53 | " min_content=50, \n", 54 | " augment=False,\n", 55 | " shuffle=False)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "6ac109e3", 61 | "metadata": {}, 62 | "source": [ 63 | "## Visualize data" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "71927db0", 69 | "metadata": {}, 70 | "source": [ 71 | "Glancing on the data to verify that settings are as expected." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "a63b4a9d", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "X, Y = dg_training[0]\n", 82 | "i = 0\n", 83 | "\n", 84 | "plt.figure(figsize=(12,4))\n", 85 | "\n", 86 | "plt.subplot(131)\n", 87 | "plt.imshow(X[i].squeeze(), cmap='gray')\n", 88 | "plt.colorbar()\n", 89 | "\n", 90 | "plt.subplot(132)\n", 91 | "plt.imshow(Y[0][i].squeeze(), cmap='gray')\n", 92 | "plt.colorbar()\n", 93 | "\n", 94 | "plt.subplot(133)\n", 95 | "plt.imshow(Y[1][i].squeeze(), cmap='gray')\n", 96 | "plt.colorbar()\n", 97 | "\n", 98 | "plt.tight_layout()" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "d5f94b2c", 104 | "metadata": {}, 105 | "source": [ 106 | "## Creating model and set training parameters" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "c9c0904a", 113 | "metadata": { 114 | "scrolled": false 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "# Create a naive DeepD3 model with a given base filter count (e.g. 32)\n", 119 | "m = DeepD3_Model(filters=32)\n", 120 | "\n", 121 | "# Set appropriate training settings\n", 122 | "m.compile(Adam(learning_rate=0.0005), # optimizer, good default setting, can be tuned \n", 123 | " [sm.losses.dice_loss, \"mse\"], # Dice loss for dendrite, MSE for spines\n", 124 | " metrics=['acc', sm.metrics.iou_score]) # Metrics for monitoring progress\n", 125 | "\n", 126 | "m.summary()" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "816b5652", 132 | "metadata": {}, 133 | "source": [ 134 | "## Fitting model" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "0ffe2faf", 140 | "metadata": {}, 141 | "source": [ 142 | "Loading some training callbacks, such as adjusting the learning rate across time, saving training progress and intermediate models" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "ba233f13", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "19ec7aa4", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "def schedule(epoch, lr):\n", 163 | " if epoch < 15:\n", 164 | " return lr\n", 165 | " \n", 166 | " else:\n", 167 | " return lr * tf.math.exp(-0.1)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "id": "8b454030", 173 | "metadata": {}, 174 | "source": [ 175 | "# Train your own DeepD3 model" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "id": "69b4777e", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "EPOCHS = 30\n", 186 | "\n", 187 | "# Save best model automatically during training\n", 188 | "mc = ModelCheckpoint(\"DeepD3_model.h5\",\n", 189 | " save_best_only=True)\n", 190 | " \n", 191 | "# Save metrics \n", 192 | "csv = CSVLogger(\"DeepD3_model.csv\")\n", 193 | "\n", 194 | "# Adjust learning rate during training to allow for better convergence\n", 195 | "lrs = LearningRateScheduler(schedule)\n", 196 | "\n", 197 | "# Actually train the network\n", 198 | "h = m.fit(dg_training, \n", 199 | " batch_size=32, \n", 200 | " epochs=EPOCHS, \n", 201 | " validation_data=dg_validation, \n", 202 | " callbacks=[mc, csv, lrs])" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "4994fd8c", 208 | "metadata": {}, 209 | "source": [ 210 | "## Save model for use in GUI or batch processing" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "id": "27df4412", 216 | "metadata": {}, 217 | "source": [ 218 | "This is for saving the neural network manually. The best model is automatically saved during training." 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "id": "9532e339", 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "m.save(\"deepd3_custom_trained_model.h5\")" 229 | ] 230 | } 231 | ], 232 | "metadata": { 233 | "kernelspec": { 234 | "display_name": "Python 3.8.10 64-bit", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.8.10" 249 | }, 250 | "vscode": { 251 | "interpreter": { 252 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" 253 | } 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 5 258 | } 259 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepd3 2 | tensorflow -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # read the contents of your README file 4 | from os import path 5 | this_directory = path.abspath(path.dirname(__file__)) 6 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 7 | long_description = f.read() 8 | 9 | setup( 10 | name="deepd3", 11 | long_description=long_description, 12 | long_description_content_type='text/markdown', 13 | version="0.1", 14 | author="Andreas M Kist", 15 | author_email="andreas.kist@fau.de", 16 | license="GPLv3", 17 | packages=find_packages(), 18 | install_requires=[ 19 | "pyqtgraph>=0.10.0", 20 | "numpy", 21 | "numba", 22 | "flammkuchen", 23 | "pyqt5", 24 | "scikit-image", 25 | "imageio", 26 | "imageio-ffmpeg", 27 | "opencv-python", 28 | "pandas", 29 | "tqdm", 30 | "roifile", 31 | "segmentation_models", 32 | "connected-components-3d", 33 | "albumentations" 34 | ], 35 | classifiers=[ 36 | "Development Status :: 4 - Beta", 37 | "Intended Audience :: Science/Research", 38 | "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", 39 | "Programming Language :: Python :: 3.6", 40 | "Programming Language :: Python :: 3.7", 41 | "Programming Language :: Python :: 3.8", 42 | ], 43 | keywords="spine segmentation", 44 | description="A tool for segmenting dendrites and dendritic spines.", 45 | project_urls={ 46 | "Source": "https://github.com/ankilab/deepd3", 47 | "Tracker": "https://github.com/ankilab/deepd3/issues", 48 | }, 49 | entry_points = { 50 | 'console_scripts': [ 51 | 'deepd3-inference = deepd3.inference.gui:main', 52 | 'deepd3-training = deepd3.training.generator:main' 53 | ] 54 | }, 55 | include_package_data=True, 56 | ) --------------------------------------------------------------------------------