├── .DS_Store
├── .appveyor.yml
├── .coveragerc
├── .gitignore
├── .travis.yml
├── CITATION.cff
├── LICENSE
├── README.md
├── contributing.md
├── requirements.txt
├── sample_data
└── readme.md
├── segmentator
├── __init__.py
├── __main__.py
├── config.py
├── config_filters.py
├── config_gui.py
├── cython
│ ├── deriche_3D.c
│ └── deriche_3D.pyx
├── deriche_prepare.py
├── filter.py
├── filters_ui.py
├── filters_utils.py
├── future
│ ├── readme.md
│ ├── wip_arcweld_mp2rage.py
│ └── wip_arcweld_mprage.py
├── gui_utils.py
├── hist2d_counts.py
├── ncut_prepare.py
├── segmentator_main.py
├── segmentator_ncut.py
├── tests
│ ├── test_utils.py
│ ├── wip_test_arcweld.py
│ └── wip_test_gradient_magnitude.py
└── utils.py
├── setup.py
└── visuals
├── animation_01.gif
├── logo.png
└── logo.svg
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ofgulban/segmentator/01f46b36152734d67495417528fcc384eb567732/.DS_Store
--------------------------------------------------------------------------------
/.appveyor.yml:
--------------------------------------------------------------------------------
1 | build: false
2 |
3 | environment:
4 | matrix:
5 | - PYTHON_VERSION: 3.7
6 | MINICONDA: C:\Miniconda
7 |
8 | init:
9 | - "ECHO %PYTHON_VERSION% %MINICONDA%"
10 |
11 | install:
12 | - "set PATH=%MINICONDA%;%MINICONDA%\\Scripts;%PATH%"
13 | - conda config --set always_yes yes --set changeps1 no
14 | - conda update -q conda
15 | - conda config --add channels conda-forge
16 | - conda info -a
17 | - "conda env create -q -n segmentator python=%PYTHON_VERSION% --file requirements.txt"
18 | - activate segmentator
19 | - pip install compoda
20 | - python setup.py install
21 |
22 | test_script:
23 | - py.test
24 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source = segmentator
3 |
4 | omit =
5 | *__init__*
6 | *__main__*
7 | *config*
8 | *tests/*
9 | segmentator/future/*
10 | segmentator/cython/*
11 | segmentator/sample_data/*
12 | *ui.py
13 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 |
55 | # Sphinx documentation
56 | docs/_build/
57 |
58 | # PyBuilder
59 | target/
60 |
61 | #Ipython Notebook
62 | .ipynb_checkpoints
63 |
64 | bin/
65 |
66 | # MacOSX Stuff
67 | # General
68 | .DS_Store
69 | .AppleDouble
70 | .LSOverride
71 |
72 | # Icon must end with two \r
73 | Icon
74 |
75 |
76 | # Thumbnails
77 | ._*
78 |
79 | # Files that might appear in the root of a volume
80 | .DocumentRevisions-V100
81 | .fseventsd
82 | .Spotlight-V100
83 | .TemporaryItems
84 | .Trashes
85 | .VolumeIcon.icns
86 | .com.apple.timemachine.donotpresent
87 |
88 | # Directories potentially created on remote AFP share
89 | .AppleDB
90 | .AppleDesktop
91 | Network Trash Folder
92 | Temporary Items
93 | .apdisk
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: false
2 | os:
3 | - linux
4 | - osx
5 | matrix:
6 | allow_failures:
7 | - os: osx
8 | language: python
9 | python:
10 | - 2.7
11 | install: # command to install dependencies
12 | - pip install -r requirements.txt
13 | - python setup.py develop
14 | script: # command to run tests
15 | - py.test --cov=./segmentator
16 | after_success:
17 | - bash <(curl -s https://codecov.io/bash)
18 | notifications:
19 | email: false
20 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.1.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: Gulban
5 | given-names: Omer Faruk
6 | orcid: https://orcid.org/0000-0001-7761-3727
7 |
8 | - family-names: Schneider
9 | given-names: Marian
10 | orcid: http://orcid.org/0000-0003-3192-5316
11 |
12 | - family-names: Marquardt
13 | given-names: Ingo
14 | orcid: http://orcid.org/0000-0001-5178-9951
15 |
16 | - family-names: Haast
17 | given-names: Roy
18 | orcid: http://orcid.org/0000-0001-8543-2467
19 |
20 | - family-names: De Martino
21 | given-names: Federico
22 | orcid: https://orcid.org/0000-0002-0352-0648
23 |
24 | title: "A scalable method to improve gray matter segmentation at ultra high field MRI"
25 | version: 1.6.0
26 | doi: https://doi.org/10.1371/journal.pone.0198335
27 | date-released: 2019-09-26
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2019 Omer Faruk Gulban and Marian Schneider
2 |
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 |
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 |
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 |
9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10 |
11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://zenodo.org/badge/latestdoi/59303623)
2 |
3 | # Segmentator
4 |
5 |
6 |
7 | Segmentator is a free and open-source package for multi-dimensional data exploration and segmentation for 3D images. This application is mainly developed and tested using ultra-high field magnetic resonance imaging (MRI) brain data.
8 |
9 |
10 | The goal is to provide a complementary tool to the already available brain tissue segmentation methods (to the best of our knowledge) in other software packages ([FSL](https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/), [CBS-Tools](https://www.cbs.mpg.de/institute/software/cbs-tools), [ITK-SNAP](http://www.itksnap.org/pmwiki/pmwiki.php), [Freesurfer](https://surfer.nmr.mgh.harvard.edu/), [SPM](http://www.fil.ion.ucl.ac.uk/spm/software/spm12/), [Brainvoyager](http://www.brainvoyager.com/), etc.).
11 |
12 | ### Citation:
13 | - Our paper can be accessed from __[this link.](https://doi.org/10.1371/journal.pone.0198335)__
14 | - Released versions of this package can be cited by using our __[Zenodo DOI](https://zenodo.org/badge/latestdoi/59303623).__
15 |
16 |
17 |
18 | ## Core dependencies
19 | **[Python 3.6](https://www.python.org/downloads/release/python-363/)** or **[Python 2.7](https://www.python.org/download/releases/2.7/)** (compatible with both).
20 |
21 | | Package | Tested version |
22 | |------------------------------------------------|----------------|
23 | | [matplotlib](http://matplotlib.org/) | 3.1.1 |
24 | | [NumPy](http://www.numpy.org/) | 1.22.0 |
25 | | [NiBabel](http://nipy.org/nibabel/) | 2.5.1 |
26 | | [SciPy](http://scipy.org/) | 1.3.1 |
27 | | [Compoda](https://github.com/ofgulban/compoda) | 0.3.5 |
28 |
29 | ## Installation & Quick Start
30 | - Download [the latest release](https://github.com/ofgulban/segmentator/releases) and unzip it.
31 | - Change directory in your command line:
32 | ```
33 | cd /path/to/segmentator
34 | ```
35 | - Install the requirements by running the following command:
36 | ```
37 | pip install -r requirements.txt
38 | ```
39 | - Install Segmentator:
40 | ```
41 | python setup.py install
42 | ```
43 | - Simply call segmentator with a nifti file:
44 | ```
45 | segmentator /path/to/file.nii.gz
46 | ```
47 | - Or see the help for available options:
48 | ```
49 | segmentator --help
50 | ```
51 |
52 | Check out __[our wiki](https://github.com/ofgulban/segmentator/wiki)__ for further details such as [GUI controls](https://github.com/ofgulban/segmentator/wiki/Controls), [alternative installation methods](https://github.com/ofgulban/segmentator/wiki/Installation) and more...
53 |
54 | ## Support
55 | Please use [GitHub issues](https://github.com/ofgulban/segmentator/issues) for questions, bug reports or feature requests.
56 |
57 | ## License
58 | Copyright © 2019, [Omer Faruk Gulban](https://github.com/ofgulban) and [Marian Schneider](https://github.com/MSchnei).
59 | This project is licensed under [BSD-3-Clause](https://opensource.org/licenses/BSD-3-Clause).
60 |
61 | ## References
62 | This application is mainly based on the following work:
63 |
64 | * Kniss, J., Kindlmann, G., & Hansen, C. D. (2005). Multidimensional transfer functions for volume rendering. Visualization Handbook, 189–209.
65 |
66 | ## Acknowledgements
67 | Since early 2020, development and maintenance of this project is being actively supported by [Brain Innovation](https://www.brainvoyager.com/) as the main developer ([Omer Faruk Gulban](https://github.com/ofgulban)) works there.
68 |
--------------------------------------------------------------------------------
/contributing.md:
--------------------------------------------------------------------------------
1 | Contributing
2 | ============
3 | If you want to contribute to Segmentator and make it better, your help is very welcome.
4 |
5 | ## Opening an issue
6 |
7 | - Please post your quesions, bug reports and feature requests [here](https://github.com/ofgulban/segmentator/issues).
8 | - When reporting a bug, please specify your system and a complete copy of the error message that you are getting.
9 |
10 | ## Making a pull request
11 |
12 | - Create a personal fork of the project on Github.
13 | - Clone the fork on your local machine.
14 | - Create a new branch to work on. Branch from `devel`.
15 | - Implement/fix your feature, comment your code.
16 | - Follow the code style of the project.
17 | - Add or change the documentation as needed.
18 | - Push your branch to your fork on Github.
19 | - From your fork open a pull request in the correct branch. Target `devel` branch of the original Segmentator repository.
20 |
21 | __Note:__ Please write your commit messages in the present tense. Your commit message should describe what the commit, when applied, does to the code – not what you did to the code.
22 |
23 | ___
24 | This guideline is adapted from [link](https://github.com/MarcDiethelm/contributing/blob/master/README.md)
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.22.0
2 | scipy==1.3.1
3 | matplotlib==3.1.1
4 | nibabel==2.5.1
5 | pytest-cov==2.7.1
6 | compoda==0.3.5
7 |
--------------------------------------------------------------------------------
/sample_data/readme.md:
--------------------------------------------------------------------------------
1 | The dataset used for developing Segmentator can be found at:
2 |
3 | [https://doi.org/10.5281/zenodo.1117858](https://doi.org/10.5281/zenodo.1117858)
4 |
5 |
--------------------------------------------------------------------------------
/segmentator/__init__.py:
--------------------------------------------------------------------------------
1 | """Make the version number available."""
2 |
3 | import pkg_resources # part of setuptools
4 |
5 | __version__ = pkg_resources.require("segmentator")[0].version
6 |
--------------------------------------------------------------------------------
/segmentator/__main__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Entry point.
3 |
4 | Mostly following this example:
5 | https://chriswarrick.com/blog/2014/09/15/python-apps-the-right-way-entry_points-and-scripts/
6 |
7 | Use config.py to hold arguments to be accessed by imported scripts.
8 | """
9 |
10 | from __future__ import print_function
11 | import argparse
12 | import segmentator.config as cfg
13 | from segmentator import __version__
14 |
15 |
16 | def main():
17 | """Command line call argument parsing."""
18 | # Instantiate argument parser object:
19 | parser = argparse.ArgumentParser()
20 |
21 | # Add arguments to namespace:
22 | parser.add_argument(
23 | 'filename', metavar='path',
24 | help="Path to input. Mostly a nifti file with image data."
25 | )
26 | parser.add_argument(
27 | "--gramag", metavar=str(cfg.gramag), required=False,
28 | default=cfg.gramag,
29 | help="'scharr', 'deriche', 'sobel', 'prewitt', 'numpy' \
30 | or path to a gradient magnitude nifti."
31 | )
32 | # used in Deriche filter gradient magnitude computation
33 | parser.add_argument(
34 | "--deriche_alpha", required=False, type=float,
35 | default=cfg.deriche_alpha, metavar=cfg.deriche_alpha,
36 | help="Used only in Deriche gradient magnitude option. Smaller alpha \
37 | values suppress more noise but can dislocate edges. Useful when there \
38 | is strong noise in the input image or the features of interest are at \
39 | a different scale compared to original image resolution."
40 | )
41 | parser.add_argument(
42 | "--scale", metavar=str(cfg.scale), required=False, type=float,
43 | default=cfg.scale,
44 | help="Determines nr of bins. Data is scaled between 0 to this number."
45 | )
46 | parser.add_argument(
47 | "--percmin", metavar=str(cfg.perc_min), required=False, type=float,
48 | default=cfg.perc_min,
49 | help="Minimum percentile used in truncation."
50 | )
51 | parser.add_argument(
52 | "--percmax", metavar=str(cfg.perc_max), required=False, type=float,
53 | default=cfg.perc_max,
54 | help="Maximum percentile used in truncation."
55 | )
56 | parser.add_argument(
57 | "--valmin", metavar=str(cfg.valmin), required=False, type=float,
58 | default=cfg.valmin,
59 | help="Minimum value, overwrites percentile."
60 | )
61 | parser.add_argument(
62 | "--valmax", metavar=str(cfg.valmax), required=False, type=float,
63 | default=cfg.valmax,
64 | help="Maximum value, overwrites percentile."
65 | )
66 | parser.add_argument(
67 | "--cbar_max", metavar=str(cfg.cbar_max), required=False, type=float,
68 | default=cfg.cbar_max,
69 | help="Maximum value (power of 10) of the colorbar slider."
70 | )
71 | parser.add_argument(
72 | "--cbar_init", metavar=str(cfg.cbar_init), required=False,
73 | type=float, default=cfg.cbar_init,
74 | help="Initial value (power of 10) of the colorbar slider. \
75 | Also used with --ncut_prepare flag."
76 | )
77 | parser.add_argument(
78 | "--ncut", metavar='path', required=False,
79 | help="Path to nyp file with ncut labels. Initiates N-cut GUI mode."
80 | )
81 | parser.add_argument(
82 | "--nogui", action='store_true',
83 | help="Only save 2D histogram image without showing GUI."
84 | )
85 | parser.add_argument(
86 | "--include_zeros", action='store_true',
87 | help="Include image zeros in histograms. Not used by default."
88 | )
89 | parser.add_argument(
90 | "--export_gramag", action='store_true',
91 | help="Export the gradient magnitude image. Not used by default."
92 | )
93 | parser.add_argument(
94 | "--force_original_precision", action='store_true',
95 | help="Do not change the data type of the input image. Can be useful \
96 | for very large images. Off by default."
97 | )
98 | parser.add_argument(
99 | "--matplotlib_backend", metavar=str(cfg.matplotlib_backend),
100 | default=cfg.matplotlib_backend, required=False,
101 | help="Change in case of issues during startup or visual glitches. \
102 | Some options are: qt5agg, qt4agg, wxagg, webagg."
103 | )
104 |
105 | # used in ncut preparation
106 | parser.add_argument(
107 | "--ncut_prepare", action='store_true',
108 | help=("------------------(utility feature)------------------ \
109 | Use this flag with the following arguments:")
110 | )
111 | parser.add_argument(
112 | "--ncut_figs", action='store_true',
113 | help="Figures are presented (useful for debugging)."
114 | )
115 | parser.add_argument(
116 | "--ncut_maxRec", required=False, type=int,
117 | default=cfg.max_rec, metavar=cfg.max_rec,
118 | help="Maximum number of recursions."
119 | )
120 | parser.add_argument(
121 | "--ncut_nrSupPix", required=False, type=int,
122 | default=cfg.nr_sup_pix, metavar=cfg.nr_sup_pix,
123 | help="Number of regions/superpixels."
124 | )
125 | parser.add_argument(
126 | "--ncut_compactness", required=False, type=float,
127 | default=cfg.compactness, metavar=cfg.compactness,
128 | help="Compactness balances intensity proximity and space \
129 | proximity of the superpixels. \
130 | Higher values give more weight to space proximity, making \
131 | superpixel shapes more square/cubic. This parameter \
132 | depends strongly on image contrast and on the shapes of \
133 | objects in the image."
134 | )
135 |
136 | # set cfg file variables to be accessed from other scripts
137 | args = parser.parse_args()
138 | # used in all
139 | cfg.filename = args.filename
140 | # used in segmentator GUI (main and ncut)
141 | cfg.gramag = args.gramag
142 | cfg.scale = args.scale
143 | cfg.perc_min = args.percmin
144 | cfg.perc_max = args.percmax
145 | cfg.valmin = args.valmin
146 | cfg.valmax = args.valmax
147 | cfg.cbar_max = args.cbar_max
148 | cfg.cbar_init = args.cbar_init
149 | if args.include_zeros:
150 | cfg.discard_zeros = False
151 | cfg.export_gramag = args.export_gramag
152 | cfg.force_original_precision = args.force_original_precision
153 | cfg.matplotlib_backend = args.matplotlib_backend
154 | # used in ncut preparation
155 | cfg.ncut_figs = args.ncut_figs
156 | cfg.max_rec = args.ncut_maxRec
157 | cfg.nr_sup_pix = args.ncut_nrSupPix
158 | cfg.compactness = args.ncut_compactness
159 | # used in ncut
160 | cfg.ncut = args.ncut
161 | # used in deriche filter
162 | cfg.deriche_alpha = args.deriche_alpha
163 |
164 | welcome_str = 'Segmentator {}'.format(__version__)
165 | welcome_decor = '=' * len(welcome_str)
166 | print('{}\n{}\n{}'.format(welcome_decor, welcome_str, welcome_decor))
167 |
168 | # Call other scripts with import method (couldn't find a better way).
169 | if args.nogui:
170 | print('No GUI option is selected. Saving 2D histogram image...')
171 | import segmentator.hist2d_counts
172 | elif args.ncut_prepare:
173 | print('Preparing N-cut file...')
174 | import segmentator.ncut_prepare
175 | elif args.ncut:
176 | print('N-cut GUI is selected.')
177 | import segmentator.segmentator_ncut
178 | else:
179 | print('Default GUI is selected.')
180 | import segmentator.segmentator_main
181 |
182 |
183 | if __name__ == "__main__":
184 | main()
185 |
--------------------------------------------------------------------------------
/segmentator/config.py:
--------------------------------------------------------------------------------
1 | """This file contains variables that are shared by several modules.
2 |
3 | Useful to hold command line arguments.
4 |
5 | """
6 |
7 | # Define variables used to initialise the sector mask
8 | init_centre = (0, 0)
9 | init_radius = 100
10 | init_theta = (0, 360)
11 |
12 | # Segmentator main command line variables
13 | filename = 'sample_filename_here'
14 | gramag = 'scharr'
15 | deriche_alpha = 3.0
16 | perc_min = 2.5
17 | perc_max = 97.5
18 | valmin = float('nan')
19 | valmax = float('nan')
20 | scale = 400
21 | cbar_max = 5.0
22 | cbar_init = 3.0
23 | discard_zeros = True
24 | export_gramag = False
25 | force_original_precision = False
26 |
27 | # Change in case of glitches in the host operating system
28 | matplotlib_backend = 'tkagg'
29 |
30 | # Possible gradient magnitude computation keyword options
31 | gramag_options = ['scharr', 'sobel', 'prewitt', 'numpy', 'deriche']
32 |
33 | # Used in segmentator ncut
34 | ncut = False
35 | max_rec = 8
36 | nr_sup_pix = 2500
37 | compactness = 2
38 |
--------------------------------------------------------------------------------
/segmentator/config_filters.py:
--------------------------------------------------------------------------------
1 | """This file contains variables that are used in filters."""
2 |
3 | filename = 'sample_filename_here'
4 |
5 | # filter defult parameters
6 | smoothing = 'STEDI'
7 | noise_scale = 0.5
8 | feature_scale = 0.5
9 | nr_iterations = 20
10 | save_every = 10
11 | edge_thr = 0.001
12 | gamma = 1
13 | downsampling = 0
14 | no_nonpositive_mask = False
15 |
--------------------------------------------------------------------------------
/segmentator/config_gui.py:
--------------------------------------------------------------------------------
1 | """This file contains GUI related variables."""
2 |
3 | import matplotlib.pyplot as plt
4 |
5 | # Colormap for mask overlays
6 | palette = plt.cm.Reds
7 | palette.set_over('r', 1.0)
8 | palette.set_under('w', 0)
9 | palette.set_bad('m', 1.0)
10 |
11 | # Slider colors
12 | axcolor = '0.875'
13 | hovcolor = '0.975'
14 |
--------------------------------------------------------------------------------
/segmentator/cython/deriche_3D.pyx:
--------------------------------------------------------------------------------
1 | """3D Deriche filter implementation."""
2 |
3 | import numpy as np
4 | cimport numpy as np
5 |
6 | DTYPE = np.int
7 | ctypedef np.int_t DTYPE_t
8 |
9 | def deriche_3D(np.ndarray[float, mode="c", ndim=3] inputData, float alpha=1):
10 | """Reference: Monga et al. 1991."""
11 | # c definitions
12 | cdef float s, a_0, a_1, a_2, a_3 , b_1, b_2
13 | cdef np.ndarray[double, mode="c", ndim=3] S_p, S_n, R_p, R_n, T_p, T_n
14 | cdef np.ndarray[double, mode="c", ndim=3] S, R, P
15 | cdef int imax = inputData.shape[0]
16 | cdef int jmax = inputData.shape[1]
17 | cdef int kmax = inputData.shape[2]
18 | cdef int i, j, k
19 |
20 | # constants
21 | s = np.power((1 - np.exp(-alpha)), 2) / \
22 | (1 + 2 * alpha * np.exp(-alpha) - np.exp(-2*alpha))
23 | a_0 = s
24 | a_1 = s * (alpha - 1) * np.exp(-alpha)
25 | b_1 = -2 * np.exp(-alpha)
26 | b_2 = np.exp(-2 * alpha)
27 | a_2 = a_1 - s * b_1
28 | a_3 = -s * b_2
29 |
30 | # Recursive filter implementation
31 | S_p = np.zeros((imax, jmax, kmax))
32 | S_n = np.zeros((imax, jmax, kmax))
33 | for k in range(kmax):
34 | for j in range(jmax):
35 | for i in range(0, imax):
36 | if i > 2:
37 | S_p[i, j, k] = inputData[i-1, j, k] \
38 | - b_1*S_p[i-1, j, k] \
39 | - b_2*S_p[i-2, j, k]
40 | i = i*-1 % inputData.shape[0] # inverse index
41 | if i < inputData.shape[0]-2:
42 | S_n[i, j, k] = inputData[i+1, j, k] \
43 | - b_1*S_n[i+1, j, k] \
44 | - b_2*S_n[i+2, j, k]
45 |
46 | S = alpha*(S_p - S_n)
47 |
48 | R_p = np.zeros((imax, jmax, kmax))
49 | R_n = np.zeros((imax, jmax, kmax))
50 | for k in range(kmax):
51 | for i in range(imax):
52 | for j in range(0, jmax):
53 | if j > 2:
54 | R_p[i, j, k] = a_0*S[i, j, k] + a_1*S[i, j-1, k] \
55 | - b_1*R_p[i, j-1, k] - b_2*R_p[i, j-2, k]
56 | j = j*-1 % inputData.shape[1] # inverse index
57 | if j < inputData.shape[1]-2:
58 | R_n[i, j, k] = a_2*S[i, j+1, k] + a_3*S[i, j+2, k] \
59 | - b_1*R_n[i, j+1, k] - b_2*R_n[i, j+2, k]
60 |
61 | R = R_n + R_p
62 |
63 | T_p = np.zeros((imax, jmax, kmax))
64 | T_n = np.zeros((imax, jmax, kmax))
65 | for i in range(imax):
66 | for j in range(jmax):
67 | for k in range(0, kmax):
68 | if k > 2:
69 | T_p[i, j, k] = a_0*R[i, j, k] + a_1*R[i, j, k-1] \
70 | - b_1*T_p[i, j, k-1] - b_2*T_p[i, j, k-2]
71 | k = k*-1 % inputData.shape[2] # inverse index
72 | if k < inputData.shape[2]-2:
73 | T_n[i, j, k] = a_2*R[i, j, k+1] + a_3*R[i, j, k+2] \
74 | - b_1*T_n[i, j, k+1] - b_2*T_n[i, j, k+2]
75 |
76 | T = T_n + T_p
77 | return T
78 |
--------------------------------------------------------------------------------
/segmentator/deriche_prepare.py:
--------------------------------------------------------------------------------
1 | """Calculate gradient magnitude with 3D Deriche filter.
2 |
3 | You can use --graMag flag to pass resulting nifti files from this script.
4 | """
5 |
6 | from segmentator.deriche_3D import deriche_3D
7 | import numpy as np
8 |
9 |
10 | def Deriche_Gradient_Magnitude(image, alpha, normalize=False,
11 | return_gradients=False):
12 | """Compute Deriche gradient magnitude of a volumetric image."""
13 | # calculate gradients
14 | image = np.ascontiguousarray(image, dtype=np.float32)
15 | gra_x = deriche_3D(image, alpha=alpha)
16 | image = np.transpose(image, (2, 0, 1))
17 | image = np.ascontiguousarray(image, dtype=np.float32)
18 | gra_y = deriche_3D(image, alpha=alpha)
19 | gra_y = np.transpose(gra_y, (1, 2, 0))
20 | image = np.transpose(image, (2, 0, 1))
21 | image = np.ascontiguousarray(image, dtype=np.float32)
22 | gra_z = deriche_3D(image, alpha=alpha)
23 | gra_z = np.transpose(gra_z, (2, 0, 1))
24 |
25 | # Put the image gradients in 4D format
26 | gradients = np.array([gra_x, gra_y, gra_z])
27 | gradients = np.transpose(gradients, (1, 2, 3, 0))
28 |
29 | if return_gradients:
30 | return gradients
31 |
32 | else: # Deriche gradient magnitude
33 | gra_mag = np.sqrt(np.power(gradients[:, :, :, 0], 2.0) +
34 | np.power(gradients[:, :, :, 1], 2.0) +
35 | np.power(gradients[:, :, :, 2], 2.0))
36 | if normalize:
37 | min_ima, max_ima = np.percentile(image, [0, 100])
38 | min_der, max_der = np.percentile(gra_mag, [0, 100])
39 | range_ima, range_der = max_ima - min_ima, max_der - min_der
40 |
41 | gra_mag = gra_mag * (range_ima / range_der)
42 | return gra_mag
43 |
--------------------------------------------------------------------------------
/segmentator/filter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Diffusion based image smoothing."""
3 |
4 | from __future__ import division
5 | import os
6 | import numpy as np
7 | import segmentator.config_filters as cfg
8 | from nibabel import load, Nifti1Image, save
9 | from numpy.linalg import eigh
10 | from scipy.ndimage import gaussian_filter
11 | from time import time
12 | from segmentator.filters_utils import (
13 | self_outer_product, dot_product_matrix_vector, divergence,
14 | compute_diffusion_weights, construct_diffusion_tensors,
15 | smooth_matrix_image)
16 | from scipy.ndimage.interpolation import zoom
17 |
18 |
19 | def QC_export(image, basename, identifier, nii):
20 | """Quality control exports."""
21 | out = Nifti1Image(image, affine=nii.affine, header=nii.header)
22 | save(out, '{}_{}.nii.gz'.format(basename, identifier))
23 |
24 |
25 | # Input
26 | file_name = cfg.filename
27 |
28 | # Primary parameters
29 | MODE = cfg.smoothing
30 | NR_ITER = cfg.nr_iterations
31 | SAVE_EVERY = cfg.save_every
32 | SIGMA = cfg.noise_scale
33 | RHO = cfg.feature_scale
34 | LAMBDA = cfg.edge_thr
35 |
36 | # Secondary parameters
37 | GAMMA = cfg.gamma
38 | ALPHA = 0.001
39 | M = 4
40 |
41 | # Export parameters
42 | identifier = MODE
43 |
44 | # Load data
45 | basename = file_name.split(os.extsep, 1)[0]
46 | nii = load(file_name)
47 | vres = nii.header['pixdim'][1:4] # voxel resolution x y z
48 | norm_vres = [r/min(vres) for r in vres] # normalized voxel resolutions
49 | ima = (nii.get_fdata()).astype('float32')
50 |
51 | if cfg.downsampling > 1: # TODO: work in progress
52 | print(' Applying initial downsampling...')
53 | ima = zoom(ima, 1./cfg.downsampling)
54 | orig = np.copy(ima)
55 | else:
56 | pass
57 |
58 | if cfg.no_nonpositive_mask: # TODO: work in progress
59 | idx_msk_flat = np.ones(ima.size, dtype=bool)
60 | else: # mask out non positive voxels
61 | idx_msk_flat = ima.flatten() > 0
62 |
63 | dims = ima.shape
64 |
65 | # The main loop
66 | start = time()
67 | for t in range(NR_ITER):
68 | iteration = str(t+1).zfill(len(str(NR_ITER)))
69 | print("Iteration: " + iteration)
70 | # Update export parameters
71 | params = '{}_n{}_s{}_r{}_g{}'.format(
72 | identifier, iteration, SIGMA, RHO, GAMMA)
73 | params = params.replace('.', 'pt')
74 |
75 | # Smoothing
76 | if SIGMA == 0:
77 | ima_temp = np.copy(ima)
78 | else:
79 | ima_temp = gaussian_filter(
80 | ima, mode='constant', cval=0.0,
81 | sigma=[SIGMA/norm_vres[0], SIGMA/norm_vres[1], SIGMA/norm_vres[2]])
82 |
83 | # Compute gradient
84 | gra = np.transpose(np.gradient(ima_temp), [1, 2, 3, 0])
85 | ima_temp = None
86 |
87 | print(' Constructing structure tensors...')
88 | struct = self_outer_product(gra)
89 |
90 | # Gaussian smoothing on tensor components
91 | struct = smooth_matrix_image(struct, RHO=RHO, vres=norm_vres)
92 |
93 | print(' Running eigen decomposition...')
94 | struct = struct.reshape([np.prod(dims), 3, 3])
95 | struct = struct[idx_msk_flat, :, :]
96 | eigvals, eigvecs = eigh(struct)
97 | struct = None
98 |
99 | print(' Constructing diffusion tensors...')
100 | mu = compute_diffusion_weights(eigvals, mode=MODE, LAMBDA=LAMBDA,
101 | ALPHA=ALPHA, M=M)
102 | difft = construct_diffusion_tensors(eigvecs, weights=mu)
103 | eigvecs, eigvals, mu = None, None, None
104 |
105 | # Reshape processed voxels (not masked) back to image space
106 | temp = np.zeros([np.prod(dims), 3, 3])
107 | temp[idx_msk_flat, :, :] = difft
108 | difft = temp.reshape(dims + (3, 3))
109 | temp = None
110 |
111 | # Weickert, 1998, eq. 1.1 (Fick's law).
112 | negative_flux = dot_product_matrix_vector(difft, gra)
113 | difft, gra = None, None
114 | # Weickert, 1998, eq. 1.2 (continuity equation)
115 | diffusion_difference = divergence(negative_flux)
116 | negative_flux = None
117 |
118 | # Update image (diffuse image using the difference)
119 | ima += GAMMA*diffusion_difference
120 | diffusion_difference = None
121 |
122 | # Convenient exports for intermediate outputs
123 | if (t+1) % SAVE_EVERY == 0 and (t+1) != NR_ITER:
124 | QC_export(ima, basename, params, nii)
125 | duration = time() - start
126 | mins, secs = int(duration / 60), int(duration % 60)
127 | print(' Image saved (took {} min {} sec)'.format(mins, secs))
128 |
129 | if cfg.downsampling > 1: # TODO: work in progress
130 | print(' Final upsampling...')
131 | residual = ima - orig
132 | residual = zoom(residual, cfg.downsampling)
133 | ima = (nii.get_fdata()).astype('float32') + residual
134 | else:
135 | pass
136 |
137 |
138 | print('Saving final image...')
139 | QC_export(ima, basename, params, nii)
140 |
141 | duration = time() - start
142 | mins, secs = int(duration / 60), int(duration % 60)
143 | print(' Finished (Took: {} min {} sec).'.format(mins, secs))
144 |
--------------------------------------------------------------------------------
/segmentator/filters_ui.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Entry point.
3 |
4 | Use config_filters.py to hold parameters.
5 | """
6 |
7 | from __future__ import print_function
8 | import argparse
9 | import segmentator.config_filters as cfg
10 | from segmentator import __version__
11 |
12 |
13 | def main():
14 | """Command line call argument parsing."""
15 | # Instantiate argument parser object:
16 | parser = argparse.ArgumentParser()
17 |
18 | # Add arguments to namespace:
19 | parser.add_argument(
20 | 'filename', metavar='path',
21 | help="Path to input. A nifti file with image data."
22 | )
23 | parser.add_argument(
24 | "--smoothing", metavar=str(cfg.smoothing), required=False,
25 | default=cfg.smoothing,
26 | help="Variants are CURED and STEDI. Use CURED for removing tubular,\
27 | honeycomb-like structures as well as smoothing isotropic areas.\
28 | STEDI is the more conservative version that retains more edges."
29 | )
30 | # parser.add_argument(
31 | # "--edge_thr", metavar=str(cfg.edge_thr), required=False,
32 | # type=float, default=cfg.edge_thr,
33 | # help="Lambda, edge threshold, lower values preserves more edges. Not\
34 | # used in CURED and STEDI."
35 | # )
36 | parser.add_argument(
37 | "--noise_scale", metavar=str(cfg.noise_scale), required=False,
38 | type=float, default=cfg.noise_scale,
39 | help="Sigma, determines the spatial scale of the noise that will be \
40 | corrected. Recommended lower bound is 0.5. Adjusted for each axis \
41 | to account for non-isotropic voxels. For example, if the selected \
42 | value is 1 for an image with [0.7, 0.7, 1.4] mm voxels, sigma is \
43 | adjusted to be [1, 1, 0.5] in the corresponding axes."
44 | )
45 | parser.add_argument(
46 | "--feature_scale", metavar=str(cfg.feature_scale), required=False,
47 | type=float, default=cfg.noise_scale,
48 | help="Rho, determines the spatial scale of the features that will be \
49 | enhanced. Recommended lower bound is 0.5. Adjusted for each axis \
50 | to account for non-isotropic voxels same way as in --noise_scale."
51 | )
52 | parser.add_argument(
53 | "--gamma", metavar=str(cfg.gamma), required=False,
54 | type=float, default=cfg.gamma,
55 | help="Strength of the updates in every iteration. Recommended range is\
56 | 0.5 to 2."
57 | )
58 | parser.add_argument(
59 | "--nr_iterations", metavar=str(cfg.nr_iterations), required=False,
60 | type=int, default=cfg.nr_iterations,
61 | help="Number of maximum iterations. More iterations will produce \
62 | smoother images."
63 | )
64 | parser.add_argument(
65 | "--save_every", metavar=str(cfg.save_every), required=False,
66 | type=int, default=cfg.save_every,
67 | help="Save every Nth iterations. Useful to track the effect of \
68 | smoothing as it evolves."
69 | )
70 | parser.add_argument(
71 | "--downsampling", metavar=str(cfg.downsampling), required=False,
72 | type=int, default=cfg.downsampling,
73 | help="(!WIP!) Downsampling factor, use integers > 1. E.g. factor of 2 \
74 | reduces the amount of voxels 8 times."
75 | )
76 | parser.add_argument(
77 | "--no_nonpositive_mask", action='store_true',
78 | help="(!WIP!) Do not mask out non-positive values."
79 | )
80 |
81 | # set cfg file variables to be accessed from other scripts
82 | args = parser.parse_args()
83 | cfg.filename = args.filename
84 | cfg.smoothing = args.smoothing
85 | # cfg.edge_thr = args.edge_thr # lambda
86 | cfg.noise_scale = args.noise_scale # sigma
87 | cfg.feature_scale = args.feature_scale # rho
88 | cfg.gamma = args.gamma
89 | cfg.nr_iterations = args.nr_iterations
90 | cfg.save_every = args.save_every
91 | cfg.downsampling = args.downsampling
92 | cfg.no_nonpositive_mask = args.no_nonpositive_mask
93 |
94 | welcome_str = 'Segmentator {}'.format(__version__)
95 | welcome_decor = '=' * len(welcome_str)
96 | print('{}\n{}\n{}'.format(welcome_decor, welcome_str, welcome_decor))
97 | print('Filters initiated...')
98 |
99 | import segmentator.filter
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/segmentator/filters_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Common functions used in filters."""
3 |
4 | from __future__ import division
5 | import numpy as np
6 | from scipy.ndimage.filters import gaussian_filter
7 |
8 |
9 | def self_outer_product(vector_field):
10 | """Vectorized computation of outer product.
11 |
12 | Parameters
13 | ----------
14 | vector_field: np.ndarray, shape(..., 3)
15 |
16 | Returns
17 | -------
18 | outer: np.ndarray, shape(..., 3, 3)
19 | """
20 | dims = vector_field.shape
21 | outer = np.repeat(vector_field, dims[-1], axis=-1)
22 | outer *= outer[..., [0, 3, 6, 1, 4, 7, 2, 5, 8]]
23 | outer = outer.reshape(dims[:-1] + (dims[-1], dims[-1]))
24 | return outer
25 |
26 |
27 | def dot_product_matrix_vector(matrix_field, vector_field):
28 | """Vectorized computation of dot product."""
29 | dims = vector_field.shape
30 | dotp = np.repeat(vector_field, dims[-1], axis=-1)
31 | dotp = dotp.reshape(dims[:-1] + (dims[-1], dims[-1]))
32 | idx_dims = tuple(range(dotp.ndim))
33 | dotp = dotp.transpose(idx_dims[:-2] + (idx_dims[-1], idx_dims[-2]))
34 | np.multiply(matrix_field, dotp, out=dotp)
35 | dotp = np.sum(dotp, axis=-1)
36 | return dotp
37 |
38 |
39 | def divergence(vector_field):
40 | """Vectorized computation of divergence, also called Laplacian."""
41 | dims = vector_field.shape
42 | result = np.zeros(dims[:-1])
43 | for i in range(dims[-1]):
44 | result += np.gradient(vector_field[..., i], axis=i)
45 | return result
46 |
47 |
48 | def compute_diffusion_weights(eigvals, mode, LAMBDA=0.001, ALPHA=0.001, M=4):
49 | """Vectorized computation diffusion weights.
50 |
51 | References
52 | ----------
53 | - Weickert, J. (1998). Anisotropic diffusion in image processing.
54 | Image Rochester NY, 256(3), 170.
55 | - Mirebeau, J.-M., Fehrenbach, J., Risser, L., & Tobji, S. (2015).
56 | Anisotropic Diffusion in ITK, 1-9.
57 | """
58 | idx_pos_e2 = eigvals[:, 1] > 0 # positive indices of second eigen value
59 | c = (1. - ALPHA) # related to matrix condition
60 |
61 | if mode in ['EED', 'cEED', 'iEED']: # TODO: I might remove these.
62 |
63 | if mode == 'EED': # edge enhancing diffusion
64 | mu = np.ones(eigvals.shape)
65 | term1 = LAMBDA
66 | term2 = eigvals[idx_pos_e2, 1:] - eigvals[idx_pos_e2, 0, None]
67 | mu[idx_pos_e2, 1:] = 1. - c * np.exp(-(term1/term2)**M)
68 | # weights for the non-positive eigen values
69 | mu[~idx_pos_e2, 2] = ALPHA # surely surfels
70 |
71 | elif mode == 'cEED': # FIXME: Not working at all, for now...
72 | term1 = LAMBDA
73 | term2 = eigvals
74 | mu = 1. - c * np.exp(-(term1/term2)**M)
75 |
76 | elif mode in ['CED', 'cCED']:
77 |
78 | if mode == 'CED': # coherence enhancing diffusion (FIXME: not tested)
79 | mu = np.ones(eigvals.shape) * ALPHA
80 | term1 = LAMBDA
81 | term2 = eigvals[:, 2, None] - eigvals[:, :-1]
82 | mu[:, :-1] = ALPHA + c * np.exp(-(term1/term2)**M)
83 |
84 | elif mode == 'cCED': # conservative coherence enhancing diffusion
85 | mu = np.ones(eigvals.shape) * ALPHA
86 | term1 = LAMBDA + eigvals[:, 0:2]
87 | term2 = eigvals[:, 2, None] - eigvals[:, 0:2]
88 | mu[:, 0:2] = ALPHA + c * np.exp(-(term1/term2)**M)
89 |
90 | elif mode == 'CURED': # NOTE: Somewhat experimental
91 | import compoda.core as coda
92 | mu = np.ones(eigvals.shape)
93 | mu[idx_pos_e2, :] = 1. - coda.closure(eigvals[idx_pos_e2, :])
94 |
95 | elif mode == 'STEDI': # NOTE: Somewhat more experimental
96 | import compoda.core as coda
97 | mu = np.ones(eigvals.shape)
98 | eigs = eigvals[idx_pos_e2, :]
99 | term1 = coda.closure(eigs)
100 | term2 = np.abs((np.max(term1, axis=-1) - np.min(term1, axis=-1)) - 0.5)
101 | term2 += 0.5
102 | mu[idx_pos_e2, :] = np.abs(term2[:, None] - term1)
103 |
104 | else:
105 | mu = np.ones(eigvals.shape)
106 | print(' Invalid smoothing mesthod. Weights are all set to ones.')
107 |
108 | return mu
109 |
110 |
111 | def construct_diffusion_tensors(eigvecs, weights):
112 | """Vectorized consruction of diffusion tensors."""
113 | dims = eigvecs.shape
114 | D = np.zeros(dims[:-2] + (dims[-1], dims[-1]))
115 | for i in range(dims[-1]): # weight vectors
116 | D += weights[:, i, None, None] * self_outer_product(eigvecs[..., i])
117 | return D
118 |
119 |
120 | def smooth_matrix_image(matrix_image, RHO=0, vres=None):
121 | """Gaussian smoothing applied to matrix image."""
122 | if vres is None:
123 | vres = [1., 1., 1.]
124 | if RHO == 0:
125 | return matrix_image
126 | else:
127 | dims = matrix_image.shape
128 | for x in range(dims[-2]):
129 | for y in range(dims[-1]):
130 | gaussian_filter(matrix_image[..., x, y],
131 | sigma=[RHO/vres[0], RHO/vres[1], RHO/vres[2]],
132 | mode='constant', cval=0.0,
133 | output=matrix_image[..., x, y])
134 | return matrix_image
135 |
--------------------------------------------------------------------------------
/segmentator/future/readme.md:
--------------------------------------------------------------------------------
1 | This is a folder for work in progress components.
2 |
--------------------------------------------------------------------------------
/segmentator/future/wip_arcweld_mp2rage.py:
--------------------------------------------------------------------------------
1 | """Full automation experiments for T1w-like data (MPRAGE & MP2RAGE).
2 |
3 | TODO:
4 | - Put anisotropic diffusion based smoothing to segmentator utilities.
5 | - MP2RAGE, 2 type of gray matter giving issues, barycentric weights might
6 | be useful to deal with this issue.
7 |
8 | """
9 |
10 | import os
11 | import peakutils
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 | from nibabel import load, Nifti1Image, save
15 | from scipy.ndimage.filters import gaussian_filter1d
16 | from segmentator.utils import compute_gradient_magnitude, aniso_diff_3D
17 |
18 | # load
19 | nii = load('/home/faruk/gdrive/Segmentator/data/faruk/arcweld/mp2rage_S001_restore.nii.gz')
20 | ima = nii.get_fdata()
21 | basename = nii.get_filename().split(os.extsep, 1)[0]
22 |
23 | # non-zero mask
24 | msk = (ima != 0) # TODO: Parametrize
25 | #
26 | ima[msk] = ima[msk] + np.min(ima)
27 |
28 | # aniso. diff. filter
29 | ima = aniso_diff_3D(ima, niter=2, kappa=500, gamma=0.1, option=1)
30 |
31 | # calculate gradient magnitude
32 | gra = compute_gradient_magnitude(ima, method='3D_sobel')
33 |
34 | # save for debugging
35 | # out = Nifti1Image(gra.reshape(nii.shape), affine=nii.affine)
36 | # save(out, basename + '_gra' + '.nii.gz')
37 |
38 | ima_max = np.percentile(ima, 100)
39 |
40 | # reshape for histograms
41 | ima, gra, msk = ima.flatten(), gra.flatten(), msk.flatten()
42 | gra_thr = np.percentile(gra[msk], 20)
43 |
44 | # select low gradient magnitude regime (lr)
45 | msk_lr = (gra < gra_thr) & (msk)
46 | n, bins, _ = plt.hist(ima[msk_lr], 200, range=(0, ima_max))
47 |
48 | # smooth histogram (good for peak detection)
49 | n = gaussian_filter1d(n, 1)
50 |
51 | # detect 'pure tissue' peaks (TODO: Checks for always finding 3 peaks)
52 | peaks = peakutils.indexes(n, thres=0.01/max(n), min_dist=40)
53 | # peaks = peaks[0:-1]
54 | tissues = []
55 | for p in peaks:
56 | tissues.append(bins[p])
57 | tissues = np.array(tissues)
58 | print peaks
59 | print tissues
60 |
61 | # insert zero-max arc
62 | zmax_max = ima_max
63 | zmax_center = (0 + zmax_max) / 2.
64 | zmax_radius = zmax_max - zmax_center
65 | tissues = np.append(tissues, zmax_center)
66 |
67 | # create smooth maps (distance to pure tissue)
68 | voxels = np.vstack([ima, gra])
69 | soft = [] # to hold soft tissue membership maps
70 | for i, t in enumerate(tissues):
71 | tissue = np.array([t, 0])
72 | # euclidean distance
73 | edist = np.sqrt(np.sum((voxels - tissue[:, None])**2., axis=0))
74 | soft.append(edist)
75 | # save intermediate maps
76 | # out = Nifti1Image(edist.reshape(nii.shape), affine=nii.affine)
77 | # save(out, basename + '_t' + str(i) + '.nii.gz')
78 | soft = np.array(soft)
79 |
80 | # interface translation (shift zero circle to another radius)
81 | soft[-1, :] = soft[-1, :] - zmax_radius
82 | # zmax_neg = soft[-1, :] < 0 # voxels fall inside zero-max arc
83 |
84 | # weight zero-max arc to not coincide with pure class regions
85 | # zmax_weight = (gra[zmax_neg] / zmax_radius)**-1
86 | zmax_weight = (gra / zmax_radius)**-1
87 | # soft[-1, :][zmax_neg] = zmax_weight*np.abs(soft[-1, :][zmax_neg])
88 | soft[-1, :] = zmax_weight*np.abs(soft[-1, :])
89 | # save for debugging
90 | # out = Nifti1Image(soft[-1, :].reshape(nii.shape), affine=nii.affine)
91 | # save(out, basename + '_zmaxarc' + '.nii.gz')
92 |
93 | # arbitrary weighting (TODO: Can be turned into config file of some sort)
94 | # save these values for MP2RAGE UNI
95 | # soft[0, :] = soft[0, :] * 1 # csf
96 | # soft[-1, :] = soft[-1, :] * 0.5 # zero-max arc
97 |
98 | # hard tissue membership maps
99 | hard = np.argmin(soft, axis=0)
100 |
101 | # append masked out areas
102 | hard = hard + 1
103 | hard[~msk] = 0
104 |
105 | # save hard classification
106 | out = Nifti1Image(hard.reshape(nii.shape), affine=nii.affine)
107 | save(out, basename + '_arcweld' + '.nii.gz')
108 |
109 | # save segmentator polish mask (not GM and WM)
110 | labels = np.unique(hard)[[0, 1, -1]]
111 | polish = np.ones(hard.shape)
112 | polish[hard == labels[0]] = 0
113 | polish[hard == labels[1]] = 0
114 | polish[hard == labels[2]] = 0
115 | out = Nifti1Image(polish.reshape(nii.shape), affine=nii.affine)
116 | save(out, basename + '_arcweld_mask' + '.nii.gz')
117 | print 'Finished.'
118 |
119 | # plt.show()
120 |
--------------------------------------------------------------------------------
/segmentator/future/wip_arcweld_mprage.py:
--------------------------------------------------------------------------------
1 | """Full automation experiments for T1w-like data (MPRAGE & MP2RAGE).
2 |
3 | TODO:
4 | - Put anisotropic diffusion based smoothing to segmentator utilities.
5 | - MP2RAGE, 2 type of gray matter giving issues, barycentric weights might
6 | be useful to deal with this issue.
7 |
8 | """
9 |
10 | import os
11 | import peakutils
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 | from nibabel import load, Nifti1Image, save
15 | from scipy.ndimage.filters import gaussian_filter1d
16 | from segmentator.utils import compute_gradient_magnitude, aniso_diff_3D
17 |
18 | # load
19 | nii = load('/home/faruk/gdrive/Segmentator/data/faruk/arcweld/mprage_S02_restore.nii.gz')
20 | ima = nii.get_fdata()
21 | basename = nii.get_filename().split(os.extsep, 1)[0]
22 |
23 | # non-zero mask
24 | msk = ima > 0 # TODO: Parametrize
25 |
26 | # aniso. diff. filter
27 | ima = aniso_diff_3D(ima, niter=2, kappa=50, gamma=0.1, option=1)
28 |
29 | # calculate gradient magnitude
30 | gra = compute_gradient_magnitude(ima, method='3D_sobel')
31 |
32 | # save for debugging
33 | # out = Nifti1Image(gra.reshape(nii.shape), affine=nii.affine)
34 | # save(out, basename + '_gra' + '.nii.gz')
35 |
36 | ima_max = np.percentile(ima, 99)
37 |
38 | # reshape for histograms
39 | ima, gra, msk = ima.flatten(), gra.flatten(), msk.flatten()
40 | gra_thr = np.percentile(gra[msk], 20)
41 |
42 | # select low gradient magnitude regime (lr)
43 | msk_lr = (gra < gra_thr) & (msk)
44 | n, bins, _ = plt.hist(ima[msk_lr], 200, range=(0, ima_max))
45 |
46 | # smooth histogram (good for peak detection)
47 | n = gaussian_filter1d(n, 1)
48 |
49 | # detect 'pure tissue' peaks (TODO: Checks for always finding 3 peaks)
50 | peaks = peakutils.indexes(n, thres=0.01/max(n), min_dist=40)
51 | # peaks = peaks[0:-1]
52 | tissues = []
53 | for p in peaks:
54 | tissues.append(bins[p])
55 | tissues = np.array(tissues)
56 | print peaks
57 | print tissues
58 |
59 | # insert zero-max arc
60 | zmax_max = tissues[-1] + tissues[-1] - tissues[1] # first gm and wm
61 | zmax_center = (0 + zmax_max) / 2.
62 | zmax_radius = zmax_max - zmax_center
63 | tissues = np.append(tissues, zmax_center)
64 |
65 | # create smooth maps (distance to pure tissue)
66 | voxels = np.vstack([ima, gra])
67 | soft = [] # to hold soft tissue membership maps
68 | for i, t in enumerate(tissues):
69 | tissue = np.array([t, 0])
70 | # euclidean distance
71 | edist = np.sqrt(np.sum((voxels - tissue[:, None])**2., axis=0))
72 | soft.append(edist)
73 | # save intermediate maps
74 | # out = Nifti1Image(edist.reshape(nii.shape), affine=nii.affine)
75 | # save(out, basename + '_t' + str(i) + '.nii.gz')
76 | soft = np.array(soft)
77 |
78 | # interface translation (shift zero circle to another radius)
79 | soft[-1, :] = soft[-1, :] - zmax_radius
80 | # zmax_neg = soft[-1, :] < 0 # voxels fall inside zero-max arc
81 |
82 | # weight zero-max arc to not coincide with pure class regions
83 | # zmax_weight = (gra[zmax_neg] / zmax_radius)**-1
84 | zmax_weight = (gra / zmax_radius)**-1
85 | # soft[-1, :][zmax_neg] = zmax_weight*np.abs(soft[-1, :][zmax_neg])
86 | soft[-1, :] = zmax_weight*np.abs(soft[-1, :])
87 | # save for debugging
88 | # out = Nifti1Image(soft[-1, :].reshape(nii.shape), affine=nii.affine)
89 | # save(out, basename + '_zmaxarc' + '.nii.gz')
90 |
91 | # arbitrary weighting (TODO: Can be turned into config file of some sort)
92 | # save these values for MPRAGE T1w/PDw
93 | # soft[0, :] = soft[0, :] * 0.66 # csf
94 | # # soft[2, :] = soft[2, :] * 1.25 # wm
95 | # soft[3, :] = soft[3, :] * 0.5 # zero-max arcs
96 |
97 | # hard tissue membership maps
98 | hard = np.argmin(soft, axis=0)
99 |
100 | # append masked out areas
101 | hard = hard + 1
102 | hard[~msk] = 0
103 |
104 | # save intermediate maps
105 | out = Nifti1Image(hard.reshape(nii.shape), affine=nii.affine)
106 | save(out, basename + '_arcweld' + '.nii.gz')
107 |
108 | # save segmentator polish mask (not GM and WM)
109 | labels = np.unique(hard)[[0, 1, -1]]
110 | polish = np.ones(hard.shape)
111 | polish[hard == labels[0]] = 0
112 | polish[hard == labels[1]] = 0
113 | polish[hard == labels[2]] = 0
114 | out = Nifti1Image(polish.reshape(nii.shape), affine=nii.affine)
115 | save(out, basename + '_arcweld_mask' + '.nii.gz')
116 | print 'Finished.'
117 |
118 | # plt.show()
119 |
--------------------------------------------------------------------------------
/segmentator/gui_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Functions covering the user interaction with the GUI."""
3 |
4 | from __future__ import division, print_function
5 | import os
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import segmentator.config as cfg
9 | from segmentator.utils import map_2D_hist_to_ima
10 | from nibabel import save, Nifti1Image
11 | from scipy.ndimage.morphology import binary_erosion
12 |
13 |
14 | class responsiveObj:
15 | """Stuff to interact in the user interface."""
16 |
17 | def __init__(self, **kwargs):
18 | """Initialize variables used acros functions here."""
19 | if kwargs is not None:
20 | for key, value in kwargs.items():
21 | setattr(self, key, value)
22 | self.basename = self.nii.get_filename().split(os.extsep, 1)[0]
23 | self.press = None
24 | self.ctrlHeld = False
25 | self.labelNr = 0
26 | self.imaSlcMskSwitch, self.volHistHighlightSwitch = 0, 0
27 | self.TranspVal = 0.5
28 | self.nrExports = 0
29 | self.borderSwitch = 0
30 | self.imaSlc = self.orig[:, :, self.sliceNr] # selected slice
31 | self.cycleCount = 0
32 | self.cycRotHistory = [[0, 0], [0, 0], [0, 0]]
33 | self.highlights = [[], []] # to hold image to histogram circles
34 |
35 | def remapMsks(self, remap_slice=True):
36 | """Update volume histogram to image mapping.
37 |
38 | Parameters
39 | ----------
40 | remap_slice : bool
41 | Do histogram to image mapping. Used to map displayed slice mask.
42 |
43 | """
44 | if self.segmType == 'main':
45 | self.volHistMask = self.sectorObj.binaryMask()
46 | self.volHistMask = self.lassoArr(self.volHistMask, self.idxLasso)
47 | self.volHistMaskH.set_data(self.volHistMask)
48 | elif self.segmType == 'ncut':
49 | self.labelContours()
50 | self.volHistMaskH.set_data(self.volHistMask)
51 | self.volHistMaskH.set_extent((0, self.nrBins, self.nrBins, 0))
52 | # histogram to image mapping
53 | if remap_slice:
54 | temp_slice = self.invHistVolume[:, :, self.sliceNr]
55 | image_slice_shape = self.invHistVolume[:, :, self.sliceNr].shape
56 | if cfg.discard_zeros:
57 | zmask = temp_slice != 0
58 | image_slice_mask = map_2D_hist_to_ima(temp_slice[zmask],
59 | self.volHistMask)
60 | # reshape to image slice shape
61 | self.imaSlcMsk = np.zeros(image_slice_shape)
62 | self.imaSlcMsk[zmask] = image_slice_mask
63 | else:
64 | image_slice_mask = map_2D_hist_to_ima(temp_slice.flatten(),
65 | self.volHistMask)
66 | # reshape to image slice shape
67 | self.imaSlcMsk = image_slice_mask.reshape(image_slice_shape)
68 |
69 | # for optional border visualization
70 | if self.borderSwitch == 1:
71 | self.imaSlcMsk = self.calcImaMaskBrd()
72 |
73 | def updatePanels(self, update_slice=True, update_rotation=False,
74 | update_extent=False):
75 | """Update histogram and image panels."""
76 | if update_rotation:
77 | self.checkRotation()
78 | if update_extent:
79 | self.updateImaExtent()
80 | if update_slice:
81 | self.imaSlcH.set_data(self.imaSlc)
82 | self.imaSlcMskH.set_data(self.imaSlcMsk)
83 | self.figure.canvas.draw()
84 |
85 | def connect(self):
86 | """Make the object responsive."""
87 | self.cidpress = self.figure.canvas.mpl_connect(
88 | 'button_press_event', self.on_press)
89 | self.cidrelease = self.figure.canvas.mpl_connect(
90 | 'button_release_event', self.on_release)
91 | self.cidmotion = self.figure.canvas.mpl_connect(
92 | 'motion_notify_event', self.on_motion)
93 | self.cidkeypress = self.figure.canvas.mpl_connect(
94 | 'key_press_event', self.on_key_press)
95 | self.cidkeyrelease = self.figure.canvas.mpl_connect(
96 | 'key_release_event', self.on_key_release)
97 |
98 | def on_key_press(self, event):
99 | """Determine what happens when a keyboard button is pressed."""
100 | if event.key == 'control':
101 | self.ctrlHeld = True
102 | elif event.key == '1':
103 | self.imaSlcMskIncr(-0.1)
104 | elif event.key == '2':
105 | self.imaSlcMskTransSwitch()
106 | elif event.key == '3':
107 | self.imaSlcMskIncr(0.1)
108 | elif event.key == '4':
109 | self.volHistHighlightTransSwitch()
110 | elif event.key == '5':
111 | self.borderSwitch = (self.borderSwitch + 1) % 2
112 | self.remapMsks()
113 | self.updatePanels(update_slice=False, update_rotation=True,
114 | update_extent=True)
115 |
116 | if self.segmType == 'main':
117 | if event.key == 'up':
118 | self.sectorObj.scale_r(1.05)
119 | self.remapMsks()
120 | self.updatePanels(update_slice=False, update_rotation=True,
121 | update_extent=False)
122 | elif event.key == 'down':
123 | self.sectorObj.scale_r(0.95)
124 | self.remapMsks()
125 | self.updatePanels(update_slice=False, update_rotation=True,
126 | update_extent=False)
127 | elif event.key == 'right':
128 | self.sectorObj.rotate(-10.0)
129 | self.remapMsks()
130 | self.updatePanels(update_slice=False, update_rotation=True,
131 | update_extent=False)
132 | elif event.key == 'left':
133 | self.sectorObj.rotate(10.0)
134 | self.remapMsks()
135 | self.updatePanels(update_slice=True, update_rotation=True,
136 | update_extent=False)
137 | else:
138 | return
139 |
140 | def on_key_release(self, event):
141 | """Determine what happens if key is released."""
142 | if event.key == 'control':
143 | self.ctrlHeld = False
144 |
145 | def findVoxInHist(self, event):
146 | """Find voxel's location in histogram."""
147 | self.press = event.xdata, event.ydata
148 | pixel_x = int(np.floor(event.xdata))
149 | pixel_y = int(np.floor(event.ydata))
150 | aoi = self.invHistVolume[:, :, self.sliceNr] # array of interest
151 | # Check rotation
152 | cyc_rot = self.cycRotHistory[self.cycleCount][1]
153 | if cyc_rot == 1: # 90
154 | aoi = np.rot90(aoi, axes=(0, 1))
155 | elif cyc_rot == 2: # 180
156 | aoi = aoi[::-1, ::-1]
157 | elif cyc_rot == 3: # 270
158 | aoi = np.rot90(aoi, axes=(1, 0))
159 | # Switch x and y voxel to get linear index since not Cartesian!!!
160 | pixelLin = aoi[pixel_y, pixel_x]
161 | # ind2sub
162 | xpix = (pixelLin / self.nrBins)
163 | ypix = (pixelLin % self.nrBins)
164 | # Switch x and y for circle centre since back to Cartesian.
165 | circle_colors = [np.array([8, 48, 107, 255])/255,
166 | np.array([33, 113, 181, 255])/255]
167 | self.highlights[0].append(plt.Circle((ypix, xpix), radius=1,
168 | edgecolor=None, color=circle_colors[0]))
169 | self.highlights[1].append(plt.Circle((ypix, xpix), radius=5,
170 | edgecolor=None, color=circle_colors[1]))
171 | self.axes.add_artist(self.highlights[0][-1]) # small circle
172 | self.axes.add_artist(self.highlights[1][-1]) # large circle
173 | self.figure.canvas.draw()
174 |
175 | def on_press(self, event):
176 | """Determine what happens if mouse button is clicked."""
177 | if self.segmType == 'main':
178 | if event.button == 1: # left button
179 | if event.inaxes == self.axes: # cursor in left plot (hist)
180 | if self.ctrlHeld is False: # ctrl no
181 | contains = self.contains(event)
182 | if not contains:
183 | print('cursor outside circle mask')
184 | if not contains:
185 | return
186 | # get sector centre x and y positions
187 | x0 = self.sectorObj.cx
188 | y0 = self.sectorObj.cy
189 | # also get cursor x and y position and safe to press
190 | self.press = x0, y0, event.xdata, event.ydata
191 | elif event.inaxes == self.axes2: # cursor in right plot (brow)
192 | self.findVoxInHist(event)
193 | else:
194 | return
195 | elif event.button == 2: # scroll button
196 | if event.inaxes != self.axes: # outside axes
197 | return
198 | # increase/decrease radius of the sector mask
199 | if self.ctrlHeld is False: # ctrl no
200 | self.sectorObj.scale_r(1.05)
201 | self.remapMsks()
202 | self.updatePanels(update_slice=False, update_rotation=True,
203 | update_extent=False)
204 | elif self.ctrlHeld is True: # ctrl yes
205 | self.sectorObj.rotate(10.0)
206 | self.remapMsks()
207 | self.updatePanels(update_slice=False, update_rotation=True,
208 | update_extent=False)
209 | else:
210 | return
211 | elif event.button == 3: # right button
212 | if event.inaxes != self.axes:
213 | return
214 | # rotate the sector mask
215 | if self.ctrlHeld is False: # ctrl no
216 | self.sectorObj.scale_r(0.95)
217 | self.remapMsks()
218 | self.updatePanels(update_slice=False, update_rotation=True,
219 | update_extent=False)
220 | elif self.ctrlHeld is True: # ctrl yes
221 | self.sectorObj.rotate(-10.0)
222 | self.remapMsks()
223 | self.updatePanels(update_slice=False, update_rotation=True,
224 | update_extent=False)
225 | else:
226 | return
227 | elif self.segmType == 'ncut':
228 | if event.button == 1: # left mouse button
229 | if event.inaxes == self.axes: # cursor in left plot (hist)
230 | xbin = int(np.floor(event.xdata))
231 | ybin = int(np.floor(event.ydata))
232 | val = self.volHistMask[ybin][xbin]
233 | # increment counterField for values in clicked subfield, at
234 | # the first click the entire field constitutes the subfield
235 | counter = int(self.counterField[ybin][xbin])
236 | if counter+1 >= self.ima_ncut_labels.shape[2]:
237 | print("already at maximum ncut dimension")
238 | return
239 | self.counterField[(
240 | self.ima_ncut_labels[:, :, counter] ==
241 | self.ima_ncut_labels[[ybin], [xbin], counter])] += 1
242 | print("counter:" + str(counter+1))
243 | # define arrays with old and new labels for later indexing
244 | oLabels = self.ima_ncut_labels[:, :, counter]
245 | nLabels = self.ima_ncut_labels[:, :, counter+1]
246 | # replace old values with new values (in clicked subfield)
247 | self.volHistMask[oLabels == val] = np.copy(
248 | nLabels[oLabels == val])
249 | self.remapMsks()
250 | self.updatePanels(update_slice=False, update_rotation=True,
251 | update_extent=False)
252 |
253 | elif event.inaxes == self.axes2: # cursor in right plot (brow)
254 | self.findVoxInHist(event)
255 | else:
256 | return
257 | elif event.button == 3: # right mouse button
258 | if event.inaxes == self.axes: # cursor in left plot (hist)
259 | xbin = int(np.floor(event.xdata))
260 | ybin = int(np.floor(event.ydata))
261 | val = self.volHistMask[ybin][xbin]
262 | # fetch the slider value to get label nr
263 | self.volHistMask[self.volHistMask == val] = \
264 | np.copy(self.labelNr)
265 | self.remapMsks()
266 | self.updatePanels(update_slice=False, update_rotation=True,
267 | update_extent=False)
268 |
269 | def on_motion(self, event):
270 | """Determine what happens if mouse button moves."""
271 | if self.segmType == 'main':
272 | # ... button is pressed
273 | if self.press is None:
274 | return
275 | # ... cursor is in left plot
276 | if event.inaxes != self.axes:
277 | return
278 | # get former sector centre x and y positions,
279 | # cursor x and y positions
280 | x0, y0, xpress, ypress = self.press
281 | # calculate difference betw cursor pos on click
282 | # and new pos dur motion
283 | # switch x0 & y0 cause volHistMask not Cart
284 | dy = event.xdata - xpress
285 | dx = event.ydata - ypress
286 | # update x and y position of sector,
287 | # based on past motion of cursor
288 | self.sectorObj.set_x(x0 + dx)
289 | self.sectorObj.set_y(y0 + dy)
290 | # update masks
291 | self.remapMsks()
292 | self.updatePanels(update_slice=False, update_rotation=True,
293 | update_extent=False)
294 | else:
295 | return
296 |
297 | def on_release(self, event):
298 | """Determine what happens if mouse button is released."""
299 | self.press = None
300 | # Remove highlight circle
301 | if self.highlights[1]:
302 | self.highlights[1][-1].set_visible(False)
303 | self.figure.canvas.draw()
304 |
305 | def disconnect(self):
306 | """Make the object unresponsive."""
307 | self.figure.canvas.mpl_disconnect(self.cidpress)
308 | self.figure.canvas.mpl_disconnect(self.cidrelease)
309 | self.figure.canvas.mpl_disconnect(self.cidmotion)
310 |
311 | def updateColorBar(self, val):
312 | """Update slider for scaling log colorbar in 2D hist."""
313 | histVMax = np.power(10, self.sHistC.val)
314 | plt.clim(vmax=histVMax)
315 |
316 | def updateSliceNr(self):
317 | """Update slice number and the selected slice."""
318 | self.sliceNr = int(self.sSliceNr.val*self.orig.shape[2])
319 | self.imaSlc = self.orig[:, :, self.sliceNr]
320 |
321 | def updateImaBrowser(self, val):
322 | """Update image browse."""
323 | # scale slider value [0,1) to dimension index
324 | self.updateSliceNr()
325 | self.remapMsks()
326 | self.updatePanels(update_slice=True, update_rotation=True,
327 | update_extent=True)
328 |
329 | def updateImaExtent(self):
330 | """Update both image and mask extent in image browser."""
331 | self.imaSlcH.set_extent((0, self.imaSlc.shape[1],
332 | self.imaSlc.shape[0], 0))
333 | self.imaSlcMskH.set_extent((0, self.imaSlc.shape[1],
334 | self.imaSlc.shape[0], 0))
335 |
336 | def cycleView(self, event):
337 | """Cycle through views."""
338 | self.cycleCount = (self.cycleCount + 1) % 3
339 | # transpose data
340 | self.orig = np.transpose(self.orig, (2, 0, 1))
341 | # transpose ima2volHistMap
342 | self.invHistVolume = np.transpose(self.invHistVolume, (2, 0, 1))
343 | # updates
344 | self.updateSliceNr()
345 | self.remapMsks()
346 | self.updatePanels(update_slice=True, update_rotation=True,
347 | update_extent=True)
348 |
349 | def rotateIma90(self, axes=(0, 1)):
350 | """Rotate image slice 90 degrees."""
351 | self.imaSlc = np.rot90(self.imaSlc, axes=axes)
352 | self.imaSlcMsk = np.rot90(self.imaSlcMsk, axes=axes)
353 |
354 | def changeRotation(self, event):
355 | """Change rotation of image after clicking the button."""
356 | self.cycRotHistory[self.cycleCount][1] += 1
357 | self.cycRotHistory[self.cycleCount][1] %= 4
358 | self.rotateIma90()
359 | self.updatePanels(update_slice=True, update_rotation=False,
360 | update_extent=True)
361 |
362 | def checkRotation(self):
363 | """Check rotation update if changed."""
364 | cyc_rot = self.cycRotHistory[self.cycleCount][1]
365 | if cyc_rot == 1: # 90
366 | self.rotateIma90(axes=(0, 1))
367 | elif cyc_rot == 2: # 180
368 | self.imaSlc = self.imaSlc[::-1, ::-1]
369 | self.imaSlcMsk = self.imaSlcMsk[::-1, ::-1]
370 | elif cyc_rot == 3: # 270
371 | self.rotateIma90(axes=(1, 0))
372 |
373 | def exportNifti(self, event):
374 | """Export labels in the image browser as a nifti file."""
375 | print(" Exporting nifti file...")
376 | # put the permuted indices back to their original format
377 | cycBackPerm = (self.cycleCount, (self.cycleCount+1) % 3,
378 | (self.cycleCount+2) % 3)
379 | # assing unique integers (for ncut labels)
380 | out_volHistMask = np.copy(self.volHistMask)
381 | labels = np.unique(self.volHistMask)
382 | intLabels = [i for i in range(labels.size)]
383 | for label, newLabel in zip(labels, intLabels):
384 | out_volHistMask[out_volHistMask == label] = intLabels[newLabel]
385 | # get 3D brain mask
386 | volume_image = np.transpose(self.invHistVolume, cycBackPerm)
387 | if cfg.discard_zeros:
388 | zmask = volume_image != 0
389 | temp_labeled_image = map_2D_hist_to_ima(volume_image[zmask],
390 | out_volHistMask)
391 | out_nii = np.zeros(volume_image.shape)
392 | out_nii[zmask] = temp_labeled_image # put back flat labels
393 | else:
394 | out_nii = map_2D_hist_to_ima(volume_image.flatten(),
395 | out_volHistMask)
396 | out_nii = out_nii.reshape(volume_image.shape)
397 | # save mask image as nii
398 | new_image = Nifti1Image(out_nii, header=self.nii.header,
399 | affine=self.nii.affine)
400 | # get new flex file name and check for overwriting
401 | labels_out = '{}_labels_{}.nii.gz'.format(
402 | self.basename, self.nrExports)
403 | while os.path.isfile(labels_out):
404 | self.nrExports += 1
405 | labels_out = '{}_labels_{}.nii.gz'.format(
406 | self.basename, self.nrExports)
407 | save(new_image, labels_out)
408 | print(" Saved as: {}".format(labels_out))
409 |
410 | def clearOverlays(self):
411 | """Clear overlaid items such as circle highlights."""
412 | if self.highlights[0]:
413 | {h.remove() for h in self.highlights[0]}
414 | {h.remove() for h in self.highlights[1]}
415 | self.highlights[0] = []
416 |
417 | def resetGlobal(self, event):
418 | """Reset stuff."""
419 | # reset highlights
420 | self.clearOverlays()
421 | # reset color bar
422 | self.sHistC.reset()
423 | # reset transparency
424 | self.TranspVal = 0.5
425 | if self.segmType == 'main':
426 | if self.lassoSwitchCount == 1: # reset only lasso drawing
427 | self.idxLasso = np.zeros(self.nrBins*self.nrBins, dtype=bool)
428 | else:
429 | # reset theta sliders
430 | self.sThetaMin.reset()
431 | self.sThetaMax.reset()
432 | # reset values for mask
433 | self.sectorObj.set_x(cfg.init_centre[0])
434 | self.sectorObj.set_y(cfg.init_centre[1])
435 | self.sectorObj.set_r(cfg.init_radius)
436 | self.sectorObj.tmin, self.sectorObj.tmax = np.deg2rad(
437 | cfg.init_theta)
438 |
439 | elif self.segmType == 'ncut':
440 | self.sLabelNr.reset()
441 | # reset ncut labels
442 | self.ima_ncut_labels = np.copy(self.orig_ncut_labels)
443 | # reset values for volHistMask
444 | self.volHistMask = self.ima_ncut_labels[:, :, 0].reshape(
445 | (self.nrBins, self.nrBins))
446 | # reset counter field
447 | self.counterField = np.zeros((self.nrBins, self.nrBins))
448 | # reset political borders
449 | self.pltMap = np.zeros((self.nrBins, self.nrBins))
450 | self.pltMapH.set_data(self.pltMap)
451 | self.updateSliceNr()
452 | self.remapMsks()
453 | self.updatePanels(update_slice=False, update_rotation=True,
454 | update_extent=False)
455 |
456 | def updateThetaMin(self, val):
457 | """Update theta (min) in volume histogram mask."""
458 | if self.segmType == 'main':
459 | theta_val = self.sThetaMin.val # get theta value from slider
460 | self.sectorObj.theta_min(theta_val)
461 | self.remapMsks()
462 | self.updatePanels(update_slice=False, update_rotation=True,
463 | update_extent=False)
464 | else:
465 | return
466 |
467 | def updateThetaMax(self, val):
468 | """Update theta(max) in volume histogram mask."""
469 | if self.segmType == 'main':
470 | theta_val = self.sThetaMax.val # get theta value from slider
471 | self.sectorObj.theta_max(theta_val)
472 | self.remapMsks()
473 | self.updatePanels(update_slice=False, update_rotation=True,
474 | update_extent=False)
475 | else:
476 | return
477 |
478 | def exportNyp(self, event):
479 | """Export histogram counts as a numpy array."""
480 | print(" Exporting numpy file...")
481 | outFileName = '{}_identifier_pcMax{}_pcMin{}_sc{}'.format(
482 | self.basename, cfg.perc_max, cfg.perc_min, int(cfg.scale))
483 | if self.segmType == 'ncut':
484 | outFileName = outFileName.replace('identifier', 'volHistLabels')
485 | out_data = self.volHistMask
486 | elif self.segmType == 'main':
487 | outFileName = outFileName.replace('identifier', 'volHist')
488 | out_data = self.counts
489 | outFileName = outFileName.replace('.', 'pt')
490 | np.save(outFileName, out_data)
491 | print(" Saved as: {}{}".format(outFileName, '.npy'))
492 |
493 | def updateLabels(self, val):
494 | """Update labels in volume histogram with slider."""
495 | if self.segmType == 'ncut':
496 | self.labelNr = self.sLabelNr.val
497 | else: # NOTE: might be used in the future
498 | return
499 |
500 | def imaSlcMskIncr(self, incr):
501 | """Update transparency of image mask by increment."""
502 | if (self.TranspVal + incr >= 0) & (self.TranspVal + incr <= 1):
503 | self.TranspVal += incr
504 | self.imaSlcMskH.set_alpha(self.TranspVal)
505 | self.figure.canvas.draw()
506 |
507 | def imaSlcMskTransSwitch(self):
508 | """Update transparency of image mask to toggle transparency."""
509 | self.imaSlcMskSwitch = (self.imaSlcMskSwitch+1) % 2
510 | if self.imaSlcMskSwitch == 1: # set imaSlcMsk transp
511 | self.imaSlcMskH.set_alpha(0)
512 | else: # set imaSlcMsk opaque
513 | self.imaSlcMskH.set_alpha(self.TranspVal)
514 | self.figure.canvas.draw()
515 |
516 | def volHistHighlightTransSwitch(self):
517 | """Update transparency of highlights to toggle transparency."""
518 | self.volHistHighlightSwitch = (self.volHistHighlightSwitch+1) % 2
519 | if self.volHistHighlightSwitch == 1 and self.highlights[0]:
520 | if self.highlights[0]:
521 | {h.set_visible(False) for h in self.highlights[0]}
522 | elif self.volHistHighlightSwitch == 0 and self.highlights[0]:
523 | {h.set_visible(True) for h in self.highlights[0]}
524 | self.figure.canvas.draw()
525 |
526 | def updateLabelsRadio(self, val):
527 | """Update labels with radio buttons."""
528 | labelScale = self.lMax / 6. # nr of non-zero radio buttons
529 | self.labelNr = int(float(val) * labelScale)
530 |
531 | def labelContours(self):
532 | """Plot political borders used in ncut version."""
533 | grad = np.gradient(self.volHistMask)
534 | self.pltMap = np.greater(np.sqrt(np.power(grad[0], 2) +
535 | np.power(grad[1], 2)), 0)
536 | self.pltMapH.set_data(self.pltMap)
537 | self.pltMapH.set_extent((0, self.nrBins, self.nrBins, 0))
538 |
539 | def lassoArr(self, array, indices):
540 | """Update lasso volume histogram mask."""
541 | lin = np.arange(array.size)
542 | newArray = array.flatten()
543 | newArray[lin[indices]] = True
544 | return newArray.reshape(array.shape)
545 |
546 | def calcImaMaskBrd(self):
547 | """Calculate borders of image mask slice."""
548 | return self.imaSlcMsk - binary_erosion(self.imaSlcMsk)
549 |
550 |
551 | class sector_mask:
552 | """A pacman-like shape with useful parameters.
553 |
554 | Disclaimer
555 | ----------
556 | This script is adapted from a stackoverflow post by user ali_m:
557 | [1] http://stackoverflow.com/questions/18352973/mask-a-circular-sector-in-a-numpy-array
558 |
559 | """
560 |
561 | def __init__(self, shape, centre, radius, angle_range):
562 | """Initialize variables used acros functions here."""
563 | self.radius, self.shape = radius, shape
564 | self.x, self.y = np.ogrid[:shape[0], :shape[1]]
565 | self.cx, self.cy = centre
566 | self.tmin, self.tmax = np.deg2rad(angle_range)
567 | # ensure stop angle > start angle
568 | if self.tmax < self.tmin:
569 | self.tmax += 2 * np.pi
570 | # convert cartesian to polar coordinates
571 | self.r2 = (self.x - self.cx) * (self.x - self.cx) + (
572 | self.y-self.cy) * (self.y - self.cy)
573 | self.theta = np.arctan2(self.x - self.cx, self.y - self.cy) - self.tmin
574 | # wrap angles between 0 and 2*pi
575 | self.theta %= 2 * np.pi
576 |
577 | def set_polCrd(self):
578 | """Convert cartesian to polar coordinates."""
579 | self.r2 = (self.x-self.cx)*(self.x-self.cx) + (
580 | self.y-self.cy)*(self.y-self.cy)
581 | self.theta = np.arctan2(self.x-self.cx, self.y-self.cy) - self.tmin
582 | # wrap angles between 0 and 2*pi
583 | self.theta %= (2*np.pi)
584 |
585 | def set_x(self, x):
586 | """Set x axis value."""
587 | self.cx = x
588 | self.set_polCrd() # update polar coordinates
589 |
590 | def set_y(self, y):
591 | """Set y axis value."""
592 | self.cy = y
593 | self.set_polCrd() # update polar coordinates
594 |
595 | def set_r(self, radius):
596 | """Set radius of the circle."""
597 | self.radius = radius
598 |
599 | def scale_r(self, scale):
600 | """Scale (multiply) the radius."""
601 | self.radius = self.radius * scale
602 |
603 | def rotate(self, degree):
604 | """Rotate shape."""
605 | rad = np.deg2rad(degree)
606 | self.tmin += rad
607 | self.tmax += rad
608 | self.set_polCrd() # update polar coordinates
609 |
610 | def theta_min(self, degree):
611 | """Angle to determine one the cut out piece in circular mask."""
612 | rad = np.deg2rad(degree)
613 | self.tmin = rad
614 | # ensure stop angle > start angle
615 | if self.tmax <= self.tmin:
616 | self.tmax += 2*np.pi
617 | # ensure stop angle- 2*np.pi NOT > start angle
618 | if self.tmax - 2*np.pi >= self.tmin:
619 | self.tmax -= 2*np.pi
620 | # update polar coordinates
621 | self.set_polCrd()
622 |
623 | def theta_max(self, degree):
624 | """Angle to determine one the cut out piece in circular mask."""
625 | rad = np.deg2rad(degree)
626 | self.tmax = rad
627 | # ensure stop angle > start angle
628 | if self.tmax <= self.tmin:
629 | self.tmax += 2*np.pi
630 | # ensure stop angle- 2*np.pi NOT > start angle
631 | if self.tmax - 2*np.pi >= self.tmin:
632 | self.tmax -= 2*np.pi
633 | # update polar coordinates
634 | self.set_polCrd()
635 |
636 | def binaryMask(self):
637 | """Return a boolean mask for a circular sector."""
638 | # circular mask
639 | self.circmask = self.r2 <= self.radius*self.radius
640 | # angular mask
641 | self.anglemask = self.theta <= (self.tmax-self.tmin)
642 | # return binary mask
643 | return self.circmask*self.anglemask
644 |
645 | def contains(self, event):
646 | """Check if a cursor pointer is inside the sector mask."""
647 | xbin = np.floor(event.xdata)
648 | ybin = np.floor(event.ydata)
649 | Mask = self.binaryMask()
650 | # the next line doesn't follow pep 8 (otherwise it fails)
651 | if Mask[ybin][xbin] is True: # switch x and ybin, volHistMask not Cart
652 | return True
653 | else:
654 | return False
655 |
656 | def draw(self, ax, cmap='Reds', alpha=0.2, vmin=0.1, zorder=0,
657 | interpolation='nearest', origin='lower', extent=[0, 100, 0, 100]):
658 | """Draw sector mask."""
659 | BinMask = self.binaryMask()
660 | FigObj = ax.imshow(BinMask, cmap=cmap, alpha=alpha, vmin=vmin,
661 | interpolation=interpolation, origin=origin,
662 | extent=extent, zorder=zorder)
663 | return (FigObj, BinMask)
664 |
--------------------------------------------------------------------------------
/segmentator/hist2d_counts.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Save 2D histogram image without displaying GUI."""
3 |
4 | from __future__ import print_function
5 | import os
6 | import numpy as np
7 | import segmentator.config as cfg
8 | from segmentator.utils import truncate_range, scale_range, check_data
9 | from segmentator.utils import set_gradient_magnitude, prep_2D_hist
10 | from nibabel import load
11 |
12 | # load data
13 | nii = load(cfg.filename)
14 | basename = nii.get_filename().split(os.extsep, 1)[0]
15 |
16 | # data processing
17 | orig, _ = check_data(nii.get_fdata(), cfg.force_original_precision)
18 | orig, _, _ = truncate_range(orig, percMin=cfg.perc_min, percMax=cfg.perc_max)
19 | orig = scale_range(orig, scale_factor=cfg.scale, delta=0.0001)
20 | gra = set_gradient_magnitude(orig, cfg.gramag)
21 |
22 | # reshape ima (a bit more intuitive for voxel-wise operations)
23 | ima = np.ndarray.flatten(orig)
24 | gra = np.ndarray.flatten(gra)
25 |
26 | counts, _, _, _, _, _ = prep_2D_hist(ima, gra, discard_zeros=cfg.discard_zeros)
27 | outName = '{}_volHist_pcMax{}_pcMin{}_sc{}'.format(
28 | basename, cfg.perc_max, cfg.perc_min, int(cfg.scale))
29 | outName = outName.replace('.', 'pt')
30 | np.save(outName, counts)
31 | print(' Image saved as:\n {}'.format(outName))
32 |
--------------------------------------------------------------------------------
/segmentator/ncut_prepare.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Normalized graph cuts for segmentator (experimental).
3 |
4 | TODO: Replacing the functionality using scikit-learn?
5 | """
6 |
7 | import os
8 | import numpy as np
9 | from matplotlib import animation
10 | from matplotlib import pyplot as plt
11 | from skimage.future import graph
12 | from skimage.segmentation import slic
13 | import segmentator.config as cfg
14 |
15 |
16 | def norm_grap_cut(image, max_edge=10000000, max_rec=4, compactness=2,
17 | nrSupPix=2000):
18 | """Normalized graph cut wrapper for 2D numpy arrays.
19 |
20 | Parameters
21 | ----------
22 | image: np.ndarray (2D)
23 | Volume histogram.
24 | max_edge: float
25 | The maximum possible value of an edge in the RAG. This corresponds
26 | to an edge between identical regions. This is used to put self
27 | edges in the RAG.
28 | compactness: float
29 | From skimage slic_superpixels.py slic function:
30 | Balances color proximity and space proximity. Higher values give
31 | more weight to space proximity, making superpixel shapes more
32 | square/cubic. This parameter depends strongly on image contrast and
33 | on the shapes of objects in the image.
34 | nrSupPix: int, positive
35 | The (approximate) number of superpixels in the region adjacency
36 | graph.
37 |
38 | Returns
39 | -------
40 | labels2, labels1: np.ndarray (2D)
41 | Segmented volume histogram mask image. Each label has a unique
42 | identifier.
43 |
44 | """
45 | # scale for uint8 conversion
46 | image = np.round(255 / image.max() * image)
47 | image = image.astype('uint8')
48 |
49 | # scikit implementation expects rgb format (shape: NxMx3)
50 | image = np.tile(image, (3, 1, 1))
51 | image = np.transpose(image, (1, 2, 0))
52 |
53 | labels1 = slic(image, compactness=compactness, n_segments=nrSupPix,
54 | sigma=2)
55 | # region adjacency graph (rag)
56 | g = graph.rag_mean_color(img, labels1, mode='similarity_and_proximity')
57 | labels2 = graph.cut_normalized(labels1, g, max_edge=max_edge,
58 | num_cuts=1000, max_rec=max_rec)
59 | return labels2, labels1
60 |
61 |
62 | path = cfg.filename
63 | basename = path.split(os.extsep, 1)[0]
64 |
65 | # load data
66 | img = np.load(path)
67 | # take logarithm of every count to make it similar to what is seen in gui
68 | img = np.log10(img+1.)
69 |
70 | # truncate very high values
71 | img_max = cfg.cbar_init
72 | img[img > img_max] = img_max
73 |
74 | max_recursion = cfg.max_rec
75 | ncut = np.zeros((img.shape[0], img.shape[1], max_recursion + 1))
76 | for i in range(0, max_recursion + 1):
77 | msk, regions = norm_grap_cut(img, max_rec=i,
78 | nrSupPix=cfg.nr_sup_pix,
79 | compactness=cfg.compactness)
80 | ncut[:, :, i] = msk
81 |
82 | # plots
83 | if cfg.ncut_figs:
84 | fig = plt.figure()
85 | ax1 = fig.add_subplot(121)
86 | ax2 = fig.add_subplot(122)
87 |
88 | # ax1.imshow(img.T, origin="lower", cmap=plt.cm.inferno)
89 | ax1.imshow(regions.T, origin="lower", cmap=plt.cm.inferno)
90 | ax2.imshow(msk.T, origin="lower", cmap=plt.cm.nipy_spectral)
91 |
92 | ax1.set_title('Source')
93 | ax2.set_title('Ncut')
94 |
95 | plt.show()
96 |
97 | fig = plt.figure()
98 | unq = np.unique(msk)
99 | idx = -1
100 |
101 | im = plt.imshow(msk.T, origin="lower", cmap=plt.cm.flag,
102 | animated=True)
103 |
104 | def updatefig(*args):
105 | """Animate the plot."""
106 | global unq, msk, idx, tmp
107 | idx += 1
108 | idx = idx % ncut.shape[2]
109 | tmp = np.copy(ncut[:, :, idx])
110 | im.set_array(tmp.T)
111 | return im,
112 |
113 | ani = animation.FuncAnimation(fig, updatefig, interval=750, blit=True)
114 | plt.show()
115 |
116 | # save output
117 | outName = '{}_ncut_sp{}_c{}'.format(basename, cfg.nr_sup_pix, cfg.compactness)
118 | outName = outName.replace('.', 'pt')
119 | np.save(outName, ncut)
120 | print(" Saved as: {}{}".format(outName, '.npy'))
121 |
--------------------------------------------------------------------------------
/segmentator/segmentator_main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Processing input and plotting."""
3 |
4 | from __future__ import division, print_function
5 | import numpy as np
6 | import segmentator.config as cfg
7 | import matplotlib
8 | matplotlib.use(cfg.matplotlib_backend)
9 | print("Matplotlib backend: {}".format(matplotlib.rcParams['backend']))
10 | import matplotlib.pyplot as plt
11 | from matplotlib.colors import LogNorm
12 | from matplotlib.widgets import Slider, Button, LassoSelector
13 | from matplotlib import path
14 | from nibabel import load
15 | from segmentator.utils import map_ima_to_2D_hist, prep_2D_hist
16 | from segmentator.utils import truncate_range, scale_range, check_data
17 | from segmentator.utils import set_gradient_magnitude
18 | from segmentator.utils import export_gradient_magnitude_image
19 | from segmentator.gui_utils import sector_mask, responsiveObj
20 | from segmentator.config_gui import palette, axcolor, hovcolor
21 |
22 | #
23 | """Data Processing"""
24 | nii = load(cfg.filename)
25 | orig, dims = check_data(nii.get_fdata(), cfg.force_original_precision)
26 | # Save min and max truncation thresholds to be used in axis labels
27 | if np.isnan(cfg.valmin) or np.isnan(cfg.valmax):
28 | orig, pMin, pMax = truncate_range(orig, percMin=cfg.perc_min,
29 | percMax=cfg.perc_max)
30 | else: # TODO: integrate this into truncate range function
31 | orig[orig < cfg.valmin] = cfg.valmin
32 | orig[orig > cfg.valmax] = cfg.valmax
33 | pMin, pMax = cfg.valmin, cfg.valmax
34 |
35 | # Continue with scaling the original truncated image and recomputing gradient
36 | orig = scale_range(orig, scale_factor=cfg.scale, delta=0.0001)
37 | gra = set_gradient_magnitude(orig, cfg.gramag)
38 | if cfg.export_gramag:
39 | export_gradient_magnitude_image(gra, nii.get_filename(), cfg.gramag,
40 | nii.affine)
41 | # Reshape for voxel-wise operations
42 | ima = np.copy(orig.flatten())
43 | gra = gra.flatten()
44 |
45 | #
46 | """Plots"""
47 | print("Preparing GUI...")
48 | # Plot 2D histogram
49 | fig = plt.figure(facecolor='0.775')
50 | ax = fig.add_subplot(121)
51 |
52 | counts, volHistH, d_min, d_max, nr_bins, bin_edges \
53 | = prep_2D_hist(ima, gra, discard_zeros=cfg.discard_zeros)
54 |
55 | # Set x-y axis range to the same (x-axis range)
56 | ax.set_xlim(d_min, d_max)
57 | ax.set_ylim(d_min, d_max)
58 | ax.set_xlabel("Intensity f(x)")
59 | ax.set_ylabel("Gradient Magnitude f'(x)")
60 | ax.set_title("2D Histogram")
61 |
62 | # Plot colorbar for 2D hist
63 | volHistH.set_norm(LogNorm(vmax=np.power(10, cfg.cbar_init)))
64 | fig.colorbar(volHistH, fraction=0.046, pad=0.04) # magical scaling
65 |
66 | # Plot 3D ima by default
67 | ax2 = fig.add_subplot(122)
68 | sliceNr = int(0.5*dims[2])
69 | imaSlcH = ax2.imshow(orig[:, :, sliceNr], cmap=plt.cm.gray, vmin=ima.min(),
70 | vmax=ima.max(), interpolation='none',
71 | extent=[0, dims[1], dims[0], 0], zorder=0)
72 |
73 | imaSlcMsk = np.ones(dims[0:2])
74 | imaSlcMskH = ax2.imshow(imaSlcMsk, cmap=palette, vmin=0.1,
75 | interpolation='none', alpha=0.5,
76 | extent=[0, dims[1], dims[0], 0], zorder=1)
77 |
78 | # Adjust subplots on figure
79 | bottom = 0.30
80 | fig.subplots_adjust(bottom=bottom)
81 | fig.canvas.manager.set_window_title(nii.get_filename())
82 | plt.axis('off')
83 |
84 | #
85 | """Initialisation"""
86 | # Create first instance of sector mask
87 | sectorObj = sector_mask((nr_bins, nr_bins), cfg.init_centre, cfg.init_radius,
88 | cfg.init_theta)
89 |
90 | # Draw sector mask for the first time
91 | volHistMaskH, volHistMask = sectorObj.draw(ax, cmap=palette, alpha=0.2,
92 | vmin=0.1, interpolation='nearest',
93 | origin='lower', zorder=1,
94 | extent=[0, nr_bins, 0, nr_bins])
95 |
96 | # Initiate a flexible figure object, pass to it useful properties
97 | idxLasso = np.zeros(nr_bins*nr_bins, dtype=bool)
98 | lassoSwitchCount = 0
99 | lassoErase = 1 # 1 for drawing, 0 for erasing
100 | flexFig = responsiveObj(figure=ax.figure, axes=ax.axes, axes2=ax2.axes,
101 | segmType='main', orig=orig, nii=nii,
102 | sectorObj=sectorObj,
103 | nrBins=nr_bins,
104 | sliceNr=sliceNr,
105 | imaSlcH=imaSlcH,
106 | imaSlcMsk=imaSlcMsk, imaSlcMskH=imaSlcMskH,
107 | volHistMask=volHistMask, volHistMaskH=volHistMaskH,
108 | contains=volHistMaskH.contains,
109 | counts=counts,
110 | idxLasso=idxLasso,
111 | lassoSwitchCount=lassoSwitchCount,
112 | lassoErase=lassoErase)
113 |
114 | # Make the figure responsive to clicks
115 | flexFig.connect()
116 | ima2volHistMap = map_ima_to_2D_hist(xinput=ima, yinput=gra, bins_arr=bin_edges)
117 | flexFig.invHistVolume = np.reshape(ima2volHistMap, dims)
118 | ima, gra = None, None
119 |
120 | #
121 | """Sliders and Buttons"""
122 | # Colorbar slider
123 | axHistC = plt.axes([0.15, bottom-0.20, 0.25, 0.025], facecolor=axcolor)
124 | flexFig.sHistC = Slider(axHistC, 'Colorbar', 1, cfg.cbar_max,
125 | valinit=cfg.cbar_init, valfmt='%0.1f')
126 |
127 | # Image browser slider
128 | axSliceNr = plt.axes([0.6, bottom-0.15, 0.25, 0.025], facecolor=axcolor)
129 | flexFig.sSliceNr = Slider(axSliceNr, 'Slice', 0, 0.999, valinit=0.5,
130 | valfmt='%0.2f')
131 |
132 | # Theta sliders
133 | aThetaMin = plt.axes([0.15, bottom-0.10, 0.25, 0.025], facecolor=axcolor)
134 | flexFig.sThetaMin = Slider(aThetaMin, 'ThetaMin', 0, 359.9,
135 | valinit=cfg.init_theta[0], valfmt='%0.1f')
136 | aThetaMax = plt.axes([0.15, bottom-0.15, 0.25, 0.025], facecolor=axcolor)
137 | flexFig.sThetaMax = Slider(aThetaMax, 'ThetaMax', 0, 359.9,
138 | valinit=cfg.init_theta[1]-0.1, valfmt='%0.1f')
139 |
140 | # Cycle button
141 | cycleax = plt.axes([0.55, bottom-0.2475, 0.075, 0.0375])
142 | flexFig.bCycle = Button(cycleax, 'Cycle',
143 | color=axcolor, hovercolor=hovcolor)
144 |
145 | # Rotate button
146 | rotateax = plt.axes([0.55, bottom-0.285, 0.075, 0.0375])
147 | flexFig.bRotate = Button(rotateax, 'Rotate',
148 | color=axcolor, hovercolor=hovcolor)
149 |
150 | # Reset button
151 | resetax = plt.axes([0.65, bottom-0.285, 0.075, 0.075])
152 | flexFig.bReset = Button(resetax, 'Reset', color=axcolor, hovercolor=hovcolor)
153 |
154 | # Export nii button
155 | exportax = plt.axes([0.75, bottom-0.285, 0.075, 0.075])
156 | flexFig.bExport = Button(exportax, 'Export\nNifti',
157 | color=axcolor, hovercolor=hovcolor)
158 |
159 | # Export nyp button
160 | exportax = plt.axes([0.85, bottom-0.285, 0.075, 0.075])
161 | flexFig.bExportNyp = Button(exportax, 'Export\nHist',
162 | color=axcolor, hovercolor=hovcolor)
163 |
164 | #
165 | """Updates"""
166 | flexFig.sHistC.on_changed(flexFig.updateColorBar)
167 | flexFig.sSliceNr.on_changed(flexFig.updateImaBrowser)
168 | flexFig.sThetaMin.on_changed(flexFig.updateThetaMin)
169 | flexFig.sThetaMax.on_changed(flexFig.updateThetaMax)
170 | flexFig.bCycle.on_clicked(flexFig.cycleView)
171 | flexFig.bRotate.on_clicked(flexFig.changeRotation)
172 | flexFig.bExport.on_clicked(flexFig.exportNifti)
173 | flexFig.bExportNyp.on_clicked(flexFig.exportNyp)
174 | flexFig.bReset.on_clicked(flexFig.resetGlobal)
175 |
176 |
177 | # TODO: Temporary solution for displaying original x-y axis labels
178 | def update_axis_labels(event):
179 | """Swap histogram bin indices with original values."""
180 | xlabels = [item.get_text() for item in ax.get_xticklabels()]
181 | orig_range_labels = np.linspace(pMin, pMax, len(xlabels))
182 |
183 | # Adjust displayed decimals based on data range
184 | data_range = pMax - pMin
185 | if data_range > 200: # arbitrary value
186 | xlabels = [('%i' % i) for i in orig_range_labels]
187 | elif data_range > 20:
188 | xlabels = [('%.1f' % i) for i in orig_range_labels]
189 | elif data_range > 2:
190 | xlabels = [('%.2f' % i) for i in orig_range_labels]
191 | else:
192 | xlabels = [('%.3f' % i) for i in orig_range_labels]
193 |
194 | ax.set_xticklabels(xlabels)
195 | ax.set_yticklabels(xlabels) # limits of y axis assumed to be the same as x
196 |
197 |
198 | fig.canvas.mpl_connect('resize_event', update_axis_labels)
199 |
200 | #
201 | """Lasso selection"""
202 | # Lasso button
203 | lassoax = plt.axes([0.15, bottom-0.285, 0.075, 0.075])
204 | bLasso = Button(lassoax, 'Lasso\nOff', color=axcolor, hovercolor=hovcolor)
205 |
206 | # Lasso draw/erase
207 | lassoEraseAx = plt.axes([0.25, bottom-0.285, 0.075, 0.075])
208 | bLassoErase = Button(lassoEraseAx, 'Erase\nOff', color=axcolor,
209 | hovercolor=hovcolor)
210 | bLassoErase.ax.patch.set_visible(False)
211 | bLassoErase.label.set_visible(False)
212 | bLassoErase.ax.axis('off')
213 |
214 |
215 | def lassoSwitch(event):
216 | """Enable disable lasso tool."""
217 | global lasso
218 | lasso = []
219 | flexFig.lassoSwitchCount = (flexFig.lassoSwitchCount+1) % 2
220 | if flexFig.lassoSwitchCount == 1: # enable lasso
221 | flexFig.disconnect() # disable drag function of sector mask
222 | lasso = LassoSelector(ax, onselect)
223 | bLasso.label.set_text("Lasso\nOn")
224 | # Make erase button appear on in lasso mode
225 | bLassoErase.ax.patch.set_visible(True)
226 | bLassoErase.label.set_visible(True)
227 | bLassoErase.ax.axis('on')
228 |
229 | else: # disable lasso
230 | flexFig.connect() # enable drag function of sector mask
231 | bLasso.label.set_text("Lasso\nOff")
232 | # Make erase button disappear
233 | bLassoErase.ax.patch.set_visible(False)
234 | bLassoErase.label.set_visible(False)
235 | bLassoErase.ax.axis('off')
236 |
237 | # Pixel coordinates
238 | pix = np.arange(nr_bins)
239 | xv, yv = np.meshgrid(pix, pix)
240 | pix = np.vstack((xv.flatten(), yv.flatten())).T
241 |
242 |
243 | def onselect(verts):
244 | """Lasso related."""
245 | global pix
246 | p = path.Path(verts)
247 | newLasIdx = p.contains_points(pix, radius=1.5) # New lasso indices
248 | flexFig.idxLasso[newLasIdx] = flexFig.lassoErase # Update lasso indices
249 | flexFig.remapMsks() # Update volume histogram mask
250 | flexFig.updatePanels(update_slice=False, update_rotation=True,
251 | update_extent=True)
252 |
253 |
254 | def lassoEraseSwitch(event):
255 | """Enable disable lasso erase function."""
256 | flexFig.lassoErase = (flexFig.lassoErase + 1) % 2
257 | if flexFig.lassoErase == 1:
258 | bLassoErase.label.set_text("Erase\nOff")
259 | elif flexFig.lassoErase == 0:
260 | bLassoErase.label.set_text("Erase\nOn")
261 |
262 |
263 | bLasso.on_clicked(lassoSwitch) # lasso on/off
264 | bLassoErase.on_clicked(lassoEraseSwitch) # lasso erase on/off
265 | flexFig.remapMsks()
266 | flexFig.updatePanels(update_slice=True, update_rotation=False,
267 | update_extent=False)
268 |
269 | print("GUI is ready.")
270 | plt.show()
271 |
--------------------------------------------------------------------------------
/segmentator/segmentator_ncut.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Processing input and plotting, for experimental ncut feature.
3 |
4 | TODO: Lots of code repetition, will be integrated better in the future.
5 | """
6 |
7 | from __future__ import division, print_function
8 | import numpy as np
9 | import segmentator.config as cfg
10 | import matplotlib
11 | matplotlib.use('TkAgg')
12 | import matplotlib.pyplot as plt
13 | from matplotlib.colors import LogNorm, ListedColormap, BoundaryNorm
14 | from matplotlib.widgets import Slider, Button, RadioButtons
15 | from nibabel import load
16 | from segmentator.utils import map_ima_to_2D_hist, prep_2D_hist
17 | from segmentator.utils import truncate_range, scale_range, check_data
18 | from segmentator.utils import set_gradient_magnitude
19 | from segmentator.utils import export_gradient_magnitude_image
20 | from segmentator.gui_utils import responsiveObj
21 | from segmentator.config_gui import palette, axcolor, hovcolor
22 |
23 | #
24 | """Load Data"""
25 | nii = load(cfg.filename)
26 | ncut_labels = np.load(cfg.ncut)
27 |
28 | # transpose the labels
29 | ncut_labels = np.transpose(ncut_labels, (1, 0, 2))
30 | nrTotal_labels = sum([2**x for x in range(ncut_labels.shape[2])])
31 | total_labels = np.arange(nrTotal_labels)
32 | total_labels[1::2] = total_labels[-2:0:-2]
33 |
34 | # relabel the labels from ncut, assign ascending integers starting with counter
35 | counter = 0
36 | for ind in np.arange(ncut_labels.shape[2]):
37 | tmp = ncut_labels[:, :, ind]
38 | uniqueVals = np.unique(tmp)
39 | nrUniqueVals = len(uniqueVals)
40 | newInd = np.arange(nrUniqueVals) + counter
41 | newVals = total_labels[newInd]
42 | tmp2 = np.zeros((tmp.shape))
43 | for ind2, val in enumerate(uniqueVals):
44 | tmp2[tmp == val] = newVals[ind2]
45 | counter = counter + nrUniqueVals
46 | ncut_labels[:, :, ind] = tmp2
47 | lMax = np.max(ncut_labels)
48 |
49 | orig_ncut_labels = ncut_labels.copy()
50 | ima_ncut_labels = ncut_labels.copy()
51 |
52 |
53 | #
54 | """Data Processing"""
55 | orig, dims = check_data(nii.get_fdata(), cfg.force_original_precision)
56 | # Save min and max truncation thresholds to be used in axis labels
57 | orig, pMin, pMax = truncate_range(orig, percMin=cfg.perc_min,
58 | percMax=cfg.perc_max)
59 | # Continue with scaling the original truncated image and recomputing gradient
60 | orig = scale_range(orig, scale_factor=cfg.scale, delta=0.0001)
61 | gra = set_gradient_magnitude(orig, cfg.gramag)
62 | if cfg.export_gramag:
63 | export_gradient_magnitude_image(gra, nii.get_filename(), nii.affine)
64 |
65 | # Reshape ima (more intuitive for voxel-wise operations)
66 | ima = np.ndarray.flatten(orig)
67 | gra = np.ndarray.flatten(gra)
68 |
69 | #
70 | """Plots"""
71 | print("Preparing GUI...")
72 | # Plot 2D histogram
73 | fig = plt.figure(facecolor='0.775')
74 | ax = fig.add_subplot(121)
75 |
76 | counts, volHistH, d_min, d_max, nr_bins, bin_edges \
77 | = prep_2D_hist(ima, gra, discard_zeros=cfg.discard_zeros)
78 |
79 | ax.set_xlim(d_min, d_max)
80 | ax.set_ylim(d_min, d_max)
81 | ax.set_xlabel("Intensity f(x)")
82 | ax.set_ylabel("Gradient Magnitude f'(x)")
83 | ax.set_title("2D Histogram")
84 |
85 | # Plot map for poltical borders
86 | pltMap = np.zeros((nr_bins, nr_bins, 1)).repeat(4, 2)
87 | cmapPltMap = ListedColormap([[1, 1, 1, 0], # transparent zeros
88 | [0, 0, 0, 0.75], # political borders
89 | [1, 0, 0, 0.5], # other colors for future use
90 | [0, 0, 1, 0.5]])
91 | boundsPltMap = [0, 1, 2, 3, 4]
92 | normPltMap = BoundaryNorm(boundsPltMap, cmapPltMap.N)
93 | pltMapH = ax.imshow(pltMap, cmap=cmapPltMap, norm=normPltMap,
94 | vmin=boundsPltMap[1], vmax=boundsPltMap[-1],
95 | extent=[0, nr_bins, nr_bins, 0], interpolation='none')
96 |
97 | # Plot colorbar for 2d hist
98 | volHistH.set_norm(LogNorm(vmax=np.power(10, cfg.cbar_init)))
99 | fig.colorbar(volHistH, fraction=0.046, pad=0.04) # magical perfect scaling
100 |
101 | # Set up a colormap for ncut labels
102 | ncut_palette = plt.cm.gist_rainbow
103 | ncut_palette.set_under('w', 0)
104 |
105 | # Plot hist mask (with ncut labels)
106 | volHistMask = np.squeeze(ncut_labels[:, :, 0])
107 | volHistMaskH = ax.imshow(volHistMask, interpolation='none',
108 | alpha=0.2, cmap=ncut_palette,
109 | vmin=np.min(ncut_labels)+1, # to make 0 transparent
110 | vmax=lMax,
111 | extent=[0, nr_bins, nr_bins, 0])
112 |
113 | # Plot 3D ima by default
114 | ax2 = fig.add_subplot(122)
115 | sliceNr = int(0.5*dims[2])
116 | imaSlcH = ax2.imshow(orig[:, :, sliceNr], cmap=plt.cm.gray,
117 | vmin=ima.min(), vmax=ima.max(), interpolation='none',
118 | extent=[0, dims[1], dims[0], 0])
119 | imaSlcMsk = np.zeros(dims[0:2])*total_labels[1]
120 | imaSlcMskH = ax2.imshow(imaSlcMsk, interpolation='none', alpha=0.5,
121 | cmap=ncut_palette, vmin=np.min(ncut_labels)+1,
122 | vmax=lMax,
123 | extent=[0, dims[1], dims[0], 0])
124 |
125 | # Adjust subplots on figure
126 | bottom = 0.30
127 | fig.subplots_adjust(bottom=bottom)
128 | fig.canvas.set_window_title(nii.get_filename())
129 | plt.axis('off')
130 |
131 |
132 | # %%
133 | """Initialisation"""
134 | # Initiate a flexible figure object, pass to it usefull properties
135 | flexFig = responsiveObj(figure=ax.figure, axes=ax.axes, axes2=ax2.axes,
136 | segmType='ncut', orig=orig, nii=nii, ima=ima,
137 | nrBins=nr_bins,
138 | sliceNr=sliceNr,
139 | imaSlcH=imaSlcH,
140 | imaSlcMsk=imaSlcMsk, imaSlcMskH=imaSlcMskH,
141 | volHistMask=volHistMask,
142 | volHistMaskH=volHistMaskH,
143 | pltMap=pltMap, pltMapH=pltMapH,
144 | counterField=np.zeros((nr_bins, nr_bins)),
145 | orig_ncut_labels=orig_ncut_labels,
146 | ima_ncut_labels=ima_ncut_labels,
147 | lMax=lMax)
148 |
149 | # Make the figure responsive to clicks
150 | flexFig.connect()
151 | # Get mapping from image slice to volume histogram
152 | ima2volHistMap = map_ima_to_2D_hist(xinput=ima, yinput=gra, bins_arr=bin_edges)
153 | flexFig.invHistVolume = np.reshape(ima2volHistMap, dims)
154 |
155 | # %%
156 | """Sliders and Buttons"""
157 | axcolor, hovcolor = '0.875', '0.975'
158 |
159 | # Radio buttons (ugly but good enough for now)
160 | rax = plt.axes([0.91, 0.35, 0.08, 0.5], facecolor=(0.75, 0.75, 0.75))
161 | flexFig.radio = RadioButtons(rax, [str(i) for i in range(7)],
162 | activecolor=(0.25, 0.25, 0.25))
163 |
164 | # Colorbar slider
165 | axHistC = plt.axes([0.15, bottom-0.230, 0.25, 0.025], facecolor=axcolor)
166 | flexFig.sHistC = Slider(axHistC, 'Colorbar', 1, cfg.cbar_max,
167 | valinit=cfg.cbar_init, valfmt='%0.1f')
168 |
169 | # Label slider
170 | axLabels = plt.axes([0.15, bottom-0.270, 0.25, 0.025], facecolor=axcolor)
171 | flexFig.sLabelNr = Slider(axLabels, 'Labels', 0, lMax,
172 | valinit=lMax, valfmt='%i')
173 |
174 | # Image browser slider
175 | axSliceNr = plt.axes([0.6, bottom-0.15, 0.25, 0.025], facecolor=axcolor)
176 | flexFig.sSliceNr = Slider(axSliceNr, 'Slice', 0, 0.999,
177 | valinit=0.5, valfmt='%0.3f')
178 |
179 | # Cycle button
180 | cycleax = plt.axes([0.55, bottom-0.285, 0.075, 0.075])
181 | flexFig.bCycle = Button(cycleax, 'Cycle\nView',
182 | color=axcolor, hovercolor=hovcolor)
183 | flexFig.cycleCount = 0
184 |
185 | # Rotate button
186 | rotateax = plt.axes([0.55, bottom-0.285, 0.075, 0.0375])
187 | flexFig.bRotate = Button(rotateax, 'Rotate',
188 | color=axcolor, hovercolor=hovcolor)
189 | flexFig.rotationCount = 0
190 |
191 | # Export nii button
192 | exportax = plt.axes([0.75, bottom-0.285, 0.075, 0.075])
193 | flexFig.bExport = Button(exportax, 'Export\nNifti',
194 | color=axcolor, hovercolor=hovcolor)
195 |
196 | # Export nyp button
197 | exportax = plt.axes([0.85, bottom-0.285, 0.075, 0.075])
198 | flexFig.bExportNyp = Button(exportax, 'Export\nHist',
199 | color=axcolor, hovercolor=hovcolor)
200 |
201 | # Reset button
202 | resetax = plt.axes([0.65, bottom-0.285, 0.075, 0.075])
203 | flexFig.bReset = Button(resetax, 'Reset', color=axcolor, hovercolor=hovcolor)
204 |
205 |
206 | # %%
207 | """Updates"""
208 | flexFig.sHistC.on_changed(flexFig.updateColorBar)
209 | flexFig.sSliceNr.on_changed(flexFig.updateImaBrowser)
210 | flexFig.sLabelNr.on_changed(flexFig.updateLabels)
211 | flexFig.bCycle.on_clicked(flexFig.cycleView)
212 | flexFig.bRotate.on_clicked(flexFig.changeRotation)
213 | flexFig.bExport.on_clicked(flexFig.exportNifti)
214 | flexFig.bExportNyp.on_clicked(flexFig.exportNyp)
215 | flexFig.bReset.on_clicked(flexFig.resetGlobal)
216 | flexFig.radio.on_clicked(flexFig.updateLabelsRadio)
217 |
218 |
219 | # TODO: Temporary solution for displaying original x-y axis labels
220 | def update_axis_labels(event):
221 | """Swap histogram bin indices with original values."""
222 | xlabels = [item.get_text() for item in ax.get_xticklabels()]
223 | orig_range_labels = np.linspace(pMin, pMax, len(xlabels))
224 |
225 | # Adjust displayed decimals based on data range
226 | data_range = pMax - pMin
227 | if data_range > 200: # arbitrary value
228 | xlabels = [('%i' % i) for i in orig_range_labels]
229 | elif data_range > 20:
230 | xlabels = [('%.1f' % i) for i in orig_range_labels]
231 | elif data_range > 2:
232 | xlabels = [('%.2f' % i) for i in orig_range_labels]
233 | else:
234 | xlabels = [('%.3f' % i) for i in orig_range_labels]
235 |
236 | ax.set_xticklabels(xlabels)
237 | ax.set_yticklabels(xlabels) # limits of y axis assumed to be the same as x
238 |
239 |
240 | fig.canvas.mpl_connect('resize_event', update_axis_labels)
241 |
242 | plt.show()
243 |
--------------------------------------------------------------------------------
/segmentator/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | """Test utility functions."""
2 |
3 | import numpy as np
4 | from segmentator.utils import truncate_range, scale_range
5 |
6 |
7 | def test_truncate_range():
8 | """Test range truncation."""
9 | # Given
10 | data = np.random.random(100)
11 | data.ravel()[np.random.choice(data.size, 10, replace=False)] = 0
12 | data.ravel()[np.random.choice(data.size, 5, replace=False)] = np.nan
13 | p_min, p_max = 2.5, 97.5
14 | expected = np.nanpercentile(data, [p_min, p_max])
15 | # When
16 | output, _, _ = truncate_range(data, percMin=p_min, percMax=p_max,
17 | discard_zeros=False)
18 | # Then
19 | assert all(np.nanpercentile(output, [0, 100]) == expected)
20 |
21 |
22 | def test_scale_range():
23 | """Test range scaling."""
24 | # Given
25 | data = np.random.random(100) - 0.5
26 | data.ravel()[np.random.choice(data.size, 10, replace=False)] = 0.
27 | data.ravel()[np.random.choice(data.size, 5, replace=False)] = np.nan
28 | s = 42. # scaling factor
29 | expected = [0., s] # min and max
30 | # When
31 | output = scale_range(data, scale_factor=s, delta=0.01, discard_zeros=False)
32 | # Then
33 | assert all([np.nanmin(output) >= expected[0],
34 | np.nanmax(output) < expected[1]])
35 |
--------------------------------------------------------------------------------
/segmentator/tests/wip_test_arcweld.py:
--------------------------------------------------------------------------------
1 | """Test and demonstsrate arcweld classification.
2 |
3 | TODO: Turn this into unit tests for arcweld.
4 |
5 | """
6 |
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | # create toy data
11 | res = 500 # resolution
12 | data = np.mgrid[0:res, 0:res].astype('float')
13 | ima, gra = data[0, :, :].flatten(), data[1, :, :].flatten()
14 | dims = data.shape
15 |
16 | # have 3 arbitrary classes (standing for classes csf, gm, wm)
17 | classes = np.array([res/7, res/7*4, res/7*6])
18 |
19 | # find arc anchor
20 | arc_center = (data.max() + data.min()) / 2.
21 | arc_radius = (data.max() - data.min()) / 2.
22 | arc_weight = (gra / arc_radius)**-1
23 | classes = np.hstack([classes, arc_center])
24 |
25 | # find euclidean distances to classes
26 | soft = []
27 | data = data.reshape(dims[0], dims[1]*dims[2])
28 | for i, a in enumerate(classes):
29 | c = np.array([a, 0])
30 | # euclidean distance
31 | edist = np.sqrt(np.sum((data - c[:, None])**2., axis=0))
32 | soft.append(edist)
33 | soft = np.asarray(soft)
34 |
35 | # arc translation
36 | soft[-1, :] = soft[-1, :] - arc_radius
37 | soft[-1, :] = arc_weight * np.abs(soft[-1, :])
38 |
39 | # arbitrary weights
40 | soft[0, :] = soft[0, :] * 0.66 # csf
41 | soft[-1, :] = soft[-1, :] * 0.5 # arc
42 |
43 | # hard class membership maps
44 | hard = np.argmin(soft, axis=0)
45 | hard = hard.reshape(dims[1], dims[2])
46 |
47 | plt.imshow(hard.T, origin="lower")
48 | plt.show()
49 |
--------------------------------------------------------------------------------
/segmentator/tests/wip_test_gradient_magnitude.py:
--------------------------------------------------------------------------------
1 | """Experiment with different gradient magnitude calculations.
2 |
3 | TODO: turn this into unit tests.
4 |
5 | """
6 |
7 | import os
8 | from nibabel import load, Nifti1Image, save
9 | from segmentator.utils import compute_gradient_magnitude
10 |
11 | # load
12 | nii = load('/home/faruk/gdrive/Segmentator/data/faruk/gramag_test/mprage_S02_restore.nii.gz')
13 | ima = nii.get_fdata()
14 | basename = nii.get_filename().split(os.extsep, 1)[0]
15 |
16 | # 3D Scharr gradient magnitude
17 | gra_mag = compute_gradient_magnitude(ima, method='scharr')
18 | out = Nifti1Image(gra_mag, affine=nii.affine)
19 | save(out, basename + '_scharr.nii.gz')
20 |
21 | # 3D Sobel gradient magnitude
22 | gra_mag = compute_gradient_magnitude(ima, method='sobel')
23 | out = Nifti1Image(gra_mag, affine=nii.affine)
24 | save(out, basename + '_sobel.nii.gz')
25 |
26 | # 3D Prewitt gradient magnitude
27 | gra_mag = compute_gradient_magnitude(ima, method='prewitt')
28 | out = Nifti1Image(gra_mag, affine=nii.affine)
29 | save(out, basename + '_prewitt.nii.gz')
30 |
31 | # numpy gradient magnitude
32 | gra_mag = compute_gradient_magnitude(ima, method='numpy')
33 | out = Nifti1Image(gra_mag, affine=nii.affine)
34 | save(out, basename + '_numpy_gradient.nii.gz')
35 |
--------------------------------------------------------------------------------
/segmentator/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Some utility functions."""
3 |
4 | from __future__ import division, print_function
5 | import os
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import segmentator.config as cfg
9 | from nibabel import load, Nifti1Image, save
10 | from scipy.ndimage import convolve
11 | from time import time
12 |
13 |
14 | def sub2ind(array_shape, rows, cols):
15 | """Pixel to voxel mapping (similar to matlab's function).
16 |
17 | Parameters
18 | ----------
19 | array_shape : TODO
20 | rows : TODO
21 | cols : TODO
22 |
23 | Returns
24 | -------
25 | TODO
26 |
27 | """
28 | # return (rows*array_shape + cols)
29 | return (cols*array_shape + rows)
30 |
31 |
32 | def map_ima_to_2D_hist(xinput, yinput, bins_arr):
33 | """Image to volume histogram mapping (kind of inverse histogram).
34 |
35 | Parameters
36 | ----------
37 | xinput : TODO
38 | First image, which is often the intensity image (eg. T1w).
39 | yinput : TODO
40 | Second image, which is often the gradient magnitude image
41 | derived from the first image.
42 | bins_arr : TODO
43 | Array of bins.
44 |
45 | Returns
46 | -------
47 | vox2pixMap : TODO
48 | Voxel to pixel mapping.
49 |
50 | """
51 | dgtzData = np.digitize(xinput, bins_arr)-1
52 | dgtzGra = np.digitize(yinput, bins_arr)-1
53 | nr_bins = len(bins_arr)-1 # subtract 1 (more borders than containers)
54 | vox2pixMap = sub2ind(nr_bins, dgtzData, dgtzGra) # 1D
55 | return vox2pixMap
56 |
57 |
58 | def map_2D_hist_to_ima(imaSlc2volHistMap, volHistMask):
59 | """Volume histogram to image mapping for slices (uses np.ind1).
60 |
61 | Parameters
62 | ----------
63 | imaSlc2volHistMap : 1D numpy array
64 | Flattened image slice.
65 | volHistMask : 1D numpy array
66 | Flattened volume histogram mask.
67 |
68 | Returns
69 | -------
70 | imaSlcMask : 1D numpy array
71 | Flat image slice mask based on labeled pixels in volume histogram.
72 |
73 | """
74 | imaSlcMask = np.zeros(imaSlc2volHistMap.shape)
75 | idxUnique = np.unique(volHistMask)
76 | for idx in idxUnique:
77 | linIndices = np.where(volHistMask.flatten() == idx)[0]
78 | # return logical array with length equal to nr of voxels
79 | voxMask = np.in1d(imaSlc2volHistMap, linIndices)
80 | # reset mask and apply logical indexing
81 | imaSlcMask[voxMask] = idx
82 | return imaSlcMask
83 |
84 |
85 | def truncate_range(data, percMin=0.25, percMax=99.75, discard_zeros=True):
86 | """Truncate too low and too high values.
87 |
88 | Parameters
89 | ----------
90 | data : np.ndarray
91 | Image to be truncated.
92 | percMin : float
93 | Percentile minimum.
94 | percMax : float
95 | Percentile maximum.
96 | discard_zeros : bool
97 | Discard voxels with value 0 from truncation.
98 |
99 | Returns
100 | -------
101 | data : np.ndarray
102 | Truncated data.
103 | pMin : float
104 | Minimum truncation threshold which is used.
105 | pMax : float
106 | Maximum truncation threshold which is used.
107 |
108 | """
109 | if discard_zeros:
110 | msk = ~np.isclose(data, 0.)
111 | pMin, pMax = np.nanpercentile(data[msk], [percMin, percMax])
112 | else:
113 | pMin, pMax = np.nanpercentile(data, [percMin, percMax])
114 | temp = data[~np.isnan(data)]
115 | temp[temp < pMin], temp[temp > pMax] = pMin, pMax # truncate min and max
116 | data[~np.isnan(data)] = temp
117 | if discard_zeros:
118 | data[~msk] = 0 # put back masked out voxels
119 | return data, pMin, pMax
120 |
121 |
122 | def scale_range(data, scale_factor=500, delta=0, discard_zeros=True):
123 | """Scale values as a preprocessing step.
124 |
125 | Parameters
126 | ----------
127 | data : np.ndarray
128 | Image to be scaled.
129 | scale_factor : float
130 | Lower scaleFactors provides faster interface due to loweing the
131 | resolution of 2D histogram ( 500 seems fast enough).
132 | delta : float
133 | Delta ensures that the max data points fall inside the last bin
134 | when this function is used with histograms.
135 | discard_zeros : bool
136 | Discard voxels with value 0 from truncation.
137 |
138 | Returns
139 | -------
140 | data: np.ndarray
141 | Scaled image.
142 |
143 | """
144 | if discard_zeros:
145 | msk = ~np.isclose(data, 0)
146 | else:
147 | msk = np.ones(data.shape, dtype=bool)
148 | scale_factor = scale_factor - delta
149 | data[msk] = data[msk] - np.nanmin(data[msk])
150 | data[msk] = scale_factor / np.nanmax(data[msk]) * data[msk]
151 | if discard_zeros:
152 | data[~msk] = 0 # put back masked out voxels
153 | return data
154 |
155 |
156 | def check_data(data, force_original_precision=True):
157 | """Do type casting here."""
158 | data = np.squeeze(data) # to prevent singular dimension error
159 | dims = data.shape
160 | print('Input image data type is {}.'.format(data.dtype.name))
161 | if force_original_precision:
162 | pass
163 | elif data.dtype != 'float32':
164 | data = data.astype('float32')
165 | print(' Data type is casted to {}.'.format(data.dtype.name))
166 | return data, dims
167 |
168 |
169 | def prep_2D_hist(ima, gra, discard_zeros=True):
170 | """Prepare 2D histogram related variables.
171 |
172 | Parameters
173 | ----------
174 | ima : np.ndarray
175 | First image, which is often the intensity image (eg. T1w).
176 | gra : np.ndarray
177 | Second image, which is often the gradient magnitude image
178 | derived from the first image.
179 |
180 | Returns
181 | -------
182 | counts : integer
183 | volHistH : TODO
184 | d_min : float
185 | Minimum of the first image.
186 | d_max : float
187 | Maximum of the first image.
188 | nr_bins : integer
189 | Number of one dimensional bins (not the pixels).
190 | bin_edges : TODO
191 |
192 | Notes
193 | -----
194 | This function is modularized to be called from the terminal.
195 |
196 | """
197 | if discard_zeros:
198 | gra = gra[~np.isclose(ima, 0)]
199 | ima = ima[~np.isclose(ima, 0)]
200 | d_min, d_max = np.round(np.nanpercentile(ima, [0, 100]))
201 | nr_bins = int(d_max - d_min)
202 | bin_edges = np.arange(d_min, d_max+1)
203 | counts, _, _, volHistH = plt.hist2d(ima, gra, bins=bin_edges, cmap='Greys')
204 | return counts, volHistH, d_min, d_max, nr_bins, bin_edges
205 |
206 |
207 | def create_3D_kernel(operator='scharr'):
208 | """Create various 3D kernels.
209 |
210 | Parameters
211 | ----------
212 | operator : np.ndarray, shape=(n, n, 3)
213 | Input can be 'sobel', 'prewitt' or any 3D numpy array.
214 |
215 | Returns
216 | -------
217 | kernel : np.ndarray, shape(6, n, n, 3)
218 |
219 | """
220 | if operator == 'sobel':
221 | operator = np.array([[[1, 2, 1], [2, 4, 2], [1, 2, 1]],
222 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
223 | [[-1, -2, -1], [-2, -4, -2], [-1, -2, -1]]],
224 | dtype='float32')
225 | elif operator == 'prewitt':
226 | operator = np.array([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
227 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
228 | [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]]],
229 | dtype='float32')
230 | elif operator == 'scharr':
231 | operator = np.array([[[9, 30, 9], [30, 100, 30], [9, 30, 9]],
232 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
233 | [[-9, -30, -9], [-30, -100, -30], [-9, -30, -9]]],
234 | dtype='float32')
235 | scale_normalization_factor = np.sum(np.abs(operator))
236 | operator = np.divide(operator, scale_normalization_factor)
237 |
238 | # create permutations operator that will be used in gradient computation
239 | kernel = np.zeros([3, 3, 3, 3])
240 | kernel[0, ...] = operator
241 | kernel[1, ...] = np.transpose(kernel[0, ...], [2, 0, 1])
242 | kernel[2, ...] = np.transpose(kernel[0, ...], [1, 2, 0])
243 | return kernel
244 |
245 |
246 | def compute_gradient_magnitude(ima, method='scharr'):
247 | """Compute gradient magnitude of images.
248 |
249 | Parameters
250 | ----------
251 | ima : np.ndarray
252 | First image, which is often the intensity image (eg. T1w).
253 | method : string
254 | Gradient computation method. Available options are 'scharr',
255 | 'sobel', 'prewitt', 'numpy'.
256 | Returns
257 | -------
258 | gra_mag : np.ndarray
259 | Second image, which is often the gradient magnitude image
260 | derived from the first image
261 |
262 | """
263 | start = time()
264 | print(' Computing gradients...')
265 | if method.lower() == 'sobel': # magnitude scale is similar to numpy method
266 | kernel = create_3D_kernel(operator=method)
267 | gra = np.zeros(ima.shape + (kernel.shape[0],))
268 | for d in range(kernel.shape[0]):
269 | gra[..., d] = convolve(ima, kernel[d, ...])
270 | # compute generic gradient magnitude with normalization
271 | gra_mag = np.sqrt(np.sum(np.power(gra, 2.), axis=-1) * 2.)
272 | elif method.lower() == 'prewitt':
273 | kernel = create_3D_kernel(operator=method)
274 | gra = np.zeros(ima.shape + (kernel.shape[0],))
275 | for d in range(kernel.shape[0]):
276 | gra[..., d] = convolve(ima, kernel[d, ...])
277 | # compute generic gradient magnitude with normalization
278 | gra_mag = np.sqrt(np.sum(np.power(gra, 2.), axis=-1) * 2.)
279 | elif method.lower() == 'scharr':
280 | kernel = create_3D_kernel(operator=method)
281 | gra = np.zeros(ima.shape + (kernel.shape[0],))
282 | for d in range(kernel.shape[0]):
283 | gra[..., d] = convolve(ima, kernel[d, ...])
284 | # compute generic gradient magnitude with normalization
285 | gra_mag = np.sqrt(np.sum(np.power(gra, 2.), axis=-1) * 2.)
286 | elif method.lower() == 'numpy':
287 | gra = np.asarray(np.gradient(ima))
288 | gra_mag = np.sqrt(np.sum(np.power(gra, 2.), axis=0))
289 | elif method.lower() == 'deriche':
290 | from segmentator.deriche_prepare import Deriche_Gradient_Magnitude
291 | alpha = cfg.deriche_alpha
292 | print(' Selected alpha: {}'.format(alpha))
293 | ima = np.ascontiguousarray(ima, dtype=np.float32)
294 | gra_mag = Deriche_Gradient_Magnitude(ima, alpha, normalize=True)
295 | else:
296 | print(' Gradient magnitude method is invalid!')
297 | end = time()
298 | print(" Gradient magnitude computed in: " + str(int(end-start))
299 | + " seconds.")
300 | return gra_mag
301 |
302 |
303 | def set_gradient_magnitude(image, gramag_option):
304 | """Set gradient magnitude based on the command line flag.
305 |
306 | Parameters
307 | ----------
308 | image : np.ndarray
309 | First image, which is often the intensity image (eg. T1w).
310 | gramag_option : string
311 | A keyword string or a path to a nifti file.
312 |
313 | Returns
314 | -------
315 | gra_mag : np.ndarray
316 | Gradient magnitude image, which has the same shape as image.
317 |
318 | """
319 | if gramag_option not in cfg.gramag_options:
320 | print("Selected gradient magnitude method is not available,"
321 | + " interpreting as a file path...")
322 | gra_mag_nii = load(gramag_option)
323 | gra_mag = np.squeeze(gra_mag_nii.get_fdata())
324 | gra_mag, _ = check_data(gra_mag, cfg.force_original_precision)
325 | gra_mag, _, _ = truncate_range(gra_mag, percMin=cfg.perc_min,
326 | percMax=cfg.perc_max)
327 | gra_mag = scale_range(gra_mag, scale_factor=cfg.scale, delta=0.0001)
328 |
329 | else:
330 | print('{} gradient method is selected.'.format(gramag_option.title()))
331 | gra_mag = compute_gradient_magnitude(image, method=gramag_option)
332 | return gra_mag
333 |
334 |
335 | def export_gradient_magnitude_image(img, filename, filtername, affine):
336 | """Export computed gradient magnitude image as a nifti file."""
337 | basename = filename.split(os.extsep, 1)[0]
338 | out_img = Nifti1Image(img, affine=affine)
339 | if filtername == 'deriche': # add alpha as suffix for extra information
340 | filtername = '{}_alpha{}'.format(filtername.title(),
341 | cfg.deriche_alpha)
342 | filtername = filtername.replace('.', 'pt')
343 | else:
344 | filtername = filtername.title()
345 | out_path = '{}_GraMag{}.nii.gz'.format(basename, filtername)
346 | print("Exporting gradient magnitude image...")
347 | save(out_img, out_path)
348 | print(' Gradient magnitude image exported in this path:\n ' + out_path)
349 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Segmentator setup.
2 |
3 | To install for development, using the commandline do:
4 | pip install -e /path/to/segmentator
5 |
6 | """
7 |
8 | from setuptools import setup
9 | from setuptools.extension import Extension
10 | from os.path import join
11 | import numpy
12 |
13 | ext_modules = [Extension(
14 | "segmentator.deriche_3D", [join('segmentator', 'cython', 'deriche_3D.c')],
15 | include_dirs=[numpy.get_include()])
16 | ]
17 |
18 | setup(name='segmentator',
19 | version='1.6.1',
20 | description=('Multi-dimensional data exploration and segmentation for 3D \
21 | images.'),
22 | url='https://github.com/ofgulban/segmentator',
23 | author='Omer Faruk Gulban',
24 | author_email='faruk.gulban@maastrichtuniversity.nl',
25 | license='BSD-3-clause',
26 | packages=['segmentator'],
27 | install_requires=['numpy>=1.17', 'matplotlib>=3.1', 'scipy>=1.3', 'compoda>=0.3'],
28 | keywords=['mri', 'segmentation', 'image', 'voxel'],
29 | zip_safe=True,
30 | entry_points={
31 | 'console_scripts': [
32 | 'segmentator = segmentator.__main__:main',
33 | 'segmentator_filters = segmentator.filters_ui:main',
34 | ]},
35 | ext_modules=ext_modules,
36 | )
37 |
--------------------------------------------------------------------------------
/visuals/animation_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ofgulban/segmentator/01f46b36152734d67495417528fcc384eb567732/visuals/animation_01.gif
--------------------------------------------------------------------------------
/visuals/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ofgulban/segmentator/01f46b36152734d67495417528fcc384eb567732/visuals/logo.png
--------------------------------------------------------------------------------
/visuals/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
830 |
--------------------------------------------------------------------------------