├── .DS_Store ├── .appveyor.yml ├── .coveragerc ├── .gitignore ├── .travis.yml ├── CITATION.cff ├── LICENSE ├── README.md ├── contributing.md ├── requirements.txt ├── sample_data └── readme.md ├── segmentator ├── __init__.py ├── __main__.py ├── config.py ├── config_filters.py ├── config_gui.py ├── cython │ ├── deriche_3D.c │ └── deriche_3D.pyx ├── deriche_prepare.py ├── filter.py ├── filters_ui.py ├── filters_utils.py ├── future │ ├── readme.md │ ├── wip_arcweld_mp2rage.py │ └── wip_arcweld_mprage.py ├── gui_utils.py ├── hist2d_counts.py ├── ncut_prepare.py ├── segmentator_main.py ├── segmentator_ncut.py ├── tests │ ├── test_utils.py │ ├── wip_test_arcweld.py │ └── wip_test_gradient_magnitude.py └── utils.py ├── setup.py └── visuals ├── animation_01.gif ├── logo.png └── logo.svg /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofgulban/segmentator/01f46b36152734d67495417528fcc384eb567732/.DS_Store -------------------------------------------------------------------------------- /.appveyor.yml: -------------------------------------------------------------------------------- 1 | build: false 2 | 3 | environment: 4 | matrix: 5 | - PYTHON_VERSION: 3.7 6 | MINICONDA: C:\Miniconda 7 | 8 | init: 9 | - "ECHO %PYTHON_VERSION% %MINICONDA%" 10 | 11 | install: 12 | - "set PATH=%MINICONDA%;%MINICONDA%\\Scripts;%PATH%" 13 | - conda config --set always_yes yes --set changeps1 no 14 | - conda update -q conda 15 | - conda config --add channels conda-forge 16 | - conda info -a 17 | - "conda env create -q -n segmentator python=%PYTHON_VERSION% --file requirements.txt" 18 | - activate segmentator 19 | - pip install compoda 20 | - python setup.py install 21 | 22 | test_script: 23 | - py.test 24 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = segmentator 3 | 4 | omit = 5 | *__init__* 6 | *__main__* 7 | *config* 8 | *tests/* 9 | segmentator/future/* 10 | segmentator/cython/* 11 | segmentator/sample_data/* 12 | *ui.py 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | #Ipython Notebook 62 | .ipynb_checkpoints 63 | 64 | bin/ 65 | 66 | # MacOSX Stuff 67 | # General 68 | .DS_Store 69 | .AppleDouble 70 | .LSOverride 71 | 72 | # Icon must end with two \r 73 | Icon 74 | 75 | 76 | # Thumbnails 77 | ._* 78 | 79 | # Files that might appear in the root of a volume 80 | .DocumentRevisions-V100 81 | .fseventsd 82 | .Spotlight-V100 83 | .TemporaryItems 84 | .Trashes 85 | .VolumeIcon.icns 86 | .com.apple.timemachine.donotpresent 87 | 88 | # Directories potentially created on remote AFP share 89 | .AppleDB 90 | .AppleDesktop 91 | Network Trash Folder 92 | Temporary Items 93 | .apdisk -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | os: 3 | - linux 4 | - osx 5 | matrix: 6 | allow_failures: 7 | - os: osx 8 | language: python 9 | python: 10 | - 2.7 11 | install: # command to install dependencies 12 | - pip install -r requirements.txt 13 | - python setup.py develop 14 | script: # command to run tests 15 | - py.test --cov=./segmentator 16 | after_success: 17 | - bash <(curl -s https://codecov.io/bash) 18 | notifications: 19 | email: false 20 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Gulban 5 | given-names: Omer Faruk 6 | orcid: https://orcid.org/0000-0001-7761-3727 7 | 8 | - family-names: Schneider 9 | given-names: Marian 10 | orcid: http://orcid.org/0000-0003-3192-5316 11 | 12 | - family-names: Marquardt 13 | given-names: Ingo 14 | orcid: http://orcid.org/0000-0001-5178-9951 15 | 16 | - family-names: Haast 17 | given-names: Roy 18 | orcid: http://orcid.org/0000-0001-8543-2467 19 | 20 | - family-names: De Martino 21 | given-names: Federico 22 | orcid: https://orcid.org/0000-0002-0352-0648 23 | 24 | title: "A scalable method to improve gray matter segmentation at ultra high field MRI" 25 | version: 1.6.0 26 | doi: https://doi.org/10.1371/journal.pone.0198335 27 | date-released: 2019-09-26 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Omer Faruk Gulban and Marian Schneider 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/59303623.svg)](https://zenodo.org/badge/latestdoi/59303623) 2 | 3 | # Segmentator 4 | 5 | 6 | 7 | Segmentator is a free and open-source package for multi-dimensional data exploration and segmentation for 3D images. This application is mainly developed and tested using ultra-high field magnetic resonance imaging (MRI) brain data. 8 | 9 | 10 | The goal is to provide a complementary tool to the already available brain tissue segmentation methods (to the best of our knowledge) in other software packages ([FSL](https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/), [CBS-Tools](https://www.cbs.mpg.de/institute/software/cbs-tools), [ITK-SNAP](http://www.itksnap.org/pmwiki/pmwiki.php), [Freesurfer](https://surfer.nmr.mgh.harvard.edu/), [SPM](http://www.fil.ion.ucl.ac.uk/spm/software/spm12/), [Brainvoyager](http://www.brainvoyager.com/), etc.). 11 | 12 | ### Citation: 13 | - Our paper can be accessed from __[this link.](https://doi.org/10.1371/journal.pone.0198335)__ 14 | - Released versions of this package can be cited by using our __[Zenodo DOI](https://zenodo.org/badge/latestdoi/59303623).__ 15 | 16 | 17 | 18 | ## Core dependencies 19 | **[Python 3.6](https://www.python.org/downloads/release/python-363/)** or **[Python 2.7](https://www.python.org/download/releases/2.7/)** (compatible with both). 20 | 21 | | Package | Tested version | 22 | |------------------------------------------------|----------------| 23 | | [matplotlib](http://matplotlib.org/) | 3.1.1 | 24 | | [NumPy](http://www.numpy.org/) | 1.22.0 | 25 | | [NiBabel](http://nipy.org/nibabel/) | 2.5.1 | 26 | | [SciPy](http://scipy.org/) | 1.3.1 | 27 | | [Compoda](https://github.com/ofgulban/compoda) | 0.3.5 | 28 | 29 | ## Installation & Quick Start 30 | - Download [the latest release](https://github.com/ofgulban/segmentator/releases) and unzip it. 31 | - Change directory in your command line: 32 | ``` 33 | cd /path/to/segmentator 34 | ``` 35 | - Install the requirements by running the following command: 36 | ``` 37 | pip install -r requirements.txt 38 | ``` 39 | - Install Segmentator: 40 | ``` 41 | python setup.py install 42 | ``` 43 | - Simply call segmentator with a nifti file: 44 | ``` 45 | segmentator /path/to/file.nii.gz 46 | ``` 47 | - Or see the help for available options: 48 | ``` 49 | segmentator --help 50 | ``` 51 | 52 | Check out __[our wiki](https://github.com/ofgulban/segmentator/wiki)__ for further details such as [GUI controls](https://github.com/ofgulban/segmentator/wiki/Controls), [alternative installation methods](https://github.com/ofgulban/segmentator/wiki/Installation) and more... 53 | 54 | ## Support 55 | Please use [GitHub issues](https://github.com/ofgulban/segmentator/issues) for questions, bug reports or feature requests. 56 | 57 | ## License 58 | Copyright © 2019, [Omer Faruk Gulban](https://github.com/ofgulban) and [Marian Schneider](https://github.com/MSchnei). 59 | This project is licensed under [BSD-3-Clause](https://opensource.org/licenses/BSD-3-Clause). 60 | 61 | ## References 62 | This application is mainly based on the following work: 63 | 64 | * Kniss, J., Kindlmann, G., & Hansen, C. D. (2005). Multidimensional transfer functions for volume rendering. Visualization Handbook, 189–209. 65 | 66 | ## Acknowledgements 67 | Since early 2020, development and maintenance of this project is being actively supported by [Brain Innovation](https://www.brainvoyager.com/) as the main developer ([Omer Faruk Gulban](https://github.com/ofgulban)) works there. 68 | -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | If you want to contribute to Segmentator and make it better, your help is very welcome. 4 | 5 | ## Opening an issue 6 | 7 | - Please post your quesions, bug reports and feature requests [here](https://github.com/ofgulban/segmentator/issues). 8 | - When reporting a bug, please specify your system and a complete copy of the error message that you are getting. 9 | 10 | ## Making a pull request 11 | 12 | - Create a personal fork of the project on Github. 13 | - Clone the fork on your local machine. 14 | - Create a new branch to work on. Branch from `devel`. 15 | - Implement/fix your feature, comment your code. 16 | - Follow the code style of the project. 17 | - Add or change the documentation as needed. 18 | - Push your branch to your fork on Github. 19 | - From your fork open a pull request in the correct branch. Target `devel` branch of the original Segmentator repository. 20 | 21 | __Note:__ Please write your commit messages in the present tense. Your commit message should describe what the commit, when applied, does to the code – not what you did to the code. 22 | 23 | ___ 24 | This guideline is adapted from [link](https://github.com/MarcDiethelm/contributing/blob/master/README.md) 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | scipy==1.3.1 3 | matplotlib==3.1.1 4 | nibabel==2.5.1 5 | pytest-cov==2.7.1 6 | compoda==0.3.5 7 | -------------------------------------------------------------------------------- /sample_data/readme.md: -------------------------------------------------------------------------------- 1 | The dataset used for developing Segmentator can be found at: 2 | 3 | [https://doi.org/10.5281/zenodo.1117858](https://doi.org/10.5281/zenodo.1117858) 4 | 5 | -------------------------------------------------------------------------------- /segmentator/__init__.py: -------------------------------------------------------------------------------- 1 | """Make the version number available.""" 2 | 3 | import pkg_resources # part of setuptools 4 | 5 | __version__ = pkg_resources.require("segmentator")[0].version 6 | -------------------------------------------------------------------------------- /segmentator/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Entry point. 3 | 4 | Mostly following this example: 5 | https://chriswarrick.com/blog/2014/09/15/python-apps-the-right-way-entry_points-and-scripts/ 6 | 7 | Use config.py to hold arguments to be accessed by imported scripts. 8 | """ 9 | 10 | from __future__ import print_function 11 | import argparse 12 | import segmentator.config as cfg 13 | from segmentator import __version__ 14 | 15 | 16 | def main(): 17 | """Command line call argument parsing.""" 18 | # Instantiate argument parser object: 19 | parser = argparse.ArgumentParser() 20 | 21 | # Add arguments to namespace: 22 | parser.add_argument( 23 | 'filename', metavar='path', 24 | help="Path to input. Mostly a nifti file with image data." 25 | ) 26 | parser.add_argument( 27 | "--gramag", metavar=str(cfg.gramag), required=False, 28 | default=cfg.gramag, 29 | help="'scharr', 'deriche', 'sobel', 'prewitt', 'numpy' \ 30 | or path to a gradient magnitude nifti." 31 | ) 32 | # used in Deriche filter gradient magnitude computation 33 | parser.add_argument( 34 | "--deriche_alpha", required=False, type=float, 35 | default=cfg.deriche_alpha, metavar=cfg.deriche_alpha, 36 | help="Used only in Deriche gradient magnitude option. Smaller alpha \ 37 | values suppress more noise but can dislocate edges. Useful when there \ 38 | is strong noise in the input image or the features of interest are at \ 39 | a different scale compared to original image resolution." 40 | ) 41 | parser.add_argument( 42 | "--scale", metavar=str(cfg.scale), required=False, type=float, 43 | default=cfg.scale, 44 | help="Determines nr of bins. Data is scaled between 0 to this number." 45 | ) 46 | parser.add_argument( 47 | "--percmin", metavar=str(cfg.perc_min), required=False, type=float, 48 | default=cfg.perc_min, 49 | help="Minimum percentile used in truncation." 50 | ) 51 | parser.add_argument( 52 | "--percmax", metavar=str(cfg.perc_max), required=False, type=float, 53 | default=cfg.perc_max, 54 | help="Maximum percentile used in truncation." 55 | ) 56 | parser.add_argument( 57 | "--valmin", metavar=str(cfg.valmin), required=False, type=float, 58 | default=cfg.valmin, 59 | help="Minimum value, overwrites percentile." 60 | ) 61 | parser.add_argument( 62 | "--valmax", metavar=str(cfg.valmax), required=False, type=float, 63 | default=cfg.valmax, 64 | help="Maximum value, overwrites percentile." 65 | ) 66 | parser.add_argument( 67 | "--cbar_max", metavar=str(cfg.cbar_max), required=False, type=float, 68 | default=cfg.cbar_max, 69 | help="Maximum value (power of 10) of the colorbar slider." 70 | ) 71 | parser.add_argument( 72 | "--cbar_init", metavar=str(cfg.cbar_init), required=False, 73 | type=float, default=cfg.cbar_init, 74 | help="Initial value (power of 10) of the colorbar slider. \ 75 | Also used with --ncut_prepare flag." 76 | ) 77 | parser.add_argument( 78 | "--ncut", metavar='path', required=False, 79 | help="Path to nyp file with ncut labels. Initiates N-cut GUI mode." 80 | ) 81 | parser.add_argument( 82 | "--nogui", action='store_true', 83 | help="Only save 2D histogram image without showing GUI." 84 | ) 85 | parser.add_argument( 86 | "--include_zeros", action='store_true', 87 | help="Include image zeros in histograms. Not used by default." 88 | ) 89 | parser.add_argument( 90 | "--export_gramag", action='store_true', 91 | help="Export the gradient magnitude image. Not used by default." 92 | ) 93 | parser.add_argument( 94 | "--force_original_precision", action='store_true', 95 | help="Do not change the data type of the input image. Can be useful \ 96 | for very large images. Off by default." 97 | ) 98 | parser.add_argument( 99 | "--matplotlib_backend", metavar=str(cfg.matplotlib_backend), 100 | default=cfg.matplotlib_backend, required=False, 101 | help="Change in case of issues during startup or visual glitches. \ 102 | Some options are: qt5agg, qt4agg, wxagg, webagg." 103 | ) 104 | 105 | # used in ncut preparation 106 | parser.add_argument( 107 | "--ncut_prepare", action='store_true', 108 | help=("------------------(utility feature)------------------ \ 109 | Use this flag with the following arguments:") 110 | ) 111 | parser.add_argument( 112 | "--ncut_figs", action='store_true', 113 | help="Figures are presented (useful for debugging)." 114 | ) 115 | parser.add_argument( 116 | "--ncut_maxRec", required=False, type=int, 117 | default=cfg.max_rec, metavar=cfg.max_rec, 118 | help="Maximum number of recursions." 119 | ) 120 | parser.add_argument( 121 | "--ncut_nrSupPix", required=False, type=int, 122 | default=cfg.nr_sup_pix, metavar=cfg.nr_sup_pix, 123 | help="Number of regions/superpixels." 124 | ) 125 | parser.add_argument( 126 | "--ncut_compactness", required=False, type=float, 127 | default=cfg.compactness, metavar=cfg.compactness, 128 | help="Compactness balances intensity proximity and space \ 129 | proximity of the superpixels. \ 130 | Higher values give more weight to space proximity, making \ 131 | superpixel shapes more square/cubic. This parameter \ 132 | depends strongly on image contrast and on the shapes of \ 133 | objects in the image." 134 | ) 135 | 136 | # set cfg file variables to be accessed from other scripts 137 | args = parser.parse_args() 138 | # used in all 139 | cfg.filename = args.filename 140 | # used in segmentator GUI (main and ncut) 141 | cfg.gramag = args.gramag 142 | cfg.scale = args.scale 143 | cfg.perc_min = args.percmin 144 | cfg.perc_max = args.percmax 145 | cfg.valmin = args.valmin 146 | cfg.valmax = args.valmax 147 | cfg.cbar_max = args.cbar_max 148 | cfg.cbar_init = args.cbar_init 149 | if args.include_zeros: 150 | cfg.discard_zeros = False 151 | cfg.export_gramag = args.export_gramag 152 | cfg.force_original_precision = args.force_original_precision 153 | cfg.matplotlib_backend = args.matplotlib_backend 154 | # used in ncut preparation 155 | cfg.ncut_figs = args.ncut_figs 156 | cfg.max_rec = args.ncut_maxRec 157 | cfg.nr_sup_pix = args.ncut_nrSupPix 158 | cfg.compactness = args.ncut_compactness 159 | # used in ncut 160 | cfg.ncut = args.ncut 161 | # used in deriche filter 162 | cfg.deriche_alpha = args.deriche_alpha 163 | 164 | welcome_str = 'Segmentator {}'.format(__version__) 165 | welcome_decor = '=' * len(welcome_str) 166 | print('{}\n{}\n{}'.format(welcome_decor, welcome_str, welcome_decor)) 167 | 168 | # Call other scripts with import method (couldn't find a better way). 169 | if args.nogui: 170 | print('No GUI option is selected. Saving 2D histogram image...') 171 | import segmentator.hist2d_counts 172 | elif args.ncut_prepare: 173 | print('Preparing N-cut file...') 174 | import segmentator.ncut_prepare 175 | elif args.ncut: 176 | print('N-cut GUI is selected.') 177 | import segmentator.segmentator_ncut 178 | else: 179 | print('Default GUI is selected.') 180 | import segmentator.segmentator_main 181 | 182 | 183 | if __name__ == "__main__": 184 | main() 185 | -------------------------------------------------------------------------------- /segmentator/config.py: -------------------------------------------------------------------------------- 1 | """This file contains variables that are shared by several modules. 2 | 3 | Useful to hold command line arguments. 4 | 5 | """ 6 | 7 | # Define variables used to initialise the sector mask 8 | init_centre = (0, 0) 9 | init_radius = 100 10 | init_theta = (0, 360) 11 | 12 | # Segmentator main command line variables 13 | filename = 'sample_filename_here' 14 | gramag = 'scharr' 15 | deriche_alpha = 3.0 16 | perc_min = 2.5 17 | perc_max = 97.5 18 | valmin = float('nan') 19 | valmax = float('nan') 20 | scale = 400 21 | cbar_max = 5.0 22 | cbar_init = 3.0 23 | discard_zeros = True 24 | export_gramag = False 25 | force_original_precision = False 26 | 27 | # Change in case of glitches in the host operating system 28 | matplotlib_backend = 'tkagg' 29 | 30 | # Possible gradient magnitude computation keyword options 31 | gramag_options = ['scharr', 'sobel', 'prewitt', 'numpy', 'deriche'] 32 | 33 | # Used in segmentator ncut 34 | ncut = False 35 | max_rec = 8 36 | nr_sup_pix = 2500 37 | compactness = 2 38 | -------------------------------------------------------------------------------- /segmentator/config_filters.py: -------------------------------------------------------------------------------- 1 | """This file contains variables that are used in filters.""" 2 | 3 | filename = 'sample_filename_here' 4 | 5 | # filter defult parameters 6 | smoothing = 'STEDI' 7 | noise_scale = 0.5 8 | feature_scale = 0.5 9 | nr_iterations = 20 10 | save_every = 10 11 | edge_thr = 0.001 12 | gamma = 1 13 | downsampling = 0 14 | no_nonpositive_mask = False 15 | -------------------------------------------------------------------------------- /segmentator/config_gui.py: -------------------------------------------------------------------------------- 1 | """This file contains GUI related variables.""" 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | # Colormap for mask overlays 6 | palette = plt.cm.Reds 7 | palette.set_over('r', 1.0) 8 | palette.set_under('w', 0) 9 | palette.set_bad('m', 1.0) 10 | 11 | # Slider colors 12 | axcolor = '0.875' 13 | hovcolor = '0.975' 14 | -------------------------------------------------------------------------------- /segmentator/cython/deriche_3D.pyx: -------------------------------------------------------------------------------- 1 | """3D Deriche filter implementation.""" 2 | 3 | import numpy as np 4 | cimport numpy as np 5 | 6 | DTYPE = np.int 7 | ctypedef np.int_t DTYPE_t 8 | 9 | def deriche_3D(np.ndarray[float, mode="c", ndim=3] inputData, float alpha=1): 10 | """Reference: Monga et al. 1991.""" 11 | # c definitions 12 | cdef float s, a_0, a_1, a_2, a_3 , b_1, b_2 13 | cdef np.ndarray[double, mode="c", ndim=3] S_p, S_n, R_p, R_n, T_p, T_n 14 | cdef np.ndarray[double, mode="c", ndim=3] S, R, P 15 | cdef int imax = inputData.shape[0] 16 | cdef int jmax = inputData.shape[1] 17 | cdef int kmax = inputData.shape[2] 18 | cdef int i, j, k 19 | 20 | # constants 21 | s = np.power((1 - np.exp(-alpha)), 2) / \ 22 | (1 + 2 * alpha * np.exp(-alpha) - np.exp(-2*alpha)) 23 | a_0 = s 24 | a_1 = s * (alpha - 1) * np.exp(-alpha) 25 | b_1 = -2 * np.exp(-alpha) 26 | b_2 = np.exp(-2 * alpha) 27 | a_2 = a_1 - s * b_1 28 | a_3 = -s * b_2 29 | 30 | # Recursive filter implementation 31 | S_p = np.zeros((imax, jmax, kmax)) 32 | S_n = np.zeros((imax, jmax, kmax)) 33 | for k in range(kmax): 34 | for j in range(jmax): 35 | for i in range(0, imax): 36 | if i > 2: 37 | S_p[i, j, k] = inputData[i-1, j, k] \ 38 | - b_1*S_p[i-1, j, k] \ 39 | - b_2*S_p[i-2, j, k] 40 | i = i*-1 % inputData.shape[0] # inverse index 41 | if i < inputData.shape[0]-2: 42 | S_n[i, j, k] = inputData[i+1, j, k] \ 43 | - b_1*S_n[i+1, j, k] \ 44 | - b_2*S_n[i+2, j, k] 45 | 46 | S = alpha*(S_p - S_n) 47 | 48 | R_p = np.zeros((imax, jmax, kmax)) 49 | R_n = np.zeros((imax, jmax, kmax)) 50 | for k in range(kmax): 51 | for i in range(imax): 52 | for j in range(0, jmax): 53 | if j > 2: 54 | R_p[i, j, k] = a_0*S[i, j, k] + a_1*S[i, j-1, k] \ 55 | - b_1*R_p[i, j-1, k] - b_2*R_p[i, j-2, k] 56 | j = j*-1 % inputData.shape[1] # inverse index 57 | if j < inputData.shape[1]-2: 58 | R_n[i, j, k] = a_2*S[i, j+1, k] + a_3*S[i, j+2, k] \ 59 | - b_1*R_n[i, j+1, k] - b_2*R_n[i, j+2, k] 60 | 61 | R = R_n + R_p 62 | 63 | T_p = np.zeros((imax, jmax, kmax)) 64 | T_n = np.zeros((imax, jmax, kmax)) 65 | for i in range(imax): 66 | for j in range(jmax): 67 | for k in range(0, kmax): 68 | if k > 2: 69 | T_p[i, j, k] = a_0*R[i, j, k] + a_1*R[i, j, k-1] \ 70 | - b_1*T_p[i, j, k-1] - b_2*T_p[i, j, k-2] 71 | k = k*-1 % inputData.shape[2] # inverse index 72 | if k < inputData.shape[2]-2: 73 | T_n[i, j, k] = a_2*R[i, j, k+1] + a_3*R[i, j, k+2] \ 74 | - b_1*T_n[i, j, k+1] - b_2*T_n[i, j, k+2] 75 | 76 | T = T_n + T_p 77 | return T 78 | -------------------------------------------------------------------------------- /segmentator/deriche_prepare.py: -------------------------------------------------------------------------------- 1 | """Calculate gradient magnitude with 3D Deriche filter. 2 | 3 | You can use --graMag flag to pass resulting nifti files from this script. 4 | """ 5 | 6 | from segmentator.deriche_3D import deriche_3D 7 | import numpy as np 8 | 9 | 10 | def Deriche_Gradient_Magnitude(image, alpha, normalize=False, 11 | return_gradients=False): 12 | """Compute Deriche gradient magnitude of a volumetric image.""" 13 | # calculate gradients 14 | image = np.ascontiguousarray(image, dtype=np.float32) 15 | gra_x = deriche_3D(image, alpha=alpha) 16 | image = np.transpose(image, (2, 0, 1)) 17 | image = np.ascontiguousarray(image, dtype=np.float32) 18 | gra_y = deriche_3D(image, alpha=alpha) 19 | gra_y = np.transpose(gra_y, (1, 2, 0)) 20 | image = np.transpose(image, (2, 0, 1)) 21 | image = np.ascontiguousarray(image, dtype=np.float32) 22 | gra_z = deriche_3D(image, alpha=alpha) 23 | gra_z = np.transpose(gra_z, (2, 0, 1)) 24 | 25 | # Put the image gradients in 4D format 26 | gradients = np.array([gra_x, gra_y, gra_z]) 27 | gradients = np.transpose(gradients, (1, 2, 3, 0)) 28 | 29 | if return_gradients: 30 | return gradients 31 | 32 | else: # Deriche gradient magnitude 33 | gra_mag = np.sqrt(np.power(gradients[:, :, :, 0], 2.0) + 34 | np.power(gradients[:, :, :, 1], 2.0) + 35 | np.power(gradients[:, :, :, 2], 2.0)) 36 | if normalize: 37 | min_ima, max_ima = np.percentile(image, [0, 100]) 38 | min_der, max_der = np.percentile(gra_mag, [0, 100]) 39 | range_ima, range_der = max_ima - min_ima, max_der - min_der 40 | 41 | gra_mag = gra_mag * (range_ima / range_der) 42 | return gra_mag 43 | -------------------------------------------------------------------------------- /segmentator/filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Diffusion based image smoothing.""" 3 | 4 | from __future__ import division 5 | import os 6 | import numpy as np 7 | import segmentator.config_filters as cfg 8 | from nibabel import load, Nifti1Image, save 9 | from numpy.linalg import eigh 10 | from scipy.ndimage import gaussian_filter 11 | from time import time 12 | from segmentator.filters_utils import ( 13 | self_outer_product, dot_product_matrix_vector, divergence, 14 | compute_diffusion_weights, construct_diffusion_tensors, 15 | smooth_matrix_image) 16 | from scipy.ndimage.interpolation import zoom 17 | 18 | 19 | def QC_export(image, basename, identifier, nii): 20 | """Quality control exports.""" 21 | out = Nifti1Image(image, affine=nii.affine, header=nii.header) 22 | save(out, '{}_{}.nii.gz'.format(basename, identifier)) 23 | 24 | 25 | # Input 26 | file_name = cfg.filename 27 | 28 | # Primary parameters 29 | MODE = cfg.smoothing 30 | NR_ITER = cfg.nr_iterations 31 | SAVE_EVERY = cfg.save_every 32 | SIGMA = cfg.noise_scale 33 | RHO = cfg.feature_scale 34 | LAMBDA = cfg.edge_thr 35 | 36 | # Secondary parameters 37 | GAMMA = cfg.gamma 38 | ALPHA = 0.001 39 | M = 4 40 | 41 | # Export parameters 42 | identifier = MODE 43 | 44 | # Load data 45 | basename = file_name.split(os.extsep, 1)[0] 46 | nii = load(file_name) 47 | vres = nii.header['pixdim'][1:4] # voxel resolution x y z 48 | norm_vres = [r/min(vres) for r in vres] # normalized voxel resolutions 49 | ima = (nii.get_fdata()).astype('float32') 50 | 51 | if cfg.downsampling > 1: # TODO: work in progress 52 | print(' Applying initial downsampling...') 53 | ima = zoom(ima, 1./cfg.downsampling) 54 | orig = np.copy(ima) 55 | else: 56 | pass 57 | 58 | if cfg.no_nonpositive_mask: # TODO: work in progress 59 | idx_msk_flat = np.ones(ima.size, dtype=bool) 60 | else: # mask out non positive voxels 61 | idx_msk_flat = ima.flatten() > 0 62 | 63 | dims = ima.shape 64 | 65 | # The main loop 66 | start = time() 67 | for t in range(NR_ITER): 68 | iteration = str(t+1).zfill(len(str(NR_ITER))) 69 | print("Iteration: " + iteration) 70 | # Update export parameters 71 | params = '{}_n{}_s{}_r{}_g{}'.format( 72 | identifier, iteration, SIGMA, RHO, GAMMA) 73 | params = params.replace('.', 'pt') 74 | 75 | # Smoothing 76 | if SIGMA == 0: 77 | ima_temp = np.copy(ima) 78 | else: 79 | ima_temp = gaussian_filter( 80 | ima, mode='constant', cval=0.0, 81 | sigma=[SIGMA/norm_vres[0], SIGMA/norm_vres[1], SIGMA/norm_vres[2]]) 82 | 83 | # Compute gradient 84 | gra = np.transpose(np.gradient(ima_temp), [1, 2, 3, 0]) 85 | ima_temp = None 86 | 87 | print(' Constructing structure tensors...') 88 | struct = self_outer_product(gra) 89 | 90 | # Gaussian smoothing on tensor components 91 | struct = smooth_matrix_image(struct, RHO=RHO, vres=norm_vres) 92 | 93 | print(' Running eigen decomposition...') 94 | struct = struct.reshape([np.prod(dims), 3, 3]) 95 | struct = struct[idx_msk_flat, :, :] 96 | eigvals, eigvecs = eigh(struct) 97 | struct = None 98 | 99 | print(' Constructing diffusion tensors...') 100 | mu = compute_diffusion_weights(eigvals, mode=MODE, LAMBDA=LAMBDA, 101 | ALPHA=ALPHA, M=M) 102 | difft = construct_diffusion_tensors(eigvecs, weights=mu) 103 | eigvecs, eigvals, mu = None, None, None 104 | 105 | # Reshape processed voxels (not masked) back to image space 106 | temp = np.zeros([np.prod(dims), 3, 3]) 107 | temp[idx_msk_flat, :, :] = difft 108 | difft = temp.reshape(dims + (3, 3)) 109 | temp = None 110 | 111 | # Weickert, 1998, eq. 1.1 (Fick's law). 112 | negative_flux = dot_product_matrix_vector(difft, gra) 113 | difft, gra = None, None 114 | # Weickert, 1998, eq. 1.2 (continuity equation) 115 | diffusion_difference = divergence(negative_flux) 116 | negative_flux = None 117 | 118 | # Update image (diffuse image using the difference) 119 | ima += GAMMA*diffusion_difference 120 | diffusion_difference = None 121 | 122 | # Convenient exports for intermediate outputs 123 | if (t+1) % SAVE_EVERY == 0 and (t+1) != NR_ITER: 124 | QC_export(ima, basename, params, nii) 125 | duration = time() - start 126 | mins, secs = int(duration / 60), int(duration % 60) 127 | print(' Image saved (took {} min {} sec)'.format(mins, secs)) 128 | 129 | if cfg.downsampling > 1: # TODO: work in progress 130 | print(' Final upsampling...') 131 | residual = ima - orig 132 | residual = zoom(residual, cfg.downsampling) 133 | ima = (nii.get_fdata()).astype('float32') + residual 134 | else: 135 | pass 136 | 137 | 138 | print('Saving final image...') 139 | QC_export(ima, basename, params, nii) 140 | 141 | duration = time() - start 142 | mins, secs = int(duration / 60), int(duration % 60) 143 | print(' Finished (Took: {} min {} sec).'.format(mins, secs)) 144 | -------------------------------------------------------------------------------- /segmentator/filters_ui.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Entry point. 3 | 4 | Use config_filters.py to hold parameters. 5 | """ 6 | 7 | from __future__ import print_function 8 | import argparse 9 | import segmentator.config_filters as cfg 10 | from segmentator import __version__ 11 | 12 | 13 | def main(): 14 | """Command line call argument parsing.""" 15 | # Instantiate argument parser object: 16 | parser = argparse.ArgumentParser() 17 | 18 | # Add arguments to namespace: 19 | parser.add_argument( 20 | 'filename', metavar='path', 21 | help="Path to input. A nifti file with image data." 22 | ) 23 | parser.add_argument( 24 | "--smoothing", metavar=str(cfg.smoothing), required=False, 25 | default=cfg.smoothing, 26 | help="Variants are CURED and STEDI. Use CURED for removing tubular,\ 27 | honeycomb-like structures as well as smoothing isotropic areas.\ 28 | STEDI is the more conservative version that retains more edges." 29 | ) 30 | # parser.add_argument( 31 | # "--edge_thr", metavar=str(cfg.edge_thr), required=False, 32 | # type=float, default=cfg.edge_thr, 33 | # help="Lambda, edge threshold, lower values preserves more edges. Not\ 34 | # used in CURED and STEDI." 35 | # ) 36 | parser.add_argument( 37 | "--noise_scale", metavar=str(cfg.noise_scale), required=False, 38 | type=float, default=cfg.noise_scale, 39 | help="Sigma, determines the spatial scale of the noise that will be \ 40 | corrected. Recommended lower bound is 0.5. Adjusted for each axis \ 41 | to account for non-isotropic voxels. For example, if the selected \ 42 | value is 1 for an image with [0.7, 0.7, 1.4] mm voxels, sigma is \ 43 | adjusted to be [1, 1, 0.5] in the corresponding axes." 44 | ) 45 | parser.add_argument( 46 | "--feature_scale", metavar=str(cfg.feature_scale), required=False, 47 | type=float, default=cfg.noise_scale, 48 | help="Rho, determines the spatial scale of the features that will be \ 49 | enhanced. Recommended lower bound is 0.5. Adjusted for each axis \ 50 | to account for non-isotropic voxels same way as in --noise_scale." 51 | ) 52 | parser.add_argument( 53 | "--gamma", metavar=str(cfg.gamma), required=False, 54 | type=float, default=cfg.gamma, 55 | help="Strength of the updates in every iteration. Recommended range is\ 56 | 0.5 to 2." 57 | ) 58 | parser.add_argument( 59 | "--nr_iterations", metavar=str(cfg.nr_iterations), required=False, 60 | type=int, default=cfg.nr_iterations, 61 | help="Number of maximum iterations. More iterations will produce \ 62 | smoother images." 63 | ) 64 | parser.add_argument( 65 | "--save_every", metavar=str(cfg.save_every), required=False, 66 | type=int, default=cfg.save_every, 67 | help="Save every Nth iterations. Useful to track the effect of \ 68 | smoothing as it evolves." 69 | ) 70 | parser.add_argument( 71 | "--downsampling", metavar=str(cfg.downsampling), required=False, 72 | type=int, default=cfg.downsampling, 73 | help="(!WIP!) Downsampling factor, use integers > 1. E.g. factor of 2 \ 74 | reduces the amount of voxels 8 times." 75 | ) 76 | parser.add_argument( 77 | "--no_nonpositive_mask", action='store_true', 78 | help="(!WIP!) Do not mask out non-positive values." 79 | ) 80 | 81 | # set cfg file variables to be accessed from other scripts 82 | args = parser.parse_args() 83 | cfg.filename = args.filename 84 | cfg.smoothing = args.smoothing 85 | # cfg.edge_thr = args.edge_thr # lambda 86 | cfg.noise_scale = args.noise_scale # sigma 87 | cfg.feature_scale = args.feature_scale # rho 88 | cfg.gamma = args.gamma 89 | cfg.nr_iterations = args.nr_iterations 90 | cfg.save_every = args.save_every 91 | cfg.downsampling = args.downsampling 92 | cfg.no_nonpositive_mask = args.no_nonpositive_mask 93 | 94 | welcome_str = 'Segmentator {}'.format(__version__) 95 | welcome_decor = '=' * len(welcome_str) 96 | print('{}\n{}\n{}'.format(welcome_decor, welcome_str, welcome_decor)) 97 | print('Filters initiated...') 98 | 99 | import segmentator.filter 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /segmentator/filters_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Common functions used in filters.""" 3 | 4 | from __future__ import division 5 | import numpy as np 6 | from scipy.ndimage.filters import gaussian_filter 7 | 8 | 9 | def self_outer_product(vector_field): 10 | """Vectorized computation of outer product. 11 | 12 | Parameters 13 | ---------- 14 | vector_field: np.ndarray, shape(..., 3) 15 | 16 | Returns 17 | ------- 18 | outer: np.ndarray, shape(..., 3, 3) 19 | """ 20 | dims = vector_field.shape 21 | outer = np.repeat(vector_field, dims[-1], axis=-1) 22 | outer *= outer[..., [0, 3, 6, 1, 4, 7, 2, 5, 8]] 23 | outer = outer.reshape(dims[:-1] + (dims[-1], dims[-1])) 24 | return outer 25 | 26 | 27 | def dot_product_matrix_vector(matrix_field, vector_field): 28 | """Vectorized computation of dot product.""" 29 | dims = vector_field.shape 30 | dotp = np.repeat(vector_field, dims[-1], axis=-1) 31 | dotp = dotp.reshape(dims[:-1] + (dims[-1], dims[-1])) 32 | idx_dims = tuple(range(dotp.ndim)) 33 | dotp = dotp.transpose(idx_dims[:-2] + (idx_dims[-1], idx_dims[-2])) 34 | np.multiply(matrix_field, dotp, out=dotp) 35 | dotp = np.sum(dotp, axis=-1) 36 | return dotp 37 | 38 | 39 | def divergence(vector_field): 40 | """Vectorized computation of divergence, also called Laplacian.""" 41 | dims = vector_field.shape 42 | result = np.zeros(dims[:-1]) 43 | for i in range(dims[-1]): 44 | result += np.gradient(vector_field[..., i], axis=i) 45 | return result 46 | 47 | 48 | def compute_diffusion_weights(eigvals, mode, LAMBDA=0.001, ALPHA=0.001, M=4): 49 | """Vectorized computation diffusion weights. 50 | 51 | References 52 | ---------- 53 | - Weickert, J. (1998). Anisotropic diffusion in image processing. 54 | Image Rochester NY, 256(3), 170. 55 | - Mirebeau, J.-M., Fehrenbach, J., Risser, L., & Tobji, S. (2015). 56 | Anisotropic Diffusion in ITK, 1-9. 57 | """ 58 | idx_pos_e2 = eigvals[:, 1] > 0 # positive indices of second eigen value 59 | c = (1. - ALPHA) # related to matrix condition 60 | 61 | if mode in ['EED', 'cEED', 'iEED']: # TODO: I might remove these. 62 | 63 | if mode == 'EED': # edge enhancing diffusion 64 | mu = np.ones(eigvals.shape) 65 | term1 = LAMBDA 66 | term2 = eigvals[idx_pos_e2, 1:] - eigvals[idx_pos_e2, 0, None] 67 | mu[idx_pos_e2, 1:] = 1. - c * np.exp(-(term1/term2)**M) 68 | # weights for the non-positive eigen values 69 | mu[~idx_pos_e2, 2] = ALPHA # surely surfels 70 | 71 | elif mode == 'cEED': # FIXME: Not working at all, for now... 72 | term1 = LAMBDA 73 | term2 = eigvals 74 | mu = 1. - c * np.exp(-(term1/term2)**M) 75 | 76 | elif mode in ['CED', 'cCED']: 77 | 78 | if mode == 'CED': # coherence enhancing diffusion (FIXME: not tested) 79 | mu = np.ones(eigvals.shape) * ALPHA 80 | term1 = LAMBDA 81 | term2 = eigvals[:, 2, None] - eigvals[:, :-1] 82 | mu[:, :-1] = ALPHA + c * np.exp(-(term1/term2)**M) 83 | 84 | elif mode == 'cCED': # conservative coherence enhancing diffusion 85 | mu = np.ones(eigvals.shape) * ALPHA 86 | term1 = LAMBDA + eigvals[:, 0:2] 87 | term2 = eigvals[:, 2, None] - eigvals[:, 0:2] 88 | mu[:, 0:2] = ALPHA + c * np.exp(-(term1/term2)**M) 89 | 90 | elif mode == 'CURED': # NOTE: Somewhat experimental 91 | import compoda.core as coda 92 | mu = np.ones(eigvals.shape) 93 | mu[idx_pos_e2, :] = 1. - coda.closure(eigvals[idx_pos_e2, :]) 94 | 95 | elif mode == 'STEDI': # NOTE: Somewhat more experimental 96 | import compoda.core as coda 97 | mu = np.ones(eigvals.shape) 98 | eigs = eigvals[idx_pos_e2, :] 99 | term1 = coda.closure(eigs) 100 | term2 = np.abs((np.max(term1, axis=-1) - np.min(term1, axis=-1)) - 0.5) 101 | term2 += 0.5 102 | mu[idx_pos_e2, :] = np.abs(term2[:, None] - term1) 103 | 104 | else: 105 | mu = np.ones(eigvals.shape) 106 | print(' Invalid smoothing mesthod. Weights are all set to ones.') 107 | 108 | return mu 109 | 110 | 111 | def construct_diffusion_tensors(eigvecs, weights): 112 | """Vectorized consruction of diffusion tensors.""" 113 | dims = eigvecs.shape 114 | D = np.zeros(dims[:-2] + (dims[-1], dims[-1])) 115 | for i in range(dims[-1]): # weight vectors 116 | D += weights[:, i, None, None] * self_outer_product(eigvecs[..., i]) 117 | return D 118 | 119 | 120 | def smooth_matrix_image(matrix_image, RHO=0, vres=None): 121 | """Gaussian smoothing applied to matrix image.""" 122 | if vres is None: 123 | vres = [1., 1., 1.] 124 | if RHO == 0: 125 | return matrix_image 126 | else: 127 | dims = matrix_image.shape 128 | for x in range(dims[-2]): 129 | for y in range(dims[-1]): 130 | gaussian_filter(matrix_image[..., x, y], 131 | sigma=[RHO/vres[0], RHO/vres[1], RHO/vres[2]], 132 | mode='constant', cval=0.0, 133 | output=matrix_image[..., x, y]) 134 | return matrix_image 135 | -------------------------------------------------------------------------------- /segmentator/future/readme.md: -------------------------------------------------------------------------------- 1 | This is a folder for work in progress components. 2 | -------------------------------------------------------------------------------- /segmentator/future/wip_arcweld_mp2rage.py: -------------------------------------------------------------------------------- 1 | """Full automation experiments for T1w-like data (MPRAGE & MP2RAGE). 2 | 3 | TODO: 4 | - Put anisotropic diffusion based smoothing to segmentator utilities. 5 | - MP2RAGE, 2 type of gray matter giving issues, barycentric weights might 6 | be useful to deal with this issue. 7 | 8 | """ 9 | 10 | import os 11 | import peakutils 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from nibabel import load, Nifti1Image, save 15 | from scipy.ndimage.filters import gaussian_filter1d 16 | from segmentator.utils import compute_gradient_magnitude, aniso_diff_3D 17 | 18 | # load 19 | nii = load('/home/faruk/gdrive/Segmentator/data/faruk/arcweld/mp2rage_S001_restore.nii.gz') 20 | ima = nii.get_fdata() 21 | basename = nii.get_filename().split(os.extsep, 1)[0] 22 | 23 | # non-zero mask 24 | msk = (ima != 0) # TODO: Parametrize 25 | # 26 | ima[msk] = ima[msk] + np.min(ima) 27 | 28 | # aniso. diff. filter 29 | ima = aniso_diff_3D(ima, niter=2, kappa=500, gamma=0.1, option=1) 30 | 31 | # calculate gradient magnitude 32 | gra = compute_gradient_magnitude(ima, method='3D_sobel') 33 | 34 | # save for debugging 35 | # out = Nifti1Image(gra.reshape(nii.shape), affine=nii.affine) 36 | # save(out, basename + '_gra' + '.nii.gz') 37 | 38 | ima_max = np.percentile(ima, 100) 39 | 40 | # reshape for histograms 41 | ima, gra, msk = ima.flatten(), gra.flatten(), msk.flatten() 42 | gra_thr = np.percentile(gra[msk], 20) 43 | 44 | # select low gradient magnitude regime (lr) 45 | msk_lr = (gra < gra_thr) & (msk) 46 | n, bins, _ = plt.hist(ima[msk_lr], 200, range=(0, ima_max)) 47 | 48 | # smooth histogram (good for peak detection) 49 | n = gaussian_filter1d(n, 1) 50 | 51 | # detect 'pure tissue' peaks (TODO: Checks for always finding 3 peaks) 52 | peaks = peakutils.indexes(n, thres=0.01/max(n), min_dist=40) 53 | # peaks = peaks[0:-1] 54 | tissues = [] 55 | for p in peaks: 56 | tissues.append(bins[p]) 57 | tissues = np.array(tissues) 58 | print peaks 59 | print tissues 60 | 61 | # insert zero-max arc 62 | zmax_max = ima_max 63 | zmax_center = (0 + zmax_max) / 2. 64 | zmax_radius = zmax_max - zmax_center 65 | tissues = np.append(tissues, zmax_center) 66 | 67 | # create smooth maps (distance to pure tissue) 68 | voxels = np.vstack([ima, gra]) 69 | soft = [] # to hold soft tissue membership maps 70 | for i, t in enumerate(tissues): 71 | tissue = np.array([t, 0]) 72 | # euclidean distance 73 | edist = np.sqrt(np.sum((voxels - tissue[:, None])**2., axis=0)) 74 | soft.append(edist) 75 | # save intermediate maps 76 | # out = Nifti1Image(edist.reshape(nii.shape), affine=nii.affine) 77 | # save(out, basename + '_t' + str(i) + '.nii.gz') 78 | soft = np.array(soft) 79 | 80 | # interface translation (shift zero circle to another radius) 81 | soft[-1, :] = soft[-1, :] - zmax_radius 82 | # zmax_neg = soft[-1, :] < 0 # voxels fall inside zero-max arc 83 | 84 | # weight zero-max arc to not coincide with pure class regions 85 | # zmax_weight = (gra[zmax_neg] / zmax_radius)**-1 86 | zmax_weight = (gra / zmax_radius)**-1 87 | # soft[-1, :][zmax_neg] = zmax_weight*np.abs(soft[-1, :][zmax_neg]) 88 | soft[-1, :] = zmax_weight*np.abs(soft[-1, :]) 89 | # save for debugging 90 | # out = Nifti1Image(soft[-1, :].reshape(nii.shape), affine=nii.affine) 91 | # save(out, basename + '_zmaxarc' + '.nii.gz') 92 | 93 | # arbitrary weighting (TODO: Can be turned into config file of some sort) 94 | # save these values for MP2RAGE UNI 95 | # soft[0, :] = soft[0, :] * 1 # csf 96 | # soft[-1, :] = soft[-1, :] * 0.5 # zero-max arc 97 | 98 | # hard tissue membership maps 99 | hard = np.argmin(soft, axis=0) 100 | 101 | # append masked out areas 102 | hard = hard + 1 103 | hard[~msk] = 0 104 | 105 | # save hard classification 106 | out = Nifti1Image(hard.reshape(nii.shape), affine=nii.affine) 107 | save(out, basename + '_arcweld' + '.nii.gz') 108 | 109 | # save segmentator polish mask (not GM and WM) 110 | labels = np.unique(hard)[[0, 1, -1]] 111 | polish = np.ones(hard.shape) 112 | polish[hard == labels[0]] = 0 113 | polish[hard == labels[1]] = 0 114 | polish[hard == labels[2]] = 0 115 | out = Nifti1Image(polish.reshape(nii.shape), affine=nii.affine) 116 | save(out, basename + '_arcweld_mask' + '.nii.gz') 117 | print 'Finished.' 118 | 119 | # plt.show() 120 | -------------------------------------------------------------------------------- /segmentator/future/wip_arcweld_mprage.py: -------------------------------------------------------------------------------- 1 | """Full automation experiments for T1w-like data (MPRAGE & MP2RAGE). 2 | 3 | TODO: 4 | - Put anisotropic diffusion based smoothing to segmentator utilities. 5 | - MP2RAGE, 2 type of gray matter giving issues, barycentric weights might 6 | be useful to deal with this issue. 7 | 8 | """ 9 | 10 | import os 11 | import peakutils 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from nibabel import load, Nifti1Image, save 15 | from scipy.ndimage.filters import gaussian_filter1d 16 | from segmentator.utils import compute_gradient_magnitude, aniso_diff_3D 17 | 18 | # load 19 | nii = load('/home/faruk/gdrive/Segmentator/data/faruk/arcweld/mprage_S02_restore.nii.gz') 20 | ima = nii.get_fdata() 21 | basename = nii.get_filename().split(os.extsep, 1)[0] 22 | 23 | # non-zero mask 24 | msk = ima > 0 # TODO: Parametrize 25 | 26 | # aniso. diff. filter 27 | ima = aniso_diff_3D(ima, niter=2, kappa=50, gamma=0.1, option=1) 28 | 29 | # calculate gradient magnitude 30 | gra = compute_gradient_magnitude(ima, method='3D_sobel') 31 | 32 | # save for debugging 33 | # out = Nifti1Image(gra.reshape(nii.shape), affine=nii.affine) 34 | # save(out, basename + '_gra' + '.nii.gz') 35 | 36 | ima_max = np.percentile(ima, 99) 37 | 38 | # reshape for histograms 39 | ima, gra, msk = ima.flatten(), gra.flatten(), msk.flatten() 40 | gra_thr = np.percentile(gra[msk], 20) 41 | 42 | # select low gradient magnitude regime (lr) 43 | msk_lr = (gra < gra_thr) & (msk) 44 | n, bins, _ = plt.hist(ima[msk_lr], 200, range=(0, ima_max)) 45 | 46 | # smooth histogram (good for peak detection) 47 | n = gaussian_filter1d(n, 1) 48 | 49 | # detect 'pure tissue' peaks (TODO: Checks for always finding 3 peaks) 50 | peaks = peakutils.indexes(n, thres=0.01/max(n), min_dist=40) 51 | # peaks = peaks[0:-1] 52 | tissues = [] 53 | for p in peaks: 54 | tissues.append(bins[p]) 55 | tissues = np.array(tissues) 56 | print peaks 57 | print tissues 58 | 59 | # insert zero-max arc 60 | zmax_max = tissues[-1] + tissues[-1] - tissues[1] # first gm and wm 61 | zmax_center = (0 + zmax_max) / 2. 62 | zmax_radius = zmax_max - zmax_center 63 | tissues = np.append(tissues, zmax_center) 64 | 65 | # create smooth maps (distance to pure tissue) 66 | voxels = np.vstack([ima, gra]) 67 | soft = [] # to hold soft tissue membership maps 68 | for i, t in enumerate(tissues): 69 | tissue = np.array([t, 0]) 70 | # euclidean distance 71 | edist = np.sqrt(np.sum((voxels - tissue[:, None])**2., axis=0)) 72 | soft.append(edist) 73 | # save intermediate maps 74 | # out = Nifti1Image(edist.reshape(nii.shape), affine=nii.affine) 75 | # save(out, basename + '_t' + str(i) + '.nii.gz') 76 | soft = np.array(soft) 77 | 78 | # interface translation (shift zero circle to another radius) 79 | soft[-1, :] = soft[-1, :] - zmax_radius 80 | # zmax_neg = soft[-1, :] < 0 # voxels fall inside zero-max arc 81 | 82 | # weight zero-max arc to not coincide with pure class regions 83 | # zmax_weight = (gra[zmax_neg] / zmax_radius)**-1 84 | zmax_weight = (gra / zmax_radius)**-1 85 | # soft[-1, :][zmax_neg] = zmax_weight*np.abs(soft[-1, :][zmax_neg]) 86 | soft[-1, :] = zmax_weight*np.abs(soft[-1, :]) 87 | # save for debugging 88 | # out = Nifti1Image(soft[-1, :].reshape(nii.shape), affine=nii.affine) 89 | # save(out, basename + '_zmaxarc' + '.nii.gz') 90 | 91 | # arbitrary weighting (TODO: Can be turned into config file of some sort) 92 | # save these values for MPRAGE T1w/PDw 93 | # soft[0, :] = soft[0, :] * 0.66 # csf 94 | # # soft[2, :] = soft[2, :] * 1.25 # wm 95 | # soft[3, :] = soft[3, :] * 0.5 # zero-max arcs 96 | 97 | # hard tissue membership maps 98 | hard = np.argmin(soft, axis=0) 99 | 100 | # append masked out areas 101 | hard = hard + 1 102 | hard[~msk] = 0 103 | 104 | # save intermediate maps 105 | out = Nifti1Image(hard.reshape(nii.shape), affine=nii.affine) 106 | save(out, basename + '_arcweld' + '.nii.gz') 107 | 108 | # save segmentator polish mask (not GM and WM) 109 | labels = np.unique(hard)[[0, 1, -1]] 110 | polish = np.ones(hard.shape) 111 | polish[hard == labels[0]] = 0 112 | polish[hard == labels[1]] = 0 113 | polish[hard == labels[2]] = 0 114 | out = Nifti1Image(polish.reshape(nii.shape), affine=nii.affine) 115 | save(out, basename + '_arcweld_mask' + '.nii.gz') 116 | print 'Finished.' 117 | 118 | # plt.show() 119 | -------------------------------------------------------------------------------- /segmentator/gui_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Functions covering the user interaction with the GUI.""" 3 | 4 | from __future__ import division, print_function 5 | import os 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import segmentator.config as cfg 9 | from segmentator.utils import map_2D_hist_to_ima 10 | from nibabel import save, Nifti1Image 11 | from scipy.ndimage.morphology import binary_erosion 12 | 13 | 14 | class responsiveObj: 15 | """Stuff to interact in the user interface.""" 16 | 17 | def __init__(self, **kwargs): 18 | """Initialize variables used acros functions here.""" 19 | if kwargs is not None: 20 | for key, value in kwargs.items(): 21 | setattr(self, key, value) 22 | self.basename = self.nii.get_filename().split(os.extsep, 1)[0] 23 | self.press = None 24 | self.ctrlHeld = False 25 | self.labelNr = 0 26 | self.imaSlcMskSwitch, self.volHistHighlightSwitch = 0, 0 27 | self.TranspVal = 0.5 28 | self.nrExports = 0 29 | self.borderSwitch = 0 30 | self.imaSlc = self.orig[:, :, self.sliceNr] # selected slice 31 | self.cycleCount = 0 32 | self.cycRotHistory = [[0, 0], [0, 0], [0, 0]] 33 | self.highlights = [[], []] # to hold image to histogram circles 34 | 35 | def remapMsks(self, remap_slice=True): 36 | """Update volume histogram to image mapping. 37 | 38 | Parameters 39 | ---------- 40 | remap_slice : bool 41 | Do histogram to image mapping. Used to map displayed slice mask. 42 | 43 | """ 44 | if self.segmType == 'main': 45 | self.volHistMask = self.sectorObj.binaryMask() 46 | self.volHistMask = self.lassoArr(self.volHistMask, self.idxLasso) 47 | self.volHistMaskH.set_data(self.volHistMask) 48 | elif self.segmType == 'ncut': 49 | self.labelContours() 50 | self.volHistMaskH.set_data(self.volHistMask) 51 | self.volHistMaskH.set_extent((0, self.nrBins, self.nrBins, 0)) 52 | # histogram to image mapping 53 | if remap_slice: 54 | temp_slice = self.invHistVolume[:, :, self.sliceNr] 55 | image_slice_shape = self.invHistVolume[:, :, self.sliceNr].shape 56 | if cfg.discard_zeros: 57 | zmask = temp_slice != 0 58 | image_slice_mask = map_2D_hist_to_ima(temp_slice[zmask], 59 | self.volHistMask) 60 | # reshape to image slice shape 61 | self.imaSlcMsk = np.zeros(image_slice_shape) 62 | self.imaSlcMsk[zmask] = image_slice_mask 63 | else: 64 | image_slice_mask = map_2D_hist_to_ima(temp_slice.flatten(), 65 | self.volHistMask) 66 | # reshape to image slice shape 67 | self.imaSlcMsk = image_slice_mask.reshape(image_slice_shape) 68 | 69 | # for optional border visualization 70 | if self.borderSwitch == 1: 71 | self.imaSlcMsk = self.calcImaMaskBrd() 72 | 73 | def updatePanels(self, update_slice=True, update_rotation=False, 74 | update_extent=False): 75 | """Update histogram and image panels.""" 76 | if update_rotation: 77 | self.checkRotation() 78 | if update_extent: 79 | self.updateImaExtent() 80 | if update_slice: 81 | self.imaSlcH.set_data(self.imaSlc) 82 | self.imaSlcMskH.set_data(self.imaSlcMsk) 83 | self.figure.canvas.draw() 84 | 85 | def connect(self): 86 | """Make the object responsive.""" 87 | self.cidpress = self.figure.canvas.mpl_connect( 88 | 'button_press_event', self.on_press) 89 | self.cidrelease = self.figure.canvas.mpl_connect( 90 | 'button_release_event', self.on_release) 91 | self.cidmotion = self.figure.canvas.mpl_connect( 92 | 'motion_notify_event', self.on_motion) 93 | self.cidkeypress = self.figure.canvas.mpl_connect( 94 | 'key_press_event', self.on_key_press) 95 | self.cidkeyrelease = self.figure.canvas.mpl_connect( 96 | 'key_release_event', self.on_key_release) 97 | 98 | def on_key_press(self, event): 99 | """Determine what happens when a keyboard button is pressed.""" 100 | if event.key == 'control': 101 | self.ctrlHeld = True 102 | elif event.key == '1': 103 | self.imaSlcMskIncr(-0.1) 104 | elif event.key == '2': 105 | self.imaSlcMskTransSwitch() 106 | elif event.key == '3': 107 | self.imaSlcMskIncr(0.1) 108 | elif event.key == '4': 109 | self.volHistHighlightTransSwitch() 110 | elif event.key == '5': 111 | self.borderSwitch = (self.borderSwitch + 1) % 2 112 | self.remapMsks() 113 | self.updatePanels(update_slice=False, update_rotation=True, 114 | update_extent=True) 115 | 116 | if self.segmType == 'main': 117 | if event.key == 'up': 118 | self.sectorObj.scale_r(1.05) 119 | self.remapMsks() 120 | self.updatePanels(update_slice=False, update_rotation=True, 121 | update_extent=False) 122 | elif event.key == 'down': 123 | self.sectorObj.scale_r(0.95) 124 | self.remapMsks() 125 | self.updatePanels(update_slice=False, update_rotation=True, 126 | update_extent=False) 127 | elif event.key == 'right': 128 | self.sectorObj.rotate(-10.0) 129 | self.remapMsks() 130 | self.updatePanels(update_slice=False, update_rotation=True, 131 | update_extent=False) 132 | elif event.key == 'left': 133 | self.sectorObj.rotate(10.0) 134 | self.remapMsks() 135 | self.updatePanels(update_slice=True, update_rotation=True, 136 | update_extent=False) 137 | else: 138 | return 139 | 140 | def on_key_release(self, event): 141 | """Determine what happens if key is released.""" 142 | if event.key == 'control': 143 | self.ctrlHeld = False 144 | 145 | def findVoxInHist(self, event): 146 | """Find voxel's location in histogram.""" 147 | self.press = event.xdata, event.ydata 148 | pixel_x = int(np.floor(event.xdata)) 149 | pixel_y = int(np.floor(event.ydata)) 150 | aoi = self.invHistVolume[:, :, self.sliceNr] # array of interest 151 | # Check rotation 152 | cyc_rot = self.cycRotHistory[self.cycleCount][1] 153 | if cyc_rot == 1: # 90 154 | aoi = np.rot90(aoi, axes=(0, 1)) 155 | elif cyc_rot == 2: # 180 156 | aoi = aoi[::-1, ::-1] 157 | elif cyc_rot == 3: # 270 158 | aoi = np.rot90(aoi, axes=(1, 0)) 159 | # Switch x and y voxel to get linear index since not Cartesian!!! 160 | pixelLin = aoi[pixel_y, pixel_x] 161 | # ind2sub 162 | xpix = (pixelLin / self.nrBins) 163 | ypix = (pixelLin % self.nrBins) 164 | # Switch x and y for circle centre since back to Cartesian. 165 | circle_colors = [np.array([8, 48, 107, 255])/255, 166 | np.array([33, 113, 181, 255])/255] 167 | self.highlights[0].append(plt.Circle((ypix, xpix), radius=1, 168 | edgecolor=None, color=circle_colors[0])) 169 | self.highlights[1].append(plt.Circle((ypix, xpix), radius=5, 170 | edgecolor=None, color=circle_colors[1])) 171 | self.axes.add_artist(self.highlights[0][-1]) # small circle 172 | self.axes.add_artist(self.highlights[1][-1]) # large circle 173 | self.figure.canvas.draw() 174 | 175 | def on_press(self, event): 176 | """Determine what happens if mouse button is clicked.""" 177 | if self.segmType == 'main': 178 | if event.button == 1: # left button 179 | if event.inaxes == self.axes: # cursor in left plot (hist) 180 | if self.ctrlHeld is False: # ctrl no 181 | contains = self.contains(event) 182 | if not contains: 183 | print('cursor outside circle mask') 184 | if not contains: 185 | return 186 | # get sector centre x and y positions 187 | x0 = self.sectorObj.cx 188 | y0 = self.sectorObj.cy 189 | # also get cursor x and y position and safe to press 190 | self.press = x0, y0, event.xdata, event.ydata 191 | elif event.inaxes == self.axes2: # cursor in right plot (brow) 192 | self.findVoxInHist(event) 193 | else: 194 | return 195 | elif event.button == 2: # scroll button 196 | if event.inaxes != self.axes: # outside axes 197 | return 198 | # increase/decrease radius of the sector mask 199 | if self.ctrlHeld is False: # ctrl no 200 | self.sectorObj.scale_r(1.05) 201 | self.remapMsks() 202 | self.updatePanels(update_slice=False, update_rotation=True, 203 | update_extent=False) 204 | elif self.ctrlHeld is True: # ctrl yes 205 | self.sectorObj.rotate(10.0) 206 | self.remapMsks() 207 | self.updatePanels(update_slice=False, update_rotation=True, 208 | update_extent=False) 209 | else: 210 | return 211 | elif event.button == 3: # right button 212 | if event.inaxes != self.axes: 213 | return 214 | # rotate the sector mask 215 | if self.ctrlHeld is False: # ctrl no 216 | self.sectorObj.scale_r(0.95) 217 | self.remapMsks() 218 | self.updatePanels(update_slice=False, update_rotation=True, 219 | update_extent=False) 220 | elif self.ctrlHeld is True: # ctrl yes 221 | self.sectorObj.rotate(-10.0) 222 | self.remapMsks() 223 | self.updatePanels(update_slice=False, update_rotation=True, 224 | update_extent=False) 225 | else: 226 | return 227 | elif self.segmType == 'ncut': 228 | if event.button == 1: # left mouse button 229 | if event.inaxes == self.axes: # cursor in left plot (hist) 230 | xbin = int(np.floor(event.xdata)) 231 | ybin = int(np.floor(event.ydata)) 232 | val = self.volHistMask[ybin][xbin] 233 | # increment counterField for values in clicked subfield, at 234 | # the first click the entire field constitutes the subfield 235 | counter = int(self.counterField[ybin][xbin]) 236 | if counter+1 >= self.ima_ncut_labels.shape[2]: 237 | print("already at maximum ncut dimension") 238 | return 239 | self.counterField[( 240 | self.ima_ncut_labels[:, :, counter] == 241 | self.ima_ncut_labels[[ybin], [xbin], counter])] += 1 242 | print("counter:" + str(counter+1)) 243 | # define arrays with old and new labels for later indexing 244 | oLabels = self.ima_ncut_labels[:, :, counter] 245 | nLabels = self.ima_ncut_labels[:, :, counter+1] 246 | # replace old values with new values (in clicked subfield) 247 | self.volHistMask[oLabels == val] = np.copy( 248 | nLabels[oLabels == val]) 249 | self.remapMsks() 250 | self.updatePanels(update_slice=False, update_rotation=True, 251 | update_extent=False) 252 | 253 | elif event.inaxes == self.axes2: # cursor in right plot (brow) 254 | self.findVoxInHist(event) 255 | else: 256 | return 257 | elif event.button == 3: # right mouse button 258 | if event.inaxes == self.axes: # cursor in left plot (hist) 259 | xbin = int(np.floor(event.xdata)) 260 | ybin = int(np.floor(event.ydata)) 261 | val = self.volHistMask[ybin][xbin] 262 | # fetch the slider value to get label nr 263 | self.volHistMask[self.volHistMask == val] = \ 264 | np.copy(self.labelNr) 265 | self.remapMsks() 266 | self.updatePanels(update_slice=False, update_rotation=True, 267 | update_extent=False) 268 | 269 | def on_motion(self, event): 270 | """Determine what happens if mouse button moves.""" 271 | if self.segmType == 'main': 272 | # ... button is pressed 273 | if self.press is None: 274 | return 275 | # ... cursor is in left plot 276 | if event.inaxes != self.axes: 277 | return 278 | # get former sector centre x and y positions, 279 | # cursor x and y positions 280 | x0, y0, xpress, ypress = self.press 281 | # calculate difference betw cursor pos on click 282 | # and new pos dur motion 283 | # switch x0 & y0 cause volHistMask not Cart 284 | dy = event.xdata - xpress 285 | dx = event.ydata - ypress 286 | # update x and y position of sector, 287 | # based on past motion of cursor 288 | self.sectorObj.set_x(x0 + dx) 289 | self.sectorObj.set_y(y0 + dy) 290 | # update masks 291 | self.remapMsks() 292 | self.updatePanels(update_slice=False, update_rotation=True, 293 | update_extent=False) 294 | else: 295 | return 296 | 297 | def on_release(self, event): 298 | """Determine what happens if mouse button is released.""" 299 | self.press = None 300 | # Remove highlight circle 301 | if self.highlights[1]: 302 | self.highlights[1][-1].set_visible(False) 303 | self.figure.canvas.draw() 304 | 305 | def disconnect(self): 306 | """Make the object unresponsive.""" 307 | self.figure.canvas.mpl_disconnect(self.cidpress) 308 | self.figure.canvas.mpl_disconnect(self.cidrelease) 309 | self.figure.canvas.mpl_disconnect(self.cidmotion) 310 | 311 | def updateColorBar(self, val): 312 | """Update slider for scaling log colorbar in 2D hist.""" 313 | histVMax = np.power(10, self.sHistC.val) 314 | plt.clim(vmax=histVMax) 315 | 316 | def updateSliceNr(self): 317 | """Update slice number and the selected slice.""" 318 | self.sliceNr = int(self.sSliceNr.val*self.orig.shape[2]) 319 | self.imaSlc = self.orig[:, :, self.sliceNr] 320 | 321 | def updateImaBrowser(self, val): 322 | """Update image browse.""" 323 | # scale slider value [0,1) to dimension index 324 | self.updateSliceNr() 325 | self.remapMsks() 326 | self.updatePanels(update_slice=True, update_rotation=True, 327 | update_extent=True) 328 | 329 | def updateImaExtent(self): 330 | """Update both image and mask extent in image browser.""" 331 | self.imaSlcH.set_extent((0, self.imaSlc.shape[1], 332 | self.imaSlc.shape[0], 0)) 333 | self.imaSlcMskH.set_extent((0, self.imaSlc.shape[1], 334 | self.imaSlc.shape[0], 0)) 335 | 336 | def cycleView(self, event): 337 | """Cycle through views.""" 338 | self.cycleCount = (self.cycleCount + 1) % 3 339 | # transpose data 340 | self.orig = np.transpose(self.orig, (2, 0, 1)) 341 | # transpose ima2volHistMap 342 | self.invHistVolume = np.transpose(self.invHistVolume, (2, 0, 1)) 343 | # updates 344 | self.updateSliceNr() 345 | self.remapMsks() 346 | self.updatePanels(update_slice=True, update_rotation=True, 347 | update_extent=True) 348 | 349 | def rotateIma90(self, axes=(0, 1)): 350 | """Rotate image slice 90 degrees.""" 351 | self.imaSlc = np.rot90(self.imaSlc, axes=axes) 352 | self.imaSlcMsk = np.rot90(self.imaSlcMsk, axes=axes) 353 | 354 | def changeRotation(self, event): 355 | """Change rotation of image after clicking the button.""" 356 | self.cycRotHistory[self.cycleCount][1] += 1 357 | self.cycRotHistory[self.cycleCount][1] %= 4 358 | self.rotateIma90() 359 | self.updatePanels(update_slice=True, update_rotation=False, 360 | update_extent=True) 361 | 362 | def checkRotation(self): 363 | """Check rotation update if changed.""" 364 | cyc_rot = self.cycRotHistory[self.cycleCount][1] 365 | if cyc_rot == 1: # 90 366 | self.rotateIma90(axes=(0, 1)) 367 | elif cyc_rot == 2: # 180 368 | self.imaSlc = self.imaSlc[::-1, ::-1] 369 | self.imaSlcMsk = self.imaSlcMsk[::-1, ::-1] 370 | elif cyc_rot == 3: # 270 371 | self.rotateIma90(axes=(1, 0)) 372 | 373 | def exportNifti(self, event): 374 | """Export labels in the image browser as a nifti file.""" 375 | print(" Exporting nifti file...") 376 | # put the permuted indices back to their original format 377 | cycBackPerm = (self.cycleCount, (self.cycleCount+1) % 3, 378 | (self.cycleCount+2) % 3) 379 | # assing unique integers (for ncut labels) 380 | out_volHistMask = np.copy(self.volHistMask) 381 | labels = np.unique(self.volHistMask) 382 | intLabels = [i for i in range(labels.size)] 383 | for label, newLabel in zip(labels, intLabels): 384 | out_volHistMask[out_volHistMask == label] = intLabels[newLabel] 385 | # get 3D brain mask 386 | volume_image = np.transpose(self.invHistVolume, cycBackPerm) 387 | if cfg.discard_zeros: 388 | zmask = volume_image != 0 389 | temp_labeled_image = map_2D_hist_to_ima(volume_image[zmask], 390 | out_volHistMask) 391 | out_nii = np.zeros(volume_image.shape) 392 | out_nii[zmask] = temp_labeled_image # put back flat labels 393 | else: 394 | out_nii = map_2D_hist_to_ima(volume_image.flatten(), 395 | out_volHistMask) 396 | out_nii = out_nii.reshape(volume_image.shape) 397 | # save mask image as nii 398 | new_image = Nifti1Image(out_nii, header=self.nii.header, 399 | affine=self.nii.affine) 400 | # get new flex file name and check for overwriting 401 | labels_out = '{}_labels_{}.nii.gz'.format( 402 | self.basename, self.nrExports) 403 | while os.path.isfile(labels_out): 404 | self.nrExports += 1 405 | labels_out = '{}_labels_{}.nii.gz'.format( 406 | self.basename, self.nrExports) 407 | save(new_image, labels_out) 408 | print(" Saved as: {}".format(labels_out)) 409 | 410 | def clearOverlays(self): 411 | """Clear overlaid items such as circle highlights.""" 412 | if self.highlights[0]: 413 | {h.remove() for h in self.highlights[0]} 414 | {h.remove() for h in self.highlights[1]} 415 | self.highlights[0] = [] 416 | 417 | def resetGlobal(self, event): 418 | """Reset stuff.""" 419 | # reset highlights 420 | self.clearOverlays() 421 | # reset color bar 422 | self.sHistC.reset() 423 | # reset transparency 424 | self.TranspVal = 0.5 425 | if self.segmType == 'main': 426 | if self.lassoSwitchCount == 1: # reset only lasso drawing 427 | self.idxLasso = np.zeros(self.nrBins*self.nrBins, dtype=bool) 428 | else: 429 | # reset theta sliders 430 | self.sThetaMin.reset() 431 | self.sThetaMax.reset() 432 | # reset values for mask 433 | self.sectorObj.set_x(cfg.init_centre[0]) 434 | self.sectorObj.set_y(cfg.init_centre[1]) 435 | self.sectorObj.set_r(cfg.init_radius) 436 | self.sectorObj.tmin, self.sectorObj.tmax = np.deg2rad( 437 | cfg.init_theta) 438 | 439 | elif self.segmType == 'ncut': 440 | self.sLabelNr.reset() 441 | # reset ncut labels 442 | self.ima_ncut_labels = np.copy(self.orig_ncut_labels) 443 | # reset values for volHistMask 444 | self.volHistMask = self.ima_ncut_labels[:, :, 0].reshape( 445 | (self.nrBins, self.nrBins)) 446 | # reset counter field 447 | self.counterField = np.zeros((self.nrBins, self.nrBins)) 448 | # reset political borders 449 | self.pltMap = np.zeros((self.nrBins, self.nrBins)) 450 | self.pltMapH.set_data(self.pltMap) 451 | self.updateSliceNr() 452 | self.remapMsks() 453 | self.updatePanels(update_slice=False, update_rotation=True, 454 | update_extent=False) 455 | 456 | def updateThetaMin(self, val): 457 | """Update theta (min) in volume histogram mask.""" 458 | if self.segmType == 'main': 459 | theta_val = self.sThetaMin.val # get theta value from slider 460 | self.sectorObj.theta_min(theta_val) 461 | self.remapMsks() 462 | self.updatePanels(update_slice=False, update_rotation=True, 463 | update_extent=False) 464 | else: 465 | return 466 | 467 | def updateThetaMax(self, val): 468 | """Update theta(max) in volume histogram mask.""" 469 | if self.segmType == 'main': 470 | theta_val = self.sThetaMax.val # get theta value from slider 471 | self.sectorObj.theta_max(theta_val) 472 | self.remapMsks() 473 | self.updatePanels(update_slice=False, update_rotation=True, 474 | update_extent=False) 475 | else: 476 | return 477 | 478 | def exportNyp(self, event): 479 | """Export histogram counts as a numpy array.""" 480 | print(" Exporting numpy file...") 481 | outFileName = '{}_identifier_pcMax{}_pcMin{}_sc{}'.format( 482 | self.basename, cfg.perc_max, cfg.perc_min, int(cfg.scale)) 483 | if self.segmType == 'ncut': 484 | outFileName = outFileName.replace('identifier', 'volHistLabels') 485 | out_data = self.volHistMask 486 | elif self.segmType == 'main': 487 | outFileName = outFileName.replace('identifier', 'volHist') 488 | out_data = self.counts 489 | outFileName = outFileName.replace('.', 'pt') 490 | np.save(outFileName, out_data) 491 | print(" Saved as: {}{}".format(outFileName, '.npy')) 492 | 493 | def updateLabels(self, val): 494 | """Update labels in volume histogram with slider.""" 495 | if self.segmType == 'ncut': 496 | self.labelNr = self.sLabelNr.val 497 | else: # NOTE: might be used in the future 498 | return 499 | 500 | def imaSlcMskIncr(self, incr): 501 | """Update transparency of image mask by increment.""" 502 | if (self.TranspVal + incr >= 0) & (self.TranspVal + incr <= 1): 503 | self.TranspVal += incr 504 | self.imaSlcMskH.set_alpha(self.TranspVal) 505 | self.figure.canvas.draw() 506 | 507 | def imaSlcMskTransSwitch(self): 508 | """Update transparency of image mask to toggle transparency.""" 509 | self.imaSlcMskSwitch = (self.imaSlcMskSwitch+1) % 2 510 | if self.imaSlcMskSwitch == 1: # set imaSlcMsk transp 511 | self.imaSlcMskH.set_alpha(0) 512 | else: # set imaSlcMsk opaque 513 | self.imaSlcMskH.set_alpha(self.TranspVal) 514 | self.figure.canvas.draw() 515 | 516 | def volHistHighlightTransSwitch(self): 517 | """Update transparency of highlights to toggle transparency.""" 518 | self.volHistHighlightSwitch = (self.volHistHighlightSwitch+1) % 2 519 | if self.volHistHighlightSwitch == 1 and self.highlights[0]: 520 | if self.highlights[0]: 521 | {h.set_visible(False) for h in self.highlights[0]} 522 | elif self.volHistHighlightSwitch == 0 and self.highlights[0]: 523 | {h.set_visible(True) for h in self.highlights[0]} 524 | self.figure.canvas.draw() 525 | 526 | def updateLabelsRadio(self, val): 527 | """Update labels with radio buttons.""" 528 | labelScale = self.lMax / 6. # nr of non-zero radio buttons 529 | self.labelNr = int(float(val) * labelScale) 530 | 531 | def labelContours(self): 532 | """Plot political borders used in ncut version.""" 533 | grad = np.gradient(self.volHistMask) 534 | self.pltMap = np.greater(np.sqrt(np.power(grad[0], 2) + 535 | np.power(grad[1], 2)), 0) 536 | self.pltMapH.set_data(self.pltMap) 537 | self.pltMapH.set_extent((0, self.nrBins, self.nrBins, 0)) 538 | 539 | def lassoArr(self, array, indices): 540 | """Update lasso volume histogram mask.""" 541 | lin = np.arange(array.size) 542 | newArray = array.flatten() 543 | newArray[lin[indices]] = True 544 | return newArray.reshape(array.shape) 545 | 546 | def calcImaMaskBrd(self): 547 | """Calculate borders of image mask slice.""" 548 | return self.imaSlcMsk - binary_erosion(self.imaSlcMsk) 549 | 550 | 551 | class sector_mask: 552 | """A pacman-like shape with useful parameters. 553 | 554 | Disclaimer 555 | ---------- 556 | This script is adapted from a stackoverflow post by user ali_m: 557 | [1] http://stackoverflow.com/questions/18352973/mask-a-circular-sector-in-a-numpy-array 558 | 559 | """ 560 | 561 | def __init__(self, shape, centre, radius, angle_range): 562 | """Initialize variables used acros functions here.""" 563 | self.radius, self.shape = radius, shape 564 | self.x, self.y = np.ogrid[:shape[0], :shape[1]] 565 | self.cx, self.cy = centre 566 | self.tmin, self.tmax = np.deg2rad(angle_range) 567 | # ensure stop angle > start angle 568 | if self.tmax < self.tmin: 569 | self.tmax += 2 * np.pi 570 | # convert cartesian to polar coordinates 571 | self.r2 = (self.x - self.cx) * (self.x - self.cx) + ( 572 | self.y-self.cy) * (self.y - self.cy) 573 | self.theta = np.arctan2(self.x - self.cx, self.y - self.cy) - self.tmin 574 | # wrap angles between 0 and 2*pi 575 | self.theta %= 2 * np.pi 576 | 577 | def set_polCrd(self): 578 | """Convert cartesian to polar coordinates.""" 579 | self.r2 = (self.x-self.cx)*(self.x-self.cx) + ( 580 | self.y-self.cy)*(self.y-self.cy) 581 | self.theta = np.arctan2(self.x-self.cx, self.y-self.cy) - self.tmin 582 | # wrap angles between 0 and 2*pi 583 | self.theta %= (2*np.pi) 584 | 585 | def set_x(self, x): 586 | """Set x axis value.""" 587 | self.cx = x 588 | self.set_polCrd() # update polar coordinates 589 | 590 | def set_y(self, y): 591 | """Set y axis value.""" 592 | self.cy = y 593 | self.set_polCrd() # update polar coordinates 594 | 595 | def set_r(self, radius): 596 | """Set radius of the circle.""" 597 | self.radius = radius 598 | 599 | def scale_r(self, scale): 600 | """Scale (multiply) the radius.""" 601 | self.radius = self.radius * scale 602 | 603 | def rotate(self, degree): 604 | """Rotate shape.""" 605 | rad = np.deg2rad(degree) 606 | self.tmin += rad 607 | self.tmax += rad 608 | self.set_polCrd() # update polar coordinates 609 | 610 | def theta_min(self, degree): 611 | """Angle to determine one the cut out piece in circular mask.""" 612 | rad = np.deg2rad(degree) 613 | self.tmin = rad 614 | # ensure stop angle > start angle 615 | if self.tmax <= self.tmin: 616 | self.tmax += 2*np.pi 617 | # ensure stop angle- 2*np.pi NOT > start angle 618 | if self.tmax - 2*np.pi >= self.tmin: 619 | self.tmax -= 2*np.pi 620 | # update polar coordinates 621 | self.set_polCrd() 622 | 623 | def theta_max(self, degree): 624 | """Angle to determine one the cut out piece in circular mask.""" 625 | rad = np.deg2rad(degree) 626 | self.tmax = rad 627 | # ensure stop angle > start angle 628 | if self.tmax <= self.tmin: 629 | self.tmax += 2*np.pi 630 | # ensure stop angle- 2*np.pi NOT > start angle 631 | if self.tmax - 2*np.pi >= self.tmin: 632 | self.tmax -= 2*np.pi 633 | # update polar coordinates 634 | self.set_polCrd() 635 | 636 | def binaryMask(self): 637 | """Return a boolean mask for a circular sector.""" 638 | # circular mask 639 | self.circmask = self.r2 <= self.radius*self.radius 640 | # angular mask 641 | self.anglemask = self.theta <= (self.tmax-self.tmin) 642 | # return binary mask 643 | return self.circmask*self.anglemask 644 | 645 | def contains(self, event): 646 | """Check if a cursor pointer is inside the sector mask.""" 647 | xbin = np.floor(event.xdata) 648 | ybin = np.floor(event.ydata) 649 | Mask = self.binaryMask() 650 | # the next line doesn't follow pep 8 (otherwise it fails) 651 | if Mask[ybin][xbin] is True: # switch x and ybin, volHistMask not Cart 652 | return True 653 | else: 654 | return False 655 | 656 | def draw(self, ax, cmap='Reds', alpha=0.2, vmin=0.1, zorder=0, 657 | interpolation='nearest', origin='lower', extent=[0, 100, 0, 100]): 658 | """Draw sector mask.""" 659 | BinMask = self.binaryMask() 660 | FigObj = ax.imshow(BinMask, cmap=cmap, alpha=alpha, vmin=vmin, 661 | interpolation=interpolation, origin=origin, 662 | extent=extent, zorder=zorder) 663 | return (FigObj, BinMask) 664 | -------------------------------------------------------------------------------- /segmentator/hist2d_counts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Save 2D histogram image without displaying GUI.""" 3 | 4 | from __future__ import print_function 5 | import os 6 | import numpy as np 7 | import segmentator.config as cfg 8 | from segmentator.utils import truncate_range, scale_range, check_data 9 | from segmentator.utils import set_gradient_magnitude, prep_2D_hist 10 | from nibabel import load 11 | 12 | # load data 13 | nii = load(cfg.filename) 14 | basename = nii.get_filename().split(os.extsep, 1)[0] 15 | 16 | # data processing 17 | orig, _ = check_data(nii.get_fdata(), cfg.force_original_precision) 18 | orig, _, _ = truncate_range(orig, percMin=cfg.perc_min, percMax=cfg.perc_max) 19 | orig = scale_range(orig, scale_factor=cfg.scale, delta=0.0001) 20 | gra = set_gradient_magnitude(orig, cfg.gramag) 21 | 22 | # reshape ima (a bit more intuitive for voxel-wise operations) 23 | ima = np.ndarray.flatten(orig) 24 | gra = np.ndarray.flatten(gra) 25 | 26 | counts, _, _, _, _, _ = prep_2D_hist(ima, gra, discard_zeros=cfg.discard_zeros) 27 | outName = '{}_volHist_pcMax{}_pcMin{}_sc{}'.format( 28 | basename, cfg.perc_max, cfg.perc_min, int(cfg.scale)) 29 | outName = outName.replace('.', 'pt') 30 | np.save(outName, counts) 31 | print(' Image saved as:\n {}'.format(outName)) 32 | -------------------------------------------------------------------------------- /segmentator/ncut_prepare.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Normalized graph cuts for segmentator (experimental). 3 | 4 | TODO: Replacing the functionality using scikit-learn? 5 | """ 6 | 7 | import os 8 | import numpy as np 9 | from matplotlib import animation 10 | from matplotlib import pyplot as plt 11 | from skimage.future import graph 12 | from skimage.segmentation import slic 13 | import segmentator.config as cfg 14 | 15 | 16 | def norm_grap_cut(image, max_edge=10000000, max_rec=4, compactness=2, 17 | nrSupPix=2000): 18 | """Normalized graph cut wrapper for 2D numpy arrays. 19 | 20 | Parameters 21 | ---------- 22 | image: np.ndarray (2D) 23 | Volume histogram. 24 | max_edge: float 25 | The maximum possible value of an edge in the RAG. This corresponds 26 | to an edge between identical regions. This is used to put self 27 | edges in the RAG. 28 | compactness: float 29 | From skimage slic_superpixels.py slic function: 30 | Balances color proximity and space proximity. Higher values give 31 | more weight to space proximity, making superpixel shapes more 32 | square/cubic. This parameter depends strongly on image contrast and 33 | on the shapes of objects in the image. 34 | nrSupPix: int, positive 35 | The (approximate) number of superpixels in the region adjacency 36 | graph. 37 | 38 | Returns 39 | ------- 40 | labels2, labels1: np.ndarray (2D) 41 | Segmented volume histogram mask image. Each label has a unique 42 | identifier. 43 | 44 | """ 45 | # scale for uint8 conversion 46 | image = np.round(255 / image.max() * image) 47 | image = image.astype('uint8') 48 | 49 | # scikit implementation expects rgb format (shape: NxMx3) 50 | image = np.tile(image, (3, 1, 1)) 51 | image = np.transpose(image, (1, 2, 0)) 52 | 53 | labels1 = slic(image, compactness=compactness, n_segments=nrSupPix, 54 | sigma=2) 55 | # region adjacency graph (rag) 56 | g = graph.rag_mean_color(img, labels1, mode='similarity_and_proximity') 57 | labels2 = graph.cut_normalized(labels1, g, max_edge=max_edge, 58 | num_cuts=1000, max_rec=max_rec) 59 | return labels2, labels1 60 | 61 | 62 | path = cfg.filename 63 | basename = path.split(os.extsep, 1)[0] 64 | 65 | # load data 66 | img = np.load(path) 67 | # take logarithm of every count to make it similar to what is seen in gui 68 | img = np.log10(img+1.) 69 | 70 | # truncate very high values 71 | img_max = cfg.cbar_init 72 | img[img > img_max] = img_max 73 | 74 | max_recursion = cfg.max_rec 75 | ncut = np.zeros((img.shape[0], img.shape[1], max_recursion + 1)) 76 | for i in range(0, max_recursion + 1): 77 | msk, regions = norm_grap_cut(img, max_rec=i, 78 | nrSupPix=cfg.nr_sup_pix, 79 | compactness=cfg.compactness) 80 | ncut[:, :, i] = msk 81 | 82 | # plots 83 | if cfg.ncut_figs: 84 | fig = plt.figure() 85 | ax1 = fig.add_subplot(121) 86 | ax2 = fig.add_subplot(122) 87 | 88 | # ax1.imshow(img.T, origin="lower", cmap=plt.cm.inferno) 89 | ax1.imshow(regions.T, origin="lower", cmap=plt.cm.inferno) 90 | ax2.imshow(msk.T, origin="lower", cmap=plt.cm.nipy_spectral) 91 | 92 | ax1.set_title('Source') 93 | ax2.set_title('Ncut') 94 | 95 | plt.show() 96 | 97 | fig = plt.figure() 98 | unq = np.unique(msk) 99 | idx = -1 100 | 101 | im = plt.imshow(msk.T, origin="lower", cmap=plt.cm.flag, 102 | animated=True) 103 | 104 | def updatefig(*args): 105 | """Animate the plot.""" 106 | global unq, msk, idx, tmp 107 | idx += 1 108 | idx = idx % ncut.shape[2] 109 | tmp = np.copy(ncut[:, :, idx]) 110 | im.set_array(tmp.T) 111 | return im, 112 | 113 | ani = animation.FuncAnimation(fig, updatefig, interval=750, blit=True) 114 | plt.show() 115 | 116 | # save output 117 | outName = '{}_ncut_sp{}_c{}'.format(basename, cfg.nr_sup_pix, cfg.compactness) 118 | outName = outName.replace('.', 'pt') 119 | np.save(outName, ncut) 120 | print(" Saved as: {}{}".format(outName, '.npy')) 121 | -------------------------------------------------------------------------------- /segmentator/segmentator_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Processing input and plotting.""" 3 | 4 | from __future__ import division, print_function 5 | import numpy as np 6 | import segmentator.config as cfg 7 | import matplotlib 8 | matplotlib.use(cfg.matplotlib_backend) 9 | print("Matplotlib backend: {}".format(matplotlib.rcParams['backend'])) 10 | import matplotlib.pyplot as plt 11 | from matplotlib.colors import LogNorm 12 | from matplotlib.widgets import Slider, Button, LassoSelector 13 | from matplotlib import path 14 | from nibabel import load 15 | from segmentator.utils import map_ima_to_2D_hist, prep_2D_hist 16 | from segmentator.utils import truncate_range, scale_range, check_data 17 | from segmentator.utils import set_gradient_magnitude 18 | from segmentator.utils import export_gradient_magnitude_image 19 | from segmentator.gui_utils import sector_mask, responsiveObj 20 | from segmentator.config_gui import palette, axcolor, hovcolor 21 | 22 | # 23 | """Data Processing""" 24 | nii = load(cfg.filename) 25 | orig, dims = check_data(nii.get_fdata(), cfg.force_original_precision) 26 | # Save min and max truncation thresholds to be used in axis labels 27 | if np.isnan(cfg.valmin) or np.isnan(cfg.valmax): 28 | orig, pMin, pMax = truncate_range(orig, percMin=cfg.perc_min, 29 | percMax=cfg.perc_max) 30 | else: # TODO: integrate this into truncate range function 31 | orig[orig < cfg.valmin] = cfg.valmin 32 | orig[orig > cfg.valmax] = cfg.valmax 33 | pMin, pMax = cfg.valmin, cfg.valmax 34 | 35 | # Continue with scaling the original truncated image and recomputing gradient 36 | orig = scale_range(orig, scale_factor=cfg.scale, delta=0.0001) 37 | gra = set_gradient_magnitude(orig, cfg.gramag) 38 | if cfg.export_gramag: 39 | export_gradient_magnitude_image(gra, nii.get_filename(), cfg.gramag, 40 | nii.affine) 41 | # Reshape for voxel-wise operations 42 | ima = np.copy(orig.flatten()) 43 | gra = gra.flatten() 44 | 45 | # 46 | """Plots""" 47 | print("Preparing GUI...") 48 | # Plot 2D histogram 49 | fig = plt.figure(facecolor='0.775') 50 | ax = fig.add_subplot(121) 51 | 52 | counts, volHistH, d_min, d_max, nr_bins, bin_edges \ 53 | = prep_2D_hist(ima, gra, discard_zeros=cfg.discard_zeros) 54 | 55 | # Set x-y axis range to the same (x-axis range) 56 | ax.set_xlim(d_min, d_max) 57 | ax.set_ylim(d_min, d_max) 58 | ax.set_xlabel("Intensity f(x)") 59 | ax.set_ylabel("Gradient Magnitude f'(x)") 60 | ax.set_title("2D Histogram") 61 | 62 | # Plot colorbar for 2D hist 63 | volHistH.set_norm(LogNorm(vmax=np.power(10, cfg.cbar_init))) 64 | fig.colorbar(volHistH, fraction=0.046, pad=0.04) # magical scaling 65 | 66 | # Plot 3D ima by default 67 | ax2 = fig.add_subplot(122) 68 | sliceNr = int(0.5*dims[2]) 69 | imaSlcH = ax2.imshow(orig[:, :, sliceNr], cmap=plt.cm.gray, vmin=ima.min(), 70 | vmax=ima.max(), interpolation='none', 71 | extent=[0, dims[1], dims[0], 0], zorder=0) 72 | 73 | imaSlcMsk = np.ones(dims[0:2]) 74 | imaSlcMskH = ax2.imshow(imaSlcMsk, cmap=palette, vmin=0.1, 75 | interpolation='none', alpha=0.5, 76 | extent=[0, dims[1], dims[0], 0], zorder=1) 77 | 78 | # Adjust subplots on figure 79 | bottom = 0.30 80 | fig.subplots_adjust(bottom=bottom) 81 | fig.canvas.manager.set_window_title(nii.get_filename()) 82 | plt.axis('off') 83 | 84 | # 85 | """Initialisation""" 86 | # Create first instance of sector mask 87 | sectorObj = sector_mask((nr_bins, nr_bins), cfg.init_centre, cfg.init_radius, 88 | cfg.init_theta) 89 | 90 | # Draw sector mask for the first time 91 | volHistMaskH, volHistMask = sectorObj.draw(ax, cmap=palette, alpha=0.2, 92 | vmin=0.1, interpolation='nearest', 93 | origin='lower', zorder=1, 94 | extent=[0, nr_bins, 0, nr_bins]) 95 | 96 | # Initiate a flexible figure object, pass to it useful properties 97 | idxLasso = np.zeros(nr_bins*nr_bins, dtype=bool) 98 | lassoSwitchCount = 0 99 | lassoErase = 1 # 1 for drawing, 0 for erasing 100 | flexFig = responsiveObj(figure=ax.figure, axes=ax.axes, axes2=ax2.axes, 101 | segmType='main', orig=orig, nii=nii, 102 | sectorObj=sectorObj, 103 | nrBins=nr_bins, 104 | sliceNr=sliceNr, 105 | imaSlcH=imaSlcH, 106 | imaSlcMsk=imaSlcMsk, imaSlcMskH=imaSlcMskH, 107 | volHistMask=volHistMask, volHistMaskH=volHistMaskH, 108 | contains=volHistMaskH.contains, 109 | counts=counts, 110 | idxLasso=idxLasso, 111 | lassoSwitchCount=lassoSwitchCount, 112 | lassoErase=lassoErase) 113 | 114 | # Make the figure responsive to clicks 115 | flexFig.connect() 116 | ima2volHistMap = map_ima_to_2D_hist(xinput=ima, yinput=gra, bins_arr=bin_edges) 117 | flexFig.invHistVolume = np.reshape(ima2volHistMap, dims) 118 | ima, gra = None, None 119 | 120 | # 121 | """Sliders and Buttons""" 122 | # Colorbar slider 123 | axHistC = plt.axes([0.15, bottom-0.20, 0.25, 0.025], facecolor=axcolor) 124 | flexFig.sHistC = Slider(axHistC, 'Colorbar', 1, cfg.cbar_max, 125 | valinit=cfg.cbar_init, valfmt='%0.1f') 126 | 127 | # Image browser slider 128 | axSliceNr = plt.axes([0.6, bottom-0.15, 0.25, 0.025], facecolor=axcolor) 129 | flexFig.sSliceNr = Slider(axSliceNr, 'Slice', 0, 0.999, valinit=0.5, 130 | valfmt='%0.2f') 131 | 132 | # Theta sliders 133 | aThetaMin = plt.axes([0.15, bottom-0.10, 0.25, 0.025], facecolor=axcolor) 134 | flexFig.sThetaMin = Slider(aThetaMin, 'ThetaMin', 0, 359.9, 135 | valinit=cfg.init_theta[0], valfmt='%0.1f') 136 | aThetaMax = plt.axes([0.15, bottom-0.15, 0.25, 0.025], facecolor=axcolor) 137 | flexFig.sThetaMax = Slider(aThetaMax, 'ThetaMax', 0, 359.9, 138 | valinit=cfg.init_theta[1]-0.1, valfmt='%0.1f') 139 | 140 | # Cycle button 141 | cycleax = plt.axes([0.55, bottom-0.2475, 0.075, 0.0375]) 142 | flexFig.bCycle = Button(cycleax, 'Cycle', 143 | color=axcolor, hovercolor=hovcolor) 144 | 145 | # Rotate button 146 | rotateax = plt.axes([0.55, bottom-0.285, 0.075, 0.0375]) 147 | flexFig.bRotate = Button(rotateax, 'Rotate', 148 | color=axcolor, hovercolor=hovcolor) 149 | 150 | # Reset button 151 | resetax = plt.axes([0.65, bottom-0.285, 0.075, 0.075]) 152 | flexFig.bReset = Button(resetax, 'Reset', color=axcolor, hovercolor=hovcolor) 153 | 154 | # Export nii button 155 | exportax = plt.axes([0.75, bottom-0.285, 0.075, 0.075]) 156 | flexFig.bExport = Button(exportax, 'Export\nNifti', 157 | color=axcolor, hovercolor=hovcolor) 158 | 159 | # Export nyp button 160 | exportax = plt.axes([0.85, bottom-0.285, 0.075, 0.075]) 161 | flexFig.bExportNyp = Button(exportax, 'Export\nHist', 162 | color=axcolor, hovercolor=hovcolor) 163 | 164 | # 165 | """Updates""" 166 | flexFig.sHistC.on_changed(flexFig.updateColorBar) 167 | flexFig.sSliceNr.on_changed(flexFig.updateImaBrowser) 168 | flexFig.sThetaMin.on_changed(flexFig.updateThetaMin) 169 | flexFig.sThetaMax.on_changed(flexFig.updateThetaMax) 170 | flexFig.bCycle.on_clicked(flexFig.cycleView) 171 | flexFig.bRotate.on_clicked(flexFig.changeRotation) 172 | flexFig.bExport.on_clicked(flexFig.exportNifti) 173 | flexFig.bExportNyp.on_clicked(flexFig.exportNyp) 174 | flexFig.bReset.on_clicked(flexFig.resetGlobal) 175 | 176 | 177 | # TODO: Temporary solution for displaying original x-y axis labels 178 | def update_axis_labels(event): 179 | """Swap histogram bin indices with original values.""" 180 | xlabels = [item.get_text() for item in ax.get_xticklabels()] 181 | orig_range_labels = np.linspace(pMin, pMax, len(xlabels)) 182 | 183 | # Adjust displayed decimals based on data range 184 | data_range = pMax - pMin 185 | if data_range > 200: # arbitrary value 186 | xlabels = [('%i' % i) for i in orig_range_labels] 187 | elif data_range > 20: 188 | xlabels = [('%.1f' % i) for i in orig_range_labels] 189 | elif data_range > 2: 190 | xlabels = [('%.2f' % i) for i in orig_range_labels] 191 | else: 192 | xlabels = [('%.3f' % i) for i in orig_range_labels] 193 | 194 | ax.set_xticklabels(xlabels) 195 | ax.set_yticklabels(xlabels) # limits of y axis assumed to be the same as x 196 | 197 | 198 | fig.canvas.mpl_connect('resize_event', update_axis_labels) 199 | 200 | # 201 | """Lasso selection""" 202 | # Lasso button 203 | lassoax = plt.axes([0.15, bottom-0.285, 0.075, 0.075]) 204 | bLasso = Button(lassoax, 'Lasso\nOff', color=axcolor, hovercolor=hovcolor) 205 | 206 | # Lasso draw/erase 207 | lassoEraseAx = plt.axes([0.25, bottom-0.285, 0.075, 0.075]) 208 | bLassoErase = Button(lassoEraseAx, 'Erase\nOff', color=axcolor, 209 | hovercolor=hovcolor) 210 | bLassoErase.ax.patch.set_visible(False) 211 | bLassoErase.label.set_visible(False) 212 | bLassoErase.ax.axis('off') 213 | 214 | 215 | def lassoSwitch(event): 216 | """Enable disable lasso tool.""" 217 | global lasso 218 | lasso = [] 219 | flexFig.lassoSwitchCount = (flexFig.lassoSwitchCount+1) % 2 220 | if flexFig.lassoSwitchCount == 1: # enable lasso 221 | flexFig.disconnect() # disable drag function of sector mask 222 | lasso = LassoSelector(ax, onselect) 223 | bLasso.label.set_text("Lasso\nOn") 224 | # Make erase button appear on in lasso mode 225 | bLassoErase.ax.patch.set_visible(True) 226 | bLassoErase.label.set_visible(True) 227 | bLassoErase.ax.axis('on') 228 | 229 | else: # disable lasso 230 | flexFig.connect() # enable drag function of sector mask 231 | bLasso.label.set_text("Lasso\nOff") 232 | # Make erase button disappear 233 | bLassoErase.ax.patch.set_visible(False) 234 | bLassoErase.label.set_visible(False) 235 | bLassoErase.ax.axis('off') 236 | 237 | # Pixel coordinates 238 | pix = np.arange(nr_bins) 239 | xv, yv = np.meshgrid(pix, pix) 240 | pix = np.vstack((xv.flatten(), yv.flatten())).T 241 | 242 | 243 | def onselect(verts): 244 | """Lasso related.""" 245 | global pix 246 | p = path.Path(verts) 247 | newLasIdx = p.contains_points(pix, radius=1.5) # New lasso indices 248 | flexFig.idxLasso[newLasIdx] = flexFig.lassoErase # Update lasso indices 249 | flexFig.remapMsks() # Update volume histogram mask 250 | flexFig.updatePanels(update_slice=False, update_rotation=True, 251 | update_extent=True) 252 | 253 | 254 | def lassoEraseSwitch(event): 255 | """Enable disable lasso erase function.""" 256 | flexFig.lassoErase = (flexFig.lassoErase + 1) % 2 257 | if flexFig.lassoErase == 1: 258 | bLassoErase.label.set_text("Erase\nOff") 259 | elif flexFig.lassoErase == 0: 260 | bLassoErase.label.set_text("Erase\nOn") 261 | 262 | 263 | bLasso.on_clicked(lassoSwitch) # lasso on/off 264 | bLassoErase.on_clicked(lassoEraseSwitch) # lasso erase on/off 265 | flexFig.remapMsks() 266 | flexFig.updatePanels(update_slice=True, update_rotation=False, 267 | update_extent=False) 268 | 269 | print("GUI is ready.") 270 | plt.show() 271 | -------------------------------------------------------------------------------- /segmentator/segmentator_ncut.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Processing input and plotting, for experimental ncut feature. 3 | 4 | TODO: Lots of code repetition, will be integrated better in the future. 5 | """ 6 | 7 | from __future__ import division, print_function 8 | import numpy as np 9 | import segmentator.config as cfg 10 | import matplotlib 11 | matplotlib.use('TkAgg') 12 | import matplotlib.pyplot as plt 13 | from matplotlib.colors import LogNorm, ListedColormap, BoundaryNorm 14 | from matplotlib.widgets import Slider, Button, RadioButtons 15 | from nibabel import load 16 | from segmentator.utils import map_ima_to_2D_hist, prep_2D_hist 17 | from segmentator.utils import truncate_range, scale_range, check_data 18 | from segmentator.utils import set_gradient_magnitude 19 | from segmentator.utils import export_gradient_magnitude_image 20 | from segmentator.gui_utils import responsiveObj 21 | from segmentator.config_gui import palette, axcolor, hovcolor 22 | 23 | # 24 | """Load Data""" 25 | nii = load(cfg.filename) 26 | ncut_labels = np.load(cfg.ncut) 27 | 28 | # transpose the labels 29 | ncut_labels = np.transpose(ncut_labels, (1, 0, 2)) 30 | nrTotal_labels = sum([2**x for x in range(ncut_labels.shape[2])]) 31 | total_labels = np.arange(nrTotal_labels) 32 | total_labels[1::2] = total_labels[-2:0:-2] 33 | 34 | # relabel the labels from ncut, assign ascending integers starting with counter 35 | counter = 0 36 | for ind in np.arange(ncut_labels.shape[2]): 37 | tmp = ncut_labels[:, :, ind] 38 | uniqueVals = np.unique(tmp) 39 | nrUniqueVals = len(uniqueVals) 40 | newInd = np.arange(nrUniqueVals) + counter 41 | newVals = total_labels[newInd] 42 | tmp2 = np.zeros((tmp.shape)) 43 | for ind2, val in enumerate(uniqueVals): 44 | tmp2[tmp == val] = newVals[ind2] 45 | counter = counter + nrUniqueVals 46 | ncut_labels[:, :, ind] = tmp2 47 | lMax = np.max(ncut_labels) 48 | 49 | orig_ncut_labels = ncut_labels.copy() 50 | ima_ncut_labels = ncut_labels.copy() 51 | 52 | 53 | # 54 | """Data Processing""" 55 | orig, dims = check_data(nii.get_fdata(), cfg.force_original_precision) 56 | # Save min and max truncation thresholds to be used in axis labels 57 | orig, pMin, pMax = truncate_range(orig, percMin=cfg.perc_min, 58 | percMax=cfg.perc_max) 59 | # Continue with scaling the original truncated image and recomputing gradient 60 | orig = scale_range(orig, scale_factor=cfg.scale, delta=0.0001) 61 | gra = set_gradient_magnitude(orig, cfg.gramag) 62 | if cfg.export_gramag: 63 | export_gradient_magnitude_image(gra, nii.get_filename(), nii.affine) 64 | 65 | # Reshape ima (more intuitive for voxel-wise operations) 66 | ima = np.ndarray.flatten(orig) 67 | gra = np.ndarray.flatten(gra) 68 | 69 | # 70 | """Plots""" 71 | print("Preparing GUI...") 72 | # Plot 2D histogram 73 | fig = plt.figure(facecolor='0.775') 74 | ax = fig.add_subplot(121) 75 | 76 | counts, volHistH, d_min, d_max, nr_bins, bin_edges \ 77 | = prep_2D_hist(ima, gra, discard_zeros=cfg.discard_zeros) 78 | 79 | ax.set_xlim(d_min, d_max) 80 | ax.set_ylim(d_min, d_max) 81 | ax.set_xlabel("Intensity f(x)") 82 | ax.set_ylabel("Gradient Magnitude f'(x)") 83 | ax.set_title("2D Histogram") 84 | 85 | # Plot map for poltical borders 86 | pltMap = np.zeros((nr_bins, nr_bins, 1)).repeat(4, 2) 87 | cmapPltMap = ListedColormap([[1, 1, 1, 0], # transparent zeros 88 | [0, 0, 0, 0.75], # political borders 89 | [1, 0, 0, 0.5], # other colors for future use 90 | [0, 0, 1, 0.5]]) 91 | boundsPltMap = [0, 1, 2, 3, 4] 92 | normPltMap = BoundaryNorm(boundsPltMap, cmapPltMap.N) 93 | pltMapH = ax.imshow(pltMap, cmap=cmapPltMap, norm=normPltMap, 94 | vmin=boundsPltMap[1], vmax=boundsPltMap[-1], 95 | extent=[0, nr_bins, nr_bins, 0], interpolation='none') 96 | 97 | # Plot colorbar for 2d hist 98 | volHistH.set_norm(LogNorm(vmax=np.power(10, cfg.cbar_init))) 99 | fig.colorbar(volHistH, fraction=0.046, pad=0.04) # magical perfect scaling 100 | 101 | # Set up a colormap for ncut labels 102 | ncut_palette = plt.cm.gist_rainbow 103 | ncut_palette.set_under('w', 0) 104 | 105 | # Plot hist mask (with ncut labels) 106 | volHistMask = np.squeeze(ncut_labels[:, :, 0]) 107 | volHistMaskH = ax.imshow(volHistMask, interpolation='none', 108 | alpha=0.2, cmap=ncut_palette, 109 | vmin=np.min(ncut_labels)+1, # to make 0 transparent 110 | vmax=lMax, 111 | extent=[0, nr_bins, nr_bins, 0]) 112 | 113 | # Plot 3D ima by default 114 | ax2 = fig.add_subplot(122) 115 | sliceNr = int(0.5*dims[2]) 116 | imaSlcH = ax2.imshow(orig[:, :, sliceNr], cmap=plt.cm.gray, 117 | vmin=ima.min(), vmax=ima.max(), interpolation='none', 118 | extent=[0, dims[1], dims[0], 0]) 119 | imaSlcMsk = np.zeros(dims[0:2])*total_labels[1] 120 | imaSlcMskH = ax2.imshow(imaSlcMsk, interpolation='none', alpha=0.5, 121 | cmap=ncut_palette, vmin=np.min(ncut_labels)+1, 122 | vmax=lMax, 123 | extent=[0, dims[1], dims[0], 0]) 124 | 125 | # Adjust subplots on figure 126 | bottom = 0.30 127 | fig.subplots_adjust(bottom=bottom) 128 | fig.canvas.set_window_title(nii.get_filename()) 129 | plt.axis('off') 130 | 131 | 132 | # %% 133 | """Initialisation""" 134 | # Initiate a flexible figure object, pass to it usefull properties 135 | flexFig = responsiveObj(figure=ax.figure, axes=ax.axes, axes2=ax2.axes, 136 | segmType='ncut', orig=orig, nii=nii, ima=ima, 137 | nrBins=nr_bins, 138 | sliceNr=sliceNr, 139 | imaSlcH=imaSlcH, 140 | imaSlcMsk=imaSlcMsk, imaSlcMskH=imaSlcMskH, 141 | volHistMask=volHistMask, 142 | volHistMaskH=volHistMaskH, 143 | pltMap=pltMap, pltMapH=pltMapH, 144 | counterField=np.zeros((nr_bins, nr_bins)), 145 | orig_ncut_labels=orig_ncut_labels, 146 | ima_ncut_labels=ima_ncut_labels, 147 | lMax=lMax) 148 | 149 | # Make the figure responsive to clicks 150 | flexFig.connect() 151 | # Get mapping from image slice to volume histogram 152 | ima2volHistMap = map_ima_to_2D_hist(xinput=ima, yinput=gra, bins_arr=bin_edges) 153 | flexFig.invHistVolume = np.reshape(ima2volHistMap, dims) 154 | 155 | # %% 156 | """Sliders and Buttons""" 157 | axcolor, hovcolor = '0.875', '0.975' 158 | 159 | # Radio buttons (ugly but good enough for now) 160 | rax = plt.axes([0.91, 0.35, 0.08, 0.5], facecolor=(0.75, 0.75, 0.75)) 161 | flexFig.radio = RadioButtons(rax, [str(i) for i in range(7)], 162 | activecolor=(0.25, 0.25, 0.25)) 163 | 164 | # Colorbar slider 165 | axHistC = plt.axes([0.15, bottom-0.230, 0.25, 0.025], facecolor=axcolor) 166 | flexFig.sHistC = Slider(axHistC, 'Colorbar', 1, cfg.cbar_max, 167 | valinit=cfg.cbar_init, valfmt='%0.1f') 168 | 169 | # Label slider 170 | axLabels = plt.axes([0.15, bottom-0.270, 0.25, 0.025], facecolor=axcolor) 171 | flexFig.sLabelNr = Slider(axLabels, 'Labels', 0, lMax, 172 | valinit=lMax, valfmt='%i') 173 | 174 | # Image browser slider 175 | axSliceNr = plt.axes([0.6, bottom-0.15, 0.25, 0.025], facecolor=axcolor) 176 | flexFig.sSliceNr = Slider(axSliceNr, 'Slice', 0, 0.999, 177 | valinit=0.5, valfmt='%0.3f') 178 | 179 | # Cycle button 180 | cycleax = plt.axes([0.55, bottom-0.285, 0.075, 0.075]) 181 | flexFig.bCycle = Button(cycleax, 'Cycle\nView', 182 | color=axcolor, hovercolor=hovcolor) 183 | flexFig.cycleCount = 0 184 | 185 | # Rotate button 186 | rotateax = plt.axes([0.55, bottom-0.285, 0.075, 0.0375]) 187 | flexFig.bRotate = Button(rotateax, 'Rotate', 188 | color=axcolor, hovercolor=hovcolor) 189 | flexFig.rotationCount = 0 190 | 191 | # Export nii button 192 | exportax = plt.axes([0.75, bottom-0.285, 0.075, 0.075]) 193 | flexFig.bExport = Button(exportax, 'Export\nNifti', 194 | color=axcolor, hovercolor=hovcolor) 195 | 196 | # Export nyp button 197 | exportax = plt.axes([0.85, bottom-0.285, 0.075, 0.075]) 198 | flexFig.bExportNyp = Button(exportax, 'Export\nHist', 199 | color=axcolor, hovercolor=hovcolor) 200 | 201 | # Reset button 202 | resetax = plt.axes([0.65, bottom-0.285, 0.075, 0.075]) 203 | flexFig.bReset = Button(resetax, 'Reset', color=axcolor, hovercolor=hovcolor) 204 | 205 | 206 | # %% 207 | """Updates""" 208 | flexFig.sHistC.on_changed(flexFig.updateColorBar) 209 | flexFig.sSliceNr.on_changed(flexFig.updateImaBrowser) 210 | flexFig.sLabelNr.on_changed(flexFig.updateLabels) 211 | flexFig.bCycle.on_clicked(flexFig.cycleView) 212 | flexFig.bRotate.on_clicked(flexFig.changeRotation) 213 | flexFig.bExport.on_clicked(flexFig.exportNifti) 214 | flexFig.bExportNyp.on_clicked(flexFig.exportNyp) 215 | flexFig.bReset.on_clicked(flexFig.resetGlobal) 216 | flexFig.radio.on_clicked(flexFig.updateLabelsRadio) 217 | 218 | 219 | # TODO: Temporary solution for displaying original x-y axis labels 220 | def update_axis_labels(event): 221 | """Swap histogram bin indices with original values.""" 222 | xlabels = [item.get_text() for item in ax.get_xticklabels()] 223 | orig_range_labels = np.linspace(pMin, pMax, len(xlabels)) 224 | 225 | # Adjust displayed decimals based on data range 226 | data_range = pMax - pMin 227 | if data_range > 200: # arbitrary value 228 | xlabels = [('%i' % i) for i in orig_range_labels] 229 | elif data_range > 20: 230 | xlabels = [('%.1f' % i) for i in orig_range_labels] 231 | elif data_range > 2: 232 | xlabels = [('%.2f' % i) for i in orig_range_labels] 233 | else: 234 | xlabels = [('%.3f' % i) for i in orig_range_labels] 235 | 236 | ax.set_xticklabels(xlabels) 237 | ax.set_yticklabels(xlabels) # limits of y axis assumed to be the same as x 238 | 239 | 240 | fig.canvas.mpl_connect('resize_event', update_axis_labels) 241 | 242 | plt.show() 243 | -------------------------------------------------------------------------------- /segmentator/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Test utility functions.""" 2 | 3 | import numpy as np 4 | from segmentator.utils import truncate_range, scale_range 5 | 6 | 7 | def test_truncate_range(): 8 | """Test range truncation.""" 9 | # Given 10 | data = np.random.random(100) 11 | data.ravel()[np.random.choice(data.size, 10, replace=False)] = 0 12 | data.ravel()[np.random.choice(data.size, 5, replace=False)] = np.nan 13 | p_min, p_max = 2.5, 97.5 14 | expected = np.nanpercentile(data, [p_min, p_max]) 15 | # When 16 | output, _, _ = truncate_range(data, percMin=p_min, percMax=p_max, 17 | discard_zeros=False) 18 | # Then 19 | assert all(np.nanpercentile(output, [0, 100]) == expected) 20 | 21 | 22 | def test_scale_range(): 23 | """Test range scaling.""" 24 | # Given 25 | data = np.random.random(100) - 0.5 26 | data.ravel()[np.random.choice(data.size, 10, replace=False)] = 0. 27 | data.ravel()[np.random.choice(data.size, 5, replace=False)] = np.nan 28 | s = 42. # scaling factor 29 | expected = [0., s] # min and max 30 | # When 31 | output = scale_range(data, scale_factor=s, delta=0.01, discard_zeros=False) 32 | # Then 33 | assert all([np.nanmin(output) >= expected[0], 34 | np.nanmax(output) < expected[1]]) 35 | -------------------------------------------------------------------------------- /segmentator/tests/wip_test_arcweld.py: -------------------------------------------------------------------------------- 1 | """Test and demonstsrate arcweld classification. 2 | 3 | TODO: Turn this into unit tests for arcweld. 4 | 5 | """ 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | # create toy data 11 | res = 500 # resolution 12 | data = np.mgrid[0:res, 0:res].astype('float') 13 | ima, gra = data[0, :, :].flatten(), data[1, :, :].flatten() 14 | dims = data.shape 15 | 16 | # have 3 arbitrary classes (standing for classes csf, gm, wm) 17 | classes = np.array([res/7, res/7*4, res/7*6]) 18 | 19 | # find arc anchor 20 | arc_center = (data.max() + data.min()) / 2. 21 | arc_radius = (data.max() - data.min()) / 2. 22 | arc_weight = (gra / arc_radius)**-1 23 | classes = np.hstack([classes, arc_center]) 24 | 25 | # find euclidean distances to classes 26 | soft = [] 27 | data = data.reshape(dims[0], dims[1]*dims[2]) 28 | for i, a in enumerate(classes): 29 | c = np.array([a, 0]) 30 | # euclidean distance 31 | edist = np.sqrt(np.sum((data - c[:, None])**2., axis=0)) 32 | soft.append(edist) 33 | soft = np.asarray(soft) 34 | 35 | # arc translation 36 | soft[-1, :] = soft[-1, :] - arc_radius 37 | soft[-1, :] = arc_weight * np.abs(soft[-1, :]) 38 | 39 | # arbitrary weights 40 | soft[0, :] = soft[0, :] * 0.66 # csf 41 | soft[-1, :] = soft[-1, :] * 0.5 # arc 42 | 43 | # hard class membership maps 44 | hard = np.argmin(soft, axis=0) 45 | hard = hard.reshape(dims[1], dims[2]) 46 | 47 | plt.imshow(hard.T, origin="lower") 48 | plt.show() 49 | -------------------------------------------------------------------------------- /segmentator/tests/wip_test_gradient_magnitude.py: -------------------------------------------------------------------------------- 1 | """Experiment with different gradient magnitude calculations. 2 | 3 | TODO: turn this into unit tests. 4 | 5 | """ 6 | 7 | import os 8 | from nibabel import load, Nifti1Image, save 9 | from segmentator.utils import compute_gradient_magnitude 10 | 11 | # load 12 | nii = load('/home/faruk/gdrive/Segmentator/data/faruk/gramag_test/mprage_S02_restore.nii.gz') 13 | ima = nii.get_fdata() 14 | basename = nii.get_filename().split(os.extsep, 1)[0] 15 | 16 | # 3D Scharr gradient magnitude 17 | gra_mag = compute_gradient_magnitude(ima, method='scharr') 18 | out = Nifti1Image(gra_mag, affine=nii.affine) 19 | save(out, basename + '_scharr.nii.gz') 20 | 21 | # 3D Sobel gradient magnitude 22 | gra_mag = compute_gradient_magnitude(ima, method='sobel') 23 | out = Nifti1Image(gra_mag, affine=nii.affine) 24 | save(out, basename + '_sobel.nii.gz') 25 | 26 | # 3D Prewitt gradient magnitude 27 | gra_mag = compute_gradient_magnitude(ima, method='prewitt') 28 | out = Nifti1Image(gra_mag, affine=nii.affine) 29 | save(out, basename + '_prewitt.nii.gz') 30 | 31 | # numpy gradient magnitude 32 | gra_mag = compute_gradient_magnitude(ima, method='numpy') 33 | out = Nifti1Image(gra_mag, affine=nii.affine) 34 | save(out, basename + '_numpy_gradient.nii.gz') 35 | -------------------------------------------------------------------------------- /segmentator/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Some utility functions.""" 3 | 4 | from __future__ import division, print_function 5 | import os 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import segmentator.config as cfg 9 | from nibabel import load, Nifti1Image, save 10 | from scipy.ndimage import convolve 11 | from time import time 12 | 13 | 14 | def sub2ind(array_shape, rows, cols): 15 | """Pixel to voxel mapping (similar to matlab's function). 16 | 17 | Parameters 18 | ---------- 19 | array_shape : TODO 20 | rows : TODO 21 | cols : TODO 22 | 23 | Returns 24 | ------- 25 | TODO 26 | 27 | """ 28 | # return (rows*array_shape + cols) 29 | return (cols*array_shape + rows) 30 | 31 | 32 | def map_ima_to_2D_hist(xinput, yinput, bins_arr): 33 | """Image to volume histogram mapping (kind of inverse histogram). 34 | 35 | Parameters 36 | ---------- 37 | xinput : TODO 38 | First image, which is often the intensity image (eg. T1w). 39 | yinput : TODO 40 | Second image, which is often the gradient magnitude image 41 | derived from the first image. 42 | bins_arr : TODO 43 | Array of bins. 44 | 45 | Returns 46 | ------- 47 | vox2pixMap : TODO 48 | Voxel to pixel mapping. 49 | 50 | """ 51 | dgtzData = np.digitize(xinput, bins_arr)-1 52 | dgtzGra = np.digitize(yinput, bins_arr)-1 53 | nr_bins = len(bins_arr)-1 # subtract 1 (more borders than containers) 54 | vox2pixMap = sub2ind(nr_bins, dgtzData, dgtzGra) # 1D 55 | return vox2pixMap 56 | 57 | 58 | def map_2D_hist_to_ima(imaSlc2volHistMap, volHistMask): 59 | """Volume histogram to image mapping for slices (uses np.ind1). 60 | 61 | Parameters 62 | ---------- 63 | imaSlc2volHistMap : 1D numpy array 64 | Flattened image slice. 65 | volHistMask : 1D numpy array 66 | Flattened volume histogram mask. 67 | 68 | Returns 69 | ------- 70 | imaSlcMask : 1D numpy array 71 | Flat image slice mask based on labeled pixels in volume histogram. 72 | 73 | """ 74 | imaSlcMask = np.zeros(imaSlc2volHistMap.shape) 75 | idxUnique = np.unique(volHistMask) 76 | for idx in idxUnique: 77 | linIndices = np.where(volHistMask.flatten() == idx)[0] 78 | # return logical array with length equal to nr of voxels 79 | voxMask = np.in1d(imaSlc2volHistMap, linIndices) 80 | # reset mask and apply logical indexing 81 | imaSlcMask[voxMask] = idx 82 | return imaSlcMask 83 | 84 | 85 | def truncate_range(data, percMin=0.25, percMax=99.75, discard_zeros=True): 86 | """Truncate too low and too high values. 87 | 88 | Parameters 89 | ---------- 90 | data : np.ndarray 91 | Image to be truncated. 92 | percMin : float 93 | Percentile minimum. 94 | percMax : float 95 | Percentile maximum. 96 | discard_zeros : bool 97 | Discard voxels with value 0 from truncation. 98 | 99 | Returns 100 | ------- 101 | data : np.ndarray 102 | Truncated data. 103 | pMin : float 104 | Minimum truncation threshold which is used. 105 | pMax : float 106 | Maximum truncation threshold which is used. 107 | 108 | """ 109 | if discard_zeros: 110 | msk = ~np.isclose(data, 0.) 111 | pMin, pMax = np.nanpercentile(data[msk], [percMin, percMax]) 112 | else: 113 | pMin, pMax = np.nanpercentile(data, [percMin, percMax]) 114 | temp = data[~np.isnan(data)] 115 | temp[temp < pMin], temp[temp > pMax] = pMin, pMax # truncate min and max 116 | data[~np.isnan(data)] = temp 117 | if discard_zeros: 118 | data[~msk] = 0 # put back masked out voxels 119 | return data, pMin, pMax 120 | 121 | 122 | def scale_range(data, scale_factor=500, delta=0, discard_zeros=True): 123 | """Scale values as a preprocessing step. 124 | 125 | Parameters 126 | ---------- 127 | data : np.ndarray 128 | Image to be scaled. 129 | scale_factor : float 130 | Lower scaleFactors provides faster interface due to loweing the 131 | resolution of 2D histogram ( 500 seems fast enough). 132 | delta : float 133 | Delta ensures that the max data points fall inside the last bin 134 | when this function is used with histograms. 135 | discard_zeros : bool 136 | Discard voxels with value 0 from truncation. 137 | 138 | Returns 139 | ------- 140 | data: np.ndarray 141 | Scaled image. 142 | 143 | """ 144 | if discard_zeros: 145 | msk = ~np.isclose(data, 0) 146 | else: 147 | msk = np.ones(data.shape, dtype=bool) 148 | scale_factor = scale_factor - delta 149 | data[msk] = data[msk] - np.nanmin(data[msk]) 150 | data[msk] = scale_factor / np.nanmax(data[msk]) * data[msk] 151 | if discard_zeros: 152 | data[~msk] = 0 # put back masked out voxels 153 | return data 154 | 155 | 156 | def check_data(data, force_original_precision=True): 157 | """Do type casting here.""" 158 | data = np.squeeze(data) # to prevent singular dimension error 159 | dims = data.shape 160 | print('Input image data type is {}.'.format(data.dtype.name)) 161 | if force_original_precision: 162 | pass 163 | elif data.dtype != 'float32': 164 | data = data.astype('float32') 165 | print(' Data type is casted to {}.'.format(data.dtype.name)) 166 | return data, dims 167 | 168 | 169 | def prep_2D_hist(ima, gra, discard_zeros=True): 170 | """Prepare 2D histogram related variables. 171 | 172 | Parameters 173 | ---------- 174 | ima : np.ndarray 175 | First image, which is often the intensity image (eg. T1w). 176 | gra : np.ndarray 177 | Second image, which is often the gradient magnitude image 178 | derived from the first image. 179 | 180 | Returns 181 | ------- 182 | counts : integer 183 | volHistH : TODO 184 | d_min : float 185 | Minimum of the first image. 186 | d_max : float 187 | Maximum of the first image. 188 | nr_bins : integer 189 | Number of one dimensional bins (not the pixels). 190 | bin_edges : TODO 191 | 192 | Notes 193 | ----- 194 | This function is modularized to be called from the terminal. 195 | 196 | """ 197 | if discard_zeros: 198 | gra = gra[~np.isclose(ima, 0)] 199 | ima = ima[~np.isclose(ima, 0)] 200 | d_min, d_max = np.round(np.nanpercentile(ima, [0, 100])) 201 | nr_bins = int(d_max - d_min) 202 | bin_edges = np.arange(d_min, d_max+1) 203 | counts, _, _, volHistH = plt.hist2d(ima, gra, bins=bin_edges, cmap='Greys') 204 | return counts, volHistH, d_min, d_max, nr_bins, bin_edges 205 | 206 | 207 | def create_3D_kernel(operator='scharr'): 208 | """Create various 3D kernels. 209 | 210 | Parameters 211 | ---------- 212 | operator : np.ndarray, shape=(n, n, 3) 213 | Input can be 'sobel', 'prewitt' or any 3D numpy array. 214 | 215 | Returns 216 | ------- 217 | kernel : np.ndarray, shape(6, n, n, 3) 218 | 219 | """ 220 | if operator == 'sobel': 221 | operator = np.array([[[1, 2, 1], [2, 4, 2], [1, 2, 1]], 222 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]], 223 | [[-1, -2, -1], [-2, -4, -2], [-1, -2, -1]]], 224 | dtype='float32') 225 | elif operator == 'prewitt': 226 | operator = np.array([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], 227 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]], 228 | [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]]], 229 | dtype='float32') 230 | elif operator == 'scharr': 231 | operator = np.array([[[9, 30, 9], [30, 100, 30], [9, 30, 9]], 232 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]], 233 | [[-9, -30, -9], [-30, -100, -30], [-9, -30, -9]]], 234 | dtype='float32') 235 | scale_normalization_factor = np.sum(np.abs(operator)) 236 | operator = np.divide(operator, scale_normalization_factor) 237 | 238 | # create permutations operator that will be used in gradient computation 239 | kernel = np.zeros([3, 3, 3, 3]) 240 | kernel[0, ...] = operator 241 | kernel[1, ...] = np.transpose(kernel[0, ...], [2, 0, 1]) 242 | kernel[2, ...] = np.transpose(kernel[0, ...], [1, 2, 0]) 243 | return kernel 244 | 245 | 246 | def compute_gradient_magnitude(ima, method='scharr'): 247 | """Compute gradient magnitude of images. 248 | 249 | Parameters 250 | ---------- 251 | ima : np.ndarray 252 | First image, which is often the intensity image (eg. T1w). 253 | method : string 254 | Gradient computation method. Available options are 'scharr', 255 | 'sobel', 'prewitt', 'numpy'. 256 | Returns 257 | ------- 258 | gra_mag : np.ndarray 259 | Second image, which is often the gradient magnitude image 260 | derived from the first image 261 | 262 | """ 263 | start = time() 264 | print(' Computing gradients...') 265 | if method.lower() == 'sobel': # magnitude scale is similar to numpy method 266 | kernel = create_3D_kernel(operator=method) 267 | gra = np.zeros(ima.shape + (kernel.shape[0],)) 268 | for d in range(kernel.shape[0]): 269 | gra[..., d] = convolve(ima, kernel[d, ...]) 270 | # compute generic gradient magnitude with normalization 271 | gra_mag = np.sqrt(np.sum(np.power(gra, 2.), axis=-1) * 2.) 272 | elif method.lower() == 'prewitt': 273 | kernel = create_3D_kernel(operator=method) 274 | gra = np.zeros(ima.shape + (kernel.shape[0],)) 275 | for d in range(kernel.shape[0]): 276 | gra[..., d] = convolve(ima, kernel[d, ...]) 277 | # compute generic gradient magnitude with normalization 278 | gra_mag = np.sqrt(np.sum(np.power(gra, 2.), axis=-1) * 2.) 279 | elif method.lower() == 'scharr': 280 | kernel = create_3D_kernel(operator=method) 281 | gra = np.zeros(ima.shape + (kernel.shape[0],)) 282 | for d in range(kernel.shape[0]): 283 | gra[..., d] = convolve(ima, kernel[d, ...]) 284 | # compute generic gradient magnitude with normalization 285 | gra_mag = np.sqrt(np.sum(np.power(gra, 2.), axis=-1) * 2.) 286 | elif method.lower() == 'numpy': 287 | gra = np.asarray(np.gradient(ima)) 288 | gra_mag = np.sqrt(np.sum(np.power(gra, 2.), axis=0)) 289 | elif method.lower() == 'deriche': 290 | from segmentator.deriche_prepare import Deriche_Gradient_Magnitude 291 | alpha = cfg.deriche_alpha 292 | print(' Selected alpha: {}'.format(alpha)) 293 | ima = np.ascontiguousarray(ima, dtype=np.float32) 294 | gra_mag = Deriche_Gradient_Magnitude(ima, alpha, normalize=True) 295 | else: 296 | print(' Gradient magnitude method is invalid!') 297 | end = time() 298 | print(" Gradient magnitude computed in: " + str(int(end-start)) 299 | + " seconds.") 300 | return gra_mag 301 | 302 | 303 | def set_gradient_magnitude(image, gramag_option): 304 | """Set gradient magnitude based on the command line flag. 305 | 306 | Parameters 307 | ---------- 308 | image : np.ndarray 309 | First image, which is often the intensity image (eg. T1w). 310 | gramag_option : string 311 | A keyword string or a path to a nifti file. 312 | 313 | Returns 314 | ------- 315 | gra_mag : np.ndarray 316 | Gradient magnitude image, which has the same shape as image. 317 | 318 | """ 319 | if gramag_option not in cfg.gramag_options: 320 | print("Selected gradient magnitude method is not available," 321 | + " interpreting as a file path...") 322 | gra_mag_nii = load(gramag_option) 323 | gra_mag = np.squeeze(gra_mag_nii.get_fdata()) 324 | gra_mag, _ = check_data(gra_mag, cfg.force_original_precision) 325 | gra_mag, _, _ = truncate_range(gra_mag, percMin=cfg.perc_min, 326 | percMax=cfg.perc_max) 327 | gra_mag = scale_range(gra_mag, scale_factor=cfg.scale, delta=0.0001) 328 | 329 | else: 330 | print('{} gradient method is selected.'.format(gramag_option.title())) 331 | gra_mag = compute_gradient_magnitude(image, method=gramag_option) 332 | return gra_mag 333 | 334 | 335 | def export_gradient_magnitude_image(img, filename, filtername, affine): 336 | """Export computed gradient magnitude image as a nifti file.""" 337 | basename = filename.split(os.extsep, 1)[0] 338 | out_img = Nifti1Image(img, affine=affine) 339 | if filtername == 'deriche': # add alpha as suffix for extra information 340 | filtername = '{}_alpha{}'.format(filtername.title(), 341 | cfg.deriche_alpha) 342 | filtername = filtername.replace('.', 'pt') 343 | else: 344 | filtername = filtername.title() 345 | out_path = '{}_GraMag{}.nii.gz'.format(basename, filtername) 346 | print("Exporting gradient magnitude image...") 347 | save(out_img, out_path) 348 | print(' Gradient magnitude image exported in this path:\n ' + out_path) 349 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Segmentator setup. 2 | 3 | To install for development, using the commandline do: 4 | pip install -e /path/to/segmentator 5 | 6 | """ 7 | 8 | from setuptools import setup 9 | from setuptools.extension import Extension 10 | from os.path import join 11 | import numpy 12 | 13 | ext_modules = [Extension( 14 | "segmentator.deriche_3D", [join('segmentator', 'cython', 'deriche_3D.c')], 15 | include_dirs=[numpy.get_include()]) 16 | ] 17 | 18 | setup(name='segmentator', 19 | version='1.6.1', 20 | description=('Multi-dimensional data exploration and segmentation for 3D \ 21 | images.'), 22 | url='https://github.com/ofgulban/segmentator', 23 | author='Omer Faruk Gulban', 24 | author_email='faruk.gulban@maastrichtuniversity.nl', 25 | license='BSD-3-clause', 26 | packages=['segmentator'], 27 | install_requires=['numpy>=1.17', 'matplotlib>=3.1', 'scipy>=1.3', 'compoda>=0.3'], 28 | keywords=['mri', 'segmentation', 'image', 'voxel'], 29 | zip_safe=True, 30 | entry_points={ 31 | 'console_scripts': [ 32 | 'segmentator = segmentator.__main__:main', 33 | 'segmentator_filters = segmentator.filters_ui:main', 34 | ]}, 35 | ext_modules=ext_modules, 36 | ) 37 | -------------------------------------------------------------------------------- /visuals/animation_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofgulban/segmentator/01f46b36152734d67495417528fcc384eb567732/visuals/animation_01.gif -------------------------------------------------------------------------------- /visuals/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofgulban/segmentator/01f46b36152734d67495417528fcc384eb567732/visuals/logo.png -------------------------------------------------------------------------------- /visuals/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 24 | 26 | 28 | 32 | 36 | 40 | 44 | 45 | 47 | 51 | 55 | 56 | 58 | 62 | 66 | 67 | 69 | 73 | 77 | 78 | 80 | 84 | 88 | 89 | 92 | 99 | 100 | 109 | 119 | 123 | 124 | 134 | 138 | 139 | 149 | 158 | 160 | 164 | 168 | 172 | 173 | 183 | 192 | 201 | 210 | 219 | 228 | 237 | 240 | 245 | 246 | 247 | 270 | 272 | 273 | 275 | image/svg+xml 276 | 278 | 279 | 280 | 281 | 282 | 288 | 293 | 300 | 305 | 319 | 324 | 331 | 336 | 341 | 346 | 351 | 352 | 359 | 365 | 371 | 377 | 378 | 384 | 390 | 395 | 396 | 403 | 407 | 413 | 416 | 421 | 426 | 427 | 433 | 434 | 435 | 442 | 447 | 457 | 467 | 477 | 481 | 485 | 491 | 497 | 503 | 504 | 507 | 512 | 516 | 522 | 528 | 534 | 535 | 536 | 539 | 544 | 548 | 554 | 560 | 566 | 567 | 568 | 571 | 577 | 581 | 587 | 593 | 599 | 600 | 601 | 606 | 612 | 618 | 622 | 628 | 634 | 640 | 646 | 652 | 658 | 664 | 670 | 675 | 680 | 685 | 690 | 691 | 694 | 699 | 704 | 710 | 716 | 717 | 723 | 729 | 734 | 740 | 741 | 742 | 749 | 753 | 759 | 765 | 771 | 777 | 783 | 789 | 795 | 801 | 807 | 813 | 814 | 815 | 821 | 828 | 829 | 830 | --------------------------------------------------------------------------------