├── requirements ├── test.txt ├── gpu.txt ├── extras.txt ├── prod.txt └── dev.txt ├── docs ├── authors.rst ├── history.rst ├── readme.rst ├── contributing.rst ├── index.rst ├── installation.rst ├── usage.rst ├── Makefile ├── make.bat └── conf.py ├── tests ├── __init__.py └── test_diluvian.py ├── requirements.txt ├── diluvian ├── __init__.py ├── conf │ ├── cremi_test_datasets.toml │ ├── cremi_datasets.toml │ └── default.toml ├── postprocessing.py ├── preprocessing.py ├── util.py ├── network.py ├── octrees.py ├── __main__.py ├── config.py ├── diluvian.py └── training.py ├── .editorconfig ├── MANIFEST.in ├── setup.cfg ├── tox.ini ├── LICENSE ├── .travis.yml ├── .gitignore ├── HISTORY.rst ├── CONTRIBUTING.rst ├── AUTHORS.rst ├── setup.py ├── Makefile ├── scripts └── create_dataset_toml.py └── README.rst /requirements/test.txt: -------------------------------------------------------------------------------- 1 | pytest==3.0.6 2 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst 2 | -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../HISTORY.rst 2 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /requirements/gpu.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.2.1 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/prod.txt 2 | 3 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /requirements/extras.txt: -------------------------------------------------------------------------------- 1 | cremi==git+https://github.com/cremi/cremi_python.git 2 | mayavi==4.5.0 3 | munkres==1.0.12 4 | -------------------------------------------------------------------------------- /diluvian/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __author__ = """Andrew S. Champion""" 4 | __email__ = 'andrew.champion@gmail.com' 5 | __version__ = '0.0.6' 6 | -------------------------------------------------------------------------------- /requirements/prod.txt: -------------------------------------------------------------------------------- 1 | h5py>=2.6.0 2 | Keras==2.1.6 3 | matplotlib==2.0.0 4 | networkx==1.11 5 | neuroglancer==0.0.8 6 | numpy==1.14.3 7 | Pillow==4.0.0 8 | pytoml==0.1.11 9 | requests==2.13.0 10 | scikit-image==0.13.0 11 | scipy==0.19.1 12 | six==1.11.0 13 | tensorflow==1.8.0 14 | tqdm==4.19.1 15 | pyn5==0.1.0 16 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | -r prod.txt 2 | pip==9.0.1 3 | bumpversion==0.5.3 4 | wheel==0.29.0 5 | watchdog==0.8.3 6 | flake8==3.2.1 7 | tox==2.3.1 8 | coverage==4.1 9 | numpydoc==0.6.0 10 | Sphinx==1.5.1 11 | sphinx-argparse==0.1.16 12 | cryptography==1.7 13 | PyYAML==3.12 14 | pytest==3.0.7 15 | pytest-cov==2.5.1 16 | pytest-runner==2.11.1 17 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | diluvian - flood filling networks 2 | ================================= 3 | 4 | Contents: 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | readme 10 | installation 11 | usage 12 | contributing 13 | authors 14 | history 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | 2 | include AUTHORS.rst 3 | 4 | include CONTRIBUTING.rst 5 | include HISTORY.rst 6 | include LICENSE 7 | include README.rst 8 | 9 | recursive-include tests * 10 | recursive-exclude * __pycache__ 11 | recursive-exclude * *.py[co] 12 | 13 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 14 | 15 | recursive-include diluvian *.py *.txt 16 | recursive-include diluvian/conf *.toml 17 | include requirements.txt 18 | recursive-include requirements *.txt 19 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.0.6 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:diluvian/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs,diluvian/third_party 19 | max-line-length = 120 20 | 21 | [aliases] 22 | test=pytest 23 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py35, py36 3 | 4 | [testenv] 5 | setenv = 6 | PYTHONPATH = {toxinidir}:{toxinidir}/diluvian 7 | COVERAGE_FILE = .coverage.{envname} 8 | deps = 9 | -r{toxinidir}/requirements/dev.txt 10 | whitelist_externals = 11 | make 12 | commands = 13 | pip install -U pip 14 | py.test --basetemp={envtmpdir} --cov=diluvian --cov-report= 15 | make lint 16 | 17 | 18 | ; If you want to make tox run the tests with the same versions, create a 19 | ; requirements.txt with the pinned versions and uncomment the following lines: 20 | ; deps = 21 | ; -r{toxinidir}/requirements.txt 22 | -------------------------------------------------------------------------------- /diluvian/conf/cremi_test_datasets.toml: -------------------------------------------------------------------------------- 1 | # Cropped test datasets from the CREMI MICCAI Challenge on 2 | # Circuit Reconstruction from Electron Microscopy Images 3 | # 4 | # https://cremi.org/ 5 | 6 | [[dataset]] 7 | name = "Sample A+" 8 | hdf5_file = "sample_A+_20160601.hdf" 9 | image_dataset = "volumes/raw" 10 | use_keras_cache = true 11 | download_url = "https://cremi.org/static/data/sample_A%2B_20160601.hdf" 12 | download_md5 = "5b77b53d56333a261f80c7c3bc2168be" 13 | 14 | [[dataset]] 15 | name = "Sample B+" 16 | hdf5_file = "sample_B+_20160601.hdf" 17 | image_dataset = "volumes/raw" 18 | use_keras_cache = true 19 | download_url = "https://cremi.org/static/data/sample_B%2B_20160601.hdf" 20 | download_md5 = "d4d53f207b00978c83e4d853b150b1d7" 21 | 22 | [[dataset]] 23 | name = "Sample C+" 24 | hdf5_file = "sample_C+_20160601.hdf" 25 | image_dataset = "volumes/raw" 26 | use_keras_cache = true 27 | download_url = "https://cremi.org/static/data/sample_C%2B_20160601.hdf" 28 | download_md5 = "69c7677c952f847c550375e516166dfa" 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2017, Andrew S. Champion 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 9 | 10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 11 | 12 | -------------------------------------------------------------------------------- /diluvian/conf/cremi_datasets.toml: -------------------------------------------------------------------------------- 1 | # Cropped datasets from the CREMI MICCAI Challenge on 2 | # Circuit Reconstruction from Electron Microscopy Images 3 | # 4 | # https://cremi.org/ 5 | 6 | [[dataset]] 7 | name = "Sample A" 8 | hdf5_file = "sample_A_20160501.hdf" 9 | image_dataset = "volumes/raw" 10 | label_dataset = "volumes/labels/neuron_ids" 11 | use_keras_cache = true 12 | download_url = "https://cremi.org/static/data/sample_A_20160501.hdf" 13 | download_md5 = "6fc1a45835b57e44afac3d6217c609d8" 14 | 15 | [[dataset]] 16 | name = "Sample B" 17 | hdf5_file = "sample_B_20160501.hdf" 18 | image_dataset = "volumes/raw" 19 | label_dataset = "volumes/labels/neuron_ids" 20 | use_keras_cache = true 21 | download_url = "https://cremi.org/static/data/sample_B_20160501.hdf" 22 | download_md5 = "16397ec1f7e0ba324b506303b0dc5034" 23 | 24 | [[dataset]] 25 | name = "Sample C" 26 | hdf5_file = "sample_C_20160501.hdf" 27 | image_dataset = "volumes/raw" 28 | label_dataset = "volumes/labels/neuron_ids" 29 | use_keras_cache = true 30 | download_url = "https://cremi.org/static/data/sample_C_20160501.hdf" 31 | download_md5 = "2b5ea85255330640d43b1784283b712d" 32 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | language: python 4 | jobs: 5 | include: 6 | - python: 3.6 7 | - python: 3.5 8 | notifications: 9 | email: false 10 | install: 11 | - pip install -U pip 12 | - pip install tox-travis coverage==4.3.4 13 | script: tox 14 | deploy: 15 | on: 16 | repo: aschampion/diluvian 17 | python: 3.5 18 | tags: true 19 | distributions: sdist bdist_wheel 20 | password: 21 | secure: ahez/10Y7uxpSuJyKO6QPelGeGkKpysvT7Ahe5WkhX+tcZA2HaA2RPXRaZj2A/aIz9ncDq/S7G8WOwdWCwrUjmXgS+H1Wrt8M+b2Rpui58q4IR+mCOa26+zPQwckMLhp/ZHwGj8amK7bZXg3YBMLSqk1WHSkp0lK45zgP02+m1VddaeG9NhhexScFXH96P3Mmvi4TypPOgsyQ6Rf1Z/VAJbujPV8Z8UiqVmu5ovgsGqVK6VlZ4gcuoQVT7nbiKcH73jczaClt0gd6ZFQj6afsraVLQEUVNPdUZtT6MaTcTWNejujppD3GDlLI4xO+m+bb0tTlMqmGaSLicMeibmq7w9CnoGaFF4ZJSFMeMzOyA6lGMVMBLumKfbgivzTYeG8ctKGgeBOR/4KYyDeLVD6z4inoJ40NfdHeGVUWqHRUcd21le/i4ChZNMbEwRNj82Sdq4vPZtNjP+jh68+hjQ1msEkiStVWQz66Cigv/fTVAZwCtW+H3A+YBfwSbgFTROk76jh9ywVDlLQCsu7eaHbgIJBW4lm2pxNp+is0Vb6oA+wbDKwbmJAzIAlUmpaE/pnIBScuN4AUXwOBvOCublH6M++TX12MjXAUghw2VDj3O4GuaNYUGMYmR0luoph9jp0X5HS10sST+KeDwWcjXwKH8ghATqQiUcWizFmkOpuppI= 22 | provider: pypi 23 | user: aschampion 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | docs/diluvian.rst 58 | docs/diluvian.third_party.rst 59 | docs/modules.rst 60 | 61 | # PyBuilder 62 | target/ 63 | 64 | # pyenv python configuration file 65 | .python-version 66 | 67 | # vim 68 | *.swp 69 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.0.6 (2018-02-13) 6 | ------------------ 7 | 8 | * Add CREMI evaluation command. 9 | * Add 3D region filling animation. 10 | * Fix region filling animations. 11 | * F_0.5 validation metrics. 12 | * Fix pip install. 13 | * Many other fixes and tweaks (see git log). 14 | 15 | 16 | 0.0.5 (2017-10-03) 17 | ------------------ 18 | 19 | * Fix bug creating U-net with far too few channels. 20 | * Fix bug causing revisit of seed position. 21 | * Fix bug breaking sparse fill. 22 | 23 | 24 | 0.0.4 (2017-10-02) 25 | ------------------ 26 | 27 | * Much faster, more reliable training and validation. 28 | * U-net supports valid padding mode and other features from original 29 | specification. 30 | * Add artifact augmentation. 31 | * More efficient subvolume sampling. 32 | * Many other changes. 33 | 34 | 35 | 0.0.3 (2017-06-04) 36 | ------------------ 37 | 38 | * Training now works in Python 3. 39 | * Multi-GPU filling: filling will now use the same number of processes and 40 | GPUs specified by ``training.num_gpus``. 41 | 42 | 43 | 0.0.2 (2017-05-22) 44 | ------------------ 45 | 46 | * Attempt to fix PyPI configuration file packaging. 47 | 48 | 49 | 0.0.1 (2017-05-22) 50 | ------------------ 51 | 52 | * First release on PyPI. 53 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every 8 | little bit helps, and credit will always be given. 9 | 10 | Development 11 | ----------- 12 | 13 | Here's how to set up `diluvian` for local development. 14 | 15 | 1. Fork the `diluvian` repo on GitHub. 16 | 2. Clone your fork locally:: 17 | 18 | $ git clone git@github.com:your_name_here/diluvian.git 19 | 20 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 21 | 22 | $ mkvirtualenv diluvian 23 | $ cd diluvian/ 24 | $ python setup.py develop 25 | 26 | 4. Create a branch for local development:: 27 | 28 | $ git checkout -b name-of-your-bugfix-or-feature 29 | 30 | Now you can make your changes locally. 31 | 32 | 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:: 33 | 34 | $ flake8 diluvian tests 35 | $ python setup.py test 36 | $ tox 37 | 38 | To get flake8 and tox, just pip install them into your virtualenv. 39 | 40 | 6. Commit your changes and push your branch to GitHub:: 41 | 42 | $ git add . 43 | $ git commit -m "Your detailed description of your changes." 44 | $ git push origin name-of-your-bugfix-or-feature 45 | 46 | 7. Submit a pull request through the GitHub website. 47 | 48 | -------------------------------------------------------------------------------- /diluvian/conf/default.toml: -------------------------------------------------------------------------------- 1 | random_seed = 1 2 | 3 | [volume] 4 | resolution = [40, 16, 16] 5 | 6 | [model] 7 | input_fov_shape = [13, 33, 33] 8 | output_fov_shape = [13, 33, 33] 9 | output_fov_move_fraction = 4 10 | v_true = 0.95 11 | v_false = 0.05 12 | t_move = 0.9 13 | move_check_thickness = 1 14 | move_recheck = true 15 | 16 | [network] 17 | factory = 'diluvian.network.make_flood_fill_unet' 18 | num_modules = 8 19 | convolution_dim = [3, 3, 3] 20 | convolution_filters = 32 21 | output_activation = "sigmoid" 22 | initialization = "glorot_uniform" 23 | dropout_probability = 0.05 24 | unet_num_layers = 4 25 | unet_downsample_rate = [0, 1, 1] 26 | 27 | [optimizer] 28 | klass = "SGD" 29 | lr = 0.001 30 | momentum = 0.5 31 | nesterov = true 32 | 33 | [training] 34 | gpu_batch_size = 32 35 | num_gpus = 1 36 | num_workers = 4 37 | training_size = 1024 38 | validation_size = 128 39 | total_epochs = 2 40 | reset_generators = false 41 | # fill_factor_bins = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.075, 42 | # 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 43 | partitions = {".*" = [2, 1, 1]} 44 | training_partition = {".*" = [0, 0, 0]} 45 | validation_partition = {".*" = [1, 0, 0]} 46 | patience = 20 47 | augment_mirrors = [0, 1, 2] 48 | augment_permute_axes = [[0, 2, 1]] 49 | augment_missing_data = [{axis = 0, prob = 0.01}] 50 | augment_noise = [{axis = 0, mul = 0.05, add = 0.05}] 51 | augment_contrast = [{axis = 0, prob = 0.05, scaling_mean = 0.5, scaling_std = 0.1, center_mean = 1.2, center_std = 0.2}] 52 | 53 | [postprocessing] 54 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributors 3 | ============ 4 | 5 | * Andrew S. Champion 6 | 7 | 8 | Acknowledgements 9 | ---------------- 10 | 11 | This library is an implementation and extension of the flood-filling network 12 | algorithm first described in [Januszewski2016]_ and network architectures in 13 | [He2016]_ and [Ronneberger2015]_. 14 | 15 | This library is built on the wonderful 16 | `Keras library `_ by François Chollet and 17 | `TensorFlow `_. 18 | 19 | Skeletonization uses the `skeletopyze `_ 20 | library by Jan Funke, which is an implementation of [Sato2000]_ and 21 | [Bitter2002]_. 22 | 23 | Diluvian uses a packaging and build harness 24 | `cookiecutter template `_. 25 | 26 | .. [Januszewski2016] 27 | Michał Januszewski, Jeremy Maitin-Shepard, Peter Li, Jorgen Kornfeld, 28 | Winfried Denk, and Viren Jain. 29 | Flood-filling networks. *arXiv preprint* 30 | *arXiv:1611.00421*, 2016. 31 | 32 | .. [He2016] 33 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 34 | Identity mappings in deep residual networks. *arXiv preprint* 35 | *arXiv:1603.05027*, 2016. 36 | 37 | .. [Ronneberger2015] 38 | Olaf Ronneberger, Philipp Fischer, and Thomas Brox. 39 | U-net: convolutional networks for biomedical image segmentation. 40 | MICCAI 2015. 2015. 41 | 42 | .. [Sato2000] 43 | Mie Sato, Ingmar Bitter, Michael A. Bender, Arie E. Kaufman, 44 | and Masayuki Nakajima. 45 | TEASAR: tree-structure extraction algorithm for accurate and robust 46 | skeletons. 47 | PCCGA 2000. 2000. 48 | 49 | .. [Bitter2002] 50 | Ingmar Bitter, Arie E. Kaufman, and Mie Sato. 51 | Penalized-distance volumetric skeleton algorithm. 52 | IEEE Trans on Visualization and Computer Graphics. 2002. 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import pip 6 | from setuptools import setup 7 | 8 | 9 | with open('README.rst') as readme_file: 10 | readme = readme_file.read() 11 | 12 | with open('HISTORY.rst') as history_file: 13 | history = history_file.read() 14 | 15 | 16 | def parse_requirements(filename): 17 | lines = (line.strip() for line in open(filename)) 18 | return [line for line in lines if line and not line.startswith('#')] 19 | 20 | 21 | parsed_requirements = parse_requirements('requirements/prod.txt') 22 | parsed_test_requirements = parse_requirements('requirements/test.txt') 23 | 24 | 25 | setup( 26 | name='diluvian', 27 | version='0.0.6', 28 | description="Flood filling networks for segmenting electron microscopy of neural tissue.", 29 | long_description=readme + '\n\n' + history, 30 | author="Andrew S. Champion", 31 | author_email='andrew.champion@gmail.com', 32 | url='https://github.com/aschampion/diluvian', 33 | packages=[ 34 | 'diluvian', 35 | ], 36 | package_dir={'diluvian': 37 | 'diluvian'}, 38 | entry_points={ 39 | 'console_scripts': [ 40 | 'diluvian=diluvian.__main__:main' 41 | ] 42 | }, 43 | include_package_data=True, 44 | install_requires=parsed_requirements, 45 | license="MIT license", 46 | zip_safe=False, 47 | keywords='diluvian', 48 | classifiers=[ 49 | 'Development Status :: 2 - Pre-Alpha', 50 | 'Intended Audience :: Science/Research', 51 | 'License :: OSI Approved :: MIT License', 52 | 'Natural Language :: English', 53 | 'Programming Language :: Python :: 2', 54 | 'Programming Language :: Python :: 2.7', 55 | 'Topic :: Scientific/Engineering :: Bio-Informatics', 56 | ], 57 | setup_requires=['pytest-runner',], 58 | test_suite='tests', 59 | tests_require=parsed_test_requirements 60 | ) 61 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Installation 5 | ============ 6 | 7 | Diluvian requires CUDA. For help installing CUDA, follow the 8 | `TensorFlow installation `_ instructions 9 | for GPU support. 10 | Note that diluvian will only install TensorFlow CPU during setup, so you will 11 | want to install the version of ``tensorflow-gpu`` diluvian requires: 12 | 13 | .. code-block:: console 14 | 15 | pip install 'tensorflow-gpu==1.3.0' 16 | 17 | You should install diluvian 18 | `in a virtualenv `_ 19 | or similar isolated environment. All other documentation here assumes a 20 | a virtualenv has been created and is active. 21 | 22 | The neuroglancer PyPI package release is out-of-date, so to avoid spurious 23 | console output and other issues you may want to 24 | `install from source `_. 25 | 26 | To use skeletonization you must install the 27 | `skeletopyze `_ library into the 28 | virtualenv manually. See its documentation for requirements and instructions. 29 | 30 | 31 | Stable release 32 | -------------- 33 | 34 | To install diluvian, run this command in your terminal: 35 | 36 | .. code-block:: console 37 | 38 | pip install diluvian 39 | 40 | This is the preferred method to install diluvian, as it will always install the most recent stable release. 41 | 42 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide 43 | you through the process. 44 | 45 | .. _pip: https://pip.pypa.io 46 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ 47 | 48 | 49 | From sources 50 | ------------ 51 | 52 | The sources for diluvian can be downloaded from the `Github repo`_. 53 | 54 | You can either clone the public repository: 55 | 56 | .. code-block:: console 57 | 58 | git clone git://github.com/aschampion/diluvian 59 | 60 | Or download the `tarball`_: 61 | 62 | .. code-block:: console 63 | 64 | curl -OL https://github.com/aschampion/diluvian/tarball/master 65 | 66 | Once you have a copy of the source, you can install it with: 67 | 68 | .. code-block:: console 69 | 70 | python setup.py install 71 | 72 | 73 | .. _Github repo: https://github.com/aschampion/diluvian 74 | .. _tarball: https://github.com/aschampion/diluvian/tarball/master 75 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | define BROWSER_PYSCRIPT 4 | import os, webbrowser, sys 5 | try: 6 | from urllib import pathname2url 7 | except: 8 | from urllib.request import pathname2url 9 | 10 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 11 | endef 12 | export BROWSER_PYSCRIPT 13 | 14 | define PRINT_HELP_PYSCRIPT 15 | import re, sys 16 | 17 | for line in sys.stdin: 18 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 19 | if match: 20 | target, help = match.groups() 21 | print("%-20s %s" % (target, help)) 22 | endef 23 | export PRINT_HELP_PYSCRIPT 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | 32 | clean-build: ## remove build artifacts 33 | rm -fr build/ 34 | rm -fr dist/ 35 | rm -fr .eggs/ 36 | find . -name '*.egg-info' -exec rm -fr {} + 37 | find . -name '*.egg' -exec rm -f {} + 38 | 39 | clean-pyc: ## remove Python file artifacts 40 | find . -name '*.pyc' -exec rm -f {} + 41 | find . -name '*.pyo' -exec rm -f {} + 42 | find . -name '*~' -exec rm -f {} + 43 | find . -name '__pycache__' -exec rm -fr {} + 44 | 45 | clean-test: ## remove test and coverage artifacts 46 | rm -fr .tox/ 47 | rm -f .coverage 48 | rm -fr htmlcov/ 49 | 50 | lint: ## check style with flake8 51 | flake8 diluvian tests 52 | 53 | test: ## run tests quickly with the default Python 54 | py.test 55 | 56 | 57 | test-all: ## run tests on every Python version with tox 58 | tox 59 | 60 | coverage: ## check code coverage quickly with the default Python 61 | coverage run --source diluvian -m pytest 62 | 63 | coverage report -m 64 | coverage html 65 | $(BROWSER) htmlcov/index.html 66 | 67 | docs: ## generate Sphinx HTML documentation, including API docs 68 | rm -f docs/diluvian.rst 69 | rm -f docs/modules.rst 70 | $(MAKE) -C docs clean 71 | $(MAKE) -C docs html 72 | $(BROWSER) docs/_build/html/index.html 73 | 74 | servedocs: docs ## compile the docs watching for changes 75 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 76 | 77 | release: clean ## package and upload a release 78 | python setup.py sdist upload 79 | python setup.py bdist_wheel upload 80 | 81 | dist: clean ## builds source and wheel package 82 | python setup.py sdist 83 | python setup.py bdist_wheel 84 | ls -l dist 85 | 86 | install: clean ## install the package to the active Python's site-packages 87 | python setup.py install 88 | -------------------------------------------------------------------------------- /scripts/create_dataset_toml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from __future__ import print_function 6 | 7 | import glob 8 | import os 9 | import re 10 | 11 | import argparse 12 | import h5py 13 | import numpy as np 14 | import pytoml as toml 15 | 16 | from diluvian.util import get_nonzero_aabb 17 | 18 | 19 | def create_dataset_conf_from_files(path, file_pattern, name_regex, name_format, mask_bounds=True): 20 | pathspec = path + file_pattern 21 | name_regex = re.compile(name_regex) 22 | 23 | datasets = [] 24 | 25 | for pathname in glob.iglob(pathspec): 26 | filename = os.path.basename(pathname) 27 | name = name_format.format(*name_regex.match(filename).groups()) 28 | ds = { 29 | 'name': name, 30 | 'hdf5_file': pathname, 31 | 'image_dataset': 'volumes/raw', 32 | 'label_dataset': 'volumes/labels/neuron_ids', 33 | 'mask_dataset': 'volumes/labels/mask', 34 | 'resolution': [40, 4, 4], 35 | } 36 | 37 | if mask_bounds: 38 | print('Finding mask bounds for {}'.format(filename)) 39 | f = h5py.File(pathname, 'r') 40 | d = f[ds['mask_dataset']] 41 | mask_data = d[:] 42 | mask_min, mask_max = get_nonzero_aabb(mask_data) 43 | 44 | ds['mask_bounds'] = [mask_min, mask_max] 45 | f.close() 46 | 47 | datasets.append(ds) 48 | 49 | return {'dataset': datasets} 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='Create a dataset TOML from a directory of HDF5 files.') 54 | 55 | parser.add_argument( 56 | '--file-pattern', dest='file_pattern', default='sample_[ABC]*hdf', 57 | help='Glob for HDF5 volume filenames.') 58 | parser.add_argument( 59 | '--name-regex', dest='name_regex', default=r'sample_([ABC])(.*).hdf', 60 | help='Regex for extracting volume name from filenames.') 61 | parser.add_argument( 62 | '--name-format', dest='name_format', default='Sample {} ({})', 63 | help='Format string for creating volume names from name regex matches.') 64 | parser.add_argument( 65 | 'path', default=None, 66 | help='Path to the HDF5 volume files.') 67 | parser.add_argument( 68 | 'dataset_file', default=None, 69 | help='Name for the TOML dataset file that will be created.') 70 | 71 | args = parser.parse_args() 72 | 73 | conf = create_dataset_conf_from_files(args.path, args.file_pattern, args.name_regex, args.name_format) 74 | print('Found {} datasets.'.format(len(conf['dataset']))) 75 | 76 | with open(args.dataset_file, 'wb') as tomlfile: 77 | tomlfile.write(toml.dumps(conf)) 78 | -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | ===== 2 | Usage 3 | ===== 4 | 5 | Basic Usage 6 | =========== 7 | 8 | Arguments for the ``diluvian`` command line interface are available via help: 9 | 10 | .. code-block:: console 11 | 12 | diluvian -h 13 | diluvian train -h 14 | diluvian fill -h 15 | diluvian sparse-fill -h 16 | diluvian view -h 17 | ... 18 | 19 | and also :ref:`in the section below `. 20 | 21 | 22 | Configuration Files 23 | ------------------- 24 | 25 | Configuration files control most of the behavior of the model, network, and 26 | training. To create a configuration file: 27 | 28 | .. code-block:: console 29 | 30 | diluvian check-config > myconfig.toml 31 | 32 | This will output the current default configuration state into a new file. 33 | Settings for configuration files are documented in the 34 | :mod:`config module documentation`. 35 | Each section in the configuration file, 36 | like ``[training]`` (known in TOML as a *table*), corresponds with a different 37 | configuration class: 38 | 39 | * :class:`Volume` 40 | * :class:`Model` 41 | * :class:`Network` 42 | * :class:`Optimizer` 43 | * :class:`Training` 44 | * :class:`Postprocessing` 45 | 46 | To run diluvian using a custom config, use the ``-c`` command line argument: 47 | 48 | .. code-block:: console 49 | 50 | diluvian train -c myconfig.toml 51 | 52 | If multiple config files are provided, each will be applied on top of the 53 | previous state in the order provided, only overriding the settings that are 54 | specified in each file: 55 | 56 | .. code-block:: console 57 | 58 | diluvian train -c myconfig1.toml -c myconfig2.toml -c myconfig3.toml 59 | 60 | This allows easy compositing of multiple configurations, for example when 61 | running a grid search. 62 | 63 | 64 | Dataset Files 65 | ------------- 66 | 67 | Volume datasets are expected to be in HDF5 files. Dataset configuration 68 | is provided by TOML files that give the paths to these files and the HDF5 69 | group paths to the relevant data within them. 70 | 71 | Each dataset is a TOML array entry in the datasets table: 72 | 73 | .. code-block:: toml 74 | 75 | [[dataset]] 76 | name = "Sample A" 77 | hdf5_file = "sample_A_20160501.hdf" 78 | image_dataset = "volumes/raw" 79 | label_dataset = "volumes/labels/neuron_ids" 80 | 81 | ``hdf5_file`` should include the full path to the file. 82 | 83 | Multiple datasets can be included by providing multiple ``[[dataset]]`` 84 | sections. 85 | 86 | To run diluvian using a dataset configuration file, use the ``-v`` 87 | command line argument: 88 | 89 | .. code-block:: console 90 | 91 | diluvian train -v mydataset.toml 92 | 93 | 94 | As a Python Library 95 | =================== 96 | 97 | To use diluvian in a project:: 98 | 99 | import diluvian 100 | 101 | If you are using diluvian via Python, it most likely is because you have data 102 | in a custom format you need to import. 103 | The easiest way to do so is by constructing or extending the 104 | :class:`Volume class `. 105 | For out-of-memory datasets, construct a volume class backed by block-sparse 106 | data structures (:class:`diluvian.octrees.OctreeVolume`). 107 | See :class:`ImageStackVolume` for an example. 108 | 109 | Once data is available as a volume, normal training and filling operations can 110 | be called. See :meth:`diluvian.training.train_network` or 111 | :meth:`diluvian.diluvian.fill_region_with_model`. 112 | 113 | 114 | .. _command-line-interface: 115 | 116 | Command Line Interface 117 | ====================== 118 | 119 | .. argparse:: 120 | :module: diluvian.__main__ 121 | :func: _make_main_parser 122 | :prog: diluvian 123 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | =============================== 2 | diluvian 3 | =============================== 4 | 5 | 6 | Flood filling networks for segmenting electron microscopy of neural tissue. 7 | 8 | ============== =============== 9 | PyPI Release |pypi_badge| 10 | Documentation |docs_badge| 11 | License |license_badge| 12 | Build Status |travis_badge| 13 | ============== =============== 14 | 15 | Diluvian is an implementation and extension of the flood-filling network (FFN) 16 | algorithm first described in [Januszewski2016]_. Flood-filling works by 17 | starting at a seed location known to lie inside a region of interest, using a 18 | convolutional network to predict the extent of the region within a small 19 | field of view around that seed location, and queuing up new field of view 20 | locations along the boundary of the current field of view that are confidently 21 | inside the region. This process is repeated until the region has been fully 22 | explored. 23 | 24 | As of December 2017 the original paper's authors have released `their implementation `_. 25 | 26 | 27 | Quick Start 28 | ----------- 29 | 30 | This assumes you already have CUDA installed and have created a fresh 31 | virtualenv. See `installation documentation `_ 32 | for detailed instructions. 33 | 34 | Install diluvian and its dependencies into your virtualenv: 35 | 36 | .. code-block:: console 37 | 38 | pip install diluvian 39 | 40 | For compatibility diluvian only requires TensorFlow CPU by default, but you 41 | will want to use TensorFlow GPU if you have installed CUDA: 42 | 43 | .. code-block:: console 44 | 45 | pip install 'tensorflow-gpu==1.3.0' 46 | 47 | To test that everything works train diluvian on three volumes from the 48 | `CREMI challenge `_: 49 | 50 | .. code-block:: console 51 | 52 | diluvian train 53 | 54 | This will automatically download the CREMI datasets to your Keras cache. Only 55 | two epochs will run with a small sample set, so the trained model is not useful 56 | but will verify Tensorflow is working correctly. 57 | 58 | To train for longer, generate a diluvian config file: 59 | 60 | .. code-block:: console 61 | 62 | diluvian check-config > myconfig.toml 63 | 64 | Now edit settings in the ``[training]`` section of ``myconfig.toml`` to your 65 | liking and begin the training again: 66 | 67 | .. code-block:: console 68 | 69 | diluvian train -c myconfig.toml 70 | 71 | For detailed command line instructions and usage from Python, see the 72 | `usage documentation `_. 73 | 74 | 75 | Limitations, Differences, and Caveats 76 | ------------------------------------- 77 | 78 | Diluvian may differ from the original FFN algorithm or make implementation 79 | choices in ways pertinent to your use: 80 | 81 | * By default diluvian uses a U-Net architecture rather than stacked convolution 82 | modules with skip links. The authors of the original FFN paper also now use 83 | both architectures (personal communication). To use a different architecture, 84 | change the ``factory`` setting in the ``[network]`` section of your config 85 | file. 86 | * Rather than resampling training data based on the filling fraction 87 | :math:`f_a`, sample loss is (optionally) weighted based on the filling 88 | fraction. 89 | * A FOV center's priority in the move queue is determined by the checking 90 | plane mask probability of the first move to queue it, rather than the 91 | highest mask probability with which it is added to the queue. 92 | * Currently only processing of each FOV is done on the GPU, with movement 93 | being processed on the CPU and requiring copying of FOV data to host and 94 | back for each move. 95 | 96 | .. [Januszewski2016] 97 | Michał Januszewski, Jeremy Maitin-Shepard, Peter Li, Jorgen Kornfeld, 98 | Winfried Denk, and Viren Jain. 99 | Flood-filling networks. *arXiv preprint* 100 | *arXiv:1611.00421*, 2016. 101 | 102 | .. |pypi_badge| 103 | image:: https://img.shields.io/pypi/v/diluvian.svg 104 | :target: https://pypi.python.org/pypi/diluvian 105 | :alt: PyPI Package Version 106 | 107 | .. |travis_badge| 108 | image:: https://img.shields.io/travis/aschampion/diluvian.svg 109 | :target: https://travis-ci.org/aschampion/diluvian 110 | :alt: Continuous Integration Status 111 | 112 | .. |docs_badge| 113 | image:: https://readthedocs.org/projects/diluvian/badge/?version=latest 114 | :target: https://diluvian.readthedocs.io/en/latest/?badge=latest 115 | :alt: Documentation Status 116 | 117 | .. |license_badge| 118 | image:: https://img.shields.io/badge/License-MIT-blue.svg 119 | :target: https://opensource.org/licenses/MIT 120 | :alt: License: MIT 121 | -------------------------------------------------------------------------------- /diluvian/postprocessing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Segmentation processing and skeletonization after flood filling.""" 3 | 4 | 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import csv 9 | import logging 10 | 11 | import numpy as np 12 | from scipy import ndimage 13 | 14 | from .config import CONFIG 15 | from .octrees import OctreeVolume 16 | from .util import get_nonzero_aabb 17 | 18 | 19 | class Body(object): 20 | def __init__(self, mask, seed): 21 | self.mask = mask 22 | self.seed = seed 23 | 24 | def is_seed_in_mask(self): 25 | return self.mask[tuple(self.seed)] 26 | 27 | def _get_bounded_mask(self, closing_shape=None): 28 | if isinstance(self.mask, OctreeVolume): 29 | # If this is a sparse volume, materialize it to memory. 30 | bounds = self.mask.get_leaf_bounds() 31 | mask = self.mask[list(map(slice, bounds[0], bounds[1]))] 32 | # Crop the mask and bounds to nonzero region of the mask. 33 | mask_min, mask_max = get_nonzero_aabb(mask) 34 | bounds[0] += mask_min 35 | bounds[1] -= np.array(mask.shape) - mask_max 36 | mask = mask[list(map(slice, mask_min, mask_max))] 37 | assert mask.shape == tuple(bounds[1] - bounds[0]), \ 38 | 'Bounds shape ({}) and mask shape ({}) differ.'.format(bounds[1] - bounds[0], mask.shape) 39 | else: 40 | bounds = (np.zeros(3, dtype=np.int64), np.array(self.mask.shape)) 41 | mask = self.mask 42 | 43 | if closing_shape is not None: 44 | # Use grey closing rather than binary closing because it uses 45 | # a mode at the boundary that prevents erosion. 46 | mask = ndimage.grey_closing(mask, structure=np.ones(closing_shape), mode='nearest') 47 | 48 | return mask, bounds 49 | 50 | def get_largest_component(self, closing_shape=None): 51 | mask, bounds = self._get_bounded_mask(closing_shape) 52 | 53 | label_im, num_labels = ndimage.label(mask) 54 | label_sizes = ndimage.sum(mask, label_im, range(num_labels + 1)) 55 | label_im[(label_sizes < label_sizes.max())[label_im]] = 0 56 | label_im = np.minimum(label_im, 1) 57 | 58 | if label_im[tuple(self.seed - bounds[0])] == 0: 59 | logging.warning('Seed voxel ({}) is not in connected component.'.format(np.array_str(self.seed))) 60 | 61 | return label_im, bounds 62 | 63 | def get_seeded_component(self, closing_shape=None): 64 | mask, bounds = self._get_bounded_mask(closing_shape) 65 | 66 | label_im, _ = ndimage.label(mask) 67 | seed_label = label_im[tuple(self.seed - bounds[0])] 68 | if seed_label == 0: 69 | raise ValueError('Seed voxel (%s) is not in body.', np.array_str(self.seed)) 70 | label_im[label_im != seed_label] = 0 71 | label_im[label_im == seed_label] = 1 72 | 73 | return label_im, bounds 74 | 75 | def to_swc(self, filename): 76 | component, bounds = self.get_largest_component(closing_shape=CONFIG.postprocessing.closing_shape) 77 | print('Skeleton is within {}, {}'.format(np.array_str(bounds[0]), np.array_str(bounds[1]))) 78 | skel = skeletonize_component(component) 79 | swc = skeleton_to_swc(skel, bounds[0], CONFIG.volume.resolution) 80 | with open(filename, 'w') as swcfile: 81 | writer = csv.writer(swcfile, delimiter=' ', quoting=csv.QUOTE_NONE) 82 | writer.writerows(swc) 83 | 84 | 85 | def skeletonize_component(component): 86 | import skeletopyze 87 | 88 | params = skeletopyze.Parameters() 89 | res = skeletopyze.point_f3() 90 | for i in range(3): 91 | res[i] = CONFIG.volume.resolution[i] 92 | 93 | print('Skeletonizing...') 94 | skel = skeletopyze.get_skeleton_graph(component.astype(np.int32), params, res) 95 | 96 | return skel 97 | 98 | 99 | def skeleton_to_swc(skeleton, offset, resolution): 100 | import networkx as nx 101 | 102 | g = nx.Graph() 103 | g.add_nodes_from(skeleton.nodes()) 104 | g.add_edges_from((e.u, e.v) for e in skeleton.edges()) 105 | 106 | # Find a directed tree for mapping to a skeleton. 107 | if nx.number_of_nodes(g) > 1: 108 | # This discards cyclic edges in the graph. 109 | t = nx.bfs_tree(nx.minimum_spanning_tree(g), g.nodes()[0]) 110 | else: 111 | t = nx.DiGraph() 112 | t.add_nodes_from(g) 113 | # Copy node attributes 114 | for n in t.nodes_iter(): 115 | loc = skeleton.locations(n) 116 | # skeletopyze is z, y, x (as it should be). 117 | loc = np.array(loc) 118 | loc = np.multiply(loc + offset, resolution) 119 | t.node[n].update({'x': loc[0], 120 | 'y': loc[1], 121 | 'z': loc[2], 122 | 'radius': skeleton.diameters(n) / 2.0}) 123 | 124 | # Set parent node ID 125 | for n, nbrs in t.adjacency_iter(): 126 | for nbr in nbrs: 127 | t.node[nbr]['parent_id'] = n 128 | if 'radius' not in t.node[nbr]: 129 | t.node[nbr]['radius'] = -1 130 | 131 | return [[ 132 | node_id, 133 | 0, 134 | n['x'], n['y'], n['z'], 135 | n['radius'], 136 | n.get('parent_id', -1)] for node_id, n in t.nodes(data=True)] 137 | -------------------------------------------------------------------------------- /diluvian/preprocessing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Volume preprocessing for seed generation and data augmentation.""" 3 | 4 | 5 | from __future__ import division 6 | 7 | import logging 8 | 9 | import numpy as np 10 | from scipy import ndimage 11 | from six.moves import range as xrange 12 | 13 | from .config import CONFIG 14 | from .util import ( 15 | get_color_shader, 16 | WrappedViewer, 17 | ) 18 | 19 | 20 | def make_prewitt(size): 21 | """Construct a separable Prewitt gradient convolution of a given size. 22 | 23 | Adapted from SciPy's ndimage ``prewitt``. 24 | 25 | Parameters 26 | ---------- 27 | size : int 28 | 1-D size of the filter (should be odd). 29 | """ 30 | def prewitt(input, axis=-1, output=None, mode='reflect', cval=0.0): 31 | input = np.asarray(input) 32 | if axis < 0: 33 | axis += input.ndim 34 | if type(output) is not np.ndarray: 35 | output = np.zeros_like(input) 36 | 37 | kernel = list(range(1, size // 2 + 1)) 38 | kernel = [-x for x in reversed(kernel)] + [0] + kernel 39 | smooth = np.ones(size, dtype=np.int32) 40 | smooth = smooth / np.abs(kernel).sum() 41 | smooth = list(smooth) 42 | 43 | ndimage.correlate1d(input, kernel, axis, output, mode, cval, 0) 44 | axes = [ii for ii in range(input.ndim) if ii != axis] 45 | for ii in axes: 46 | ndimage.correlate1d(output, smooth, ii, output, mode, cval, 0) 47 | return output 48 | 49 | return prewitt 50 | 51 | 52 | def intensity_distance_seeds(image_data, resolution, axis=0, erosion_radius=16, min_sep=24, visualize=False): 53 | """Create seed locations maximally distant from a Sobel filter. 54 | 55 | Parameters 56 | ---------- 57 | image_data : ndarray 58 | resolution : ndarray 59 | axis : int, optional 60 | Axis along which to slices volume to generate seeds in 2D. If 61 | None volume is processed in 3D. 62 | erosion_radius : int, optional 63 | L_infinity norm radius of the structuring element for eroding 64 | components. 65 | min_sep : int, optional 66 | L_infinity minimum separation of seeds in nanometers. 67 | 68 | Returns 69 | ------- 70 | list of ndarray 71 | """ 72 | # Late import as this is the only function using Scikit. 73 | from skimage import morphology 74 | 75 | structure = np.ones(np.floor_divide(erosion_radius, resolution) * 2 + 1) 76 | 77 | if axis is None: 78 | def slices(): 79 | yield [slice(None), slice(None), slice(None)] 80 | else: 81 | structure = structure[axis] 82 | 83 | def slices(): 84 | for i in xrange(image_data.shape[axis]): 85 | s = list(map(slice, [None] * 3)) 86 | s[axis] = i 87 | yield s 88 | 89 | sobel = np.zeros_like(image_data) 90 | thresh = np.zeros_like(image_data) 91 | transform = np.zeros_like(image_data) 92 | skmax = np.zeros_like(image_data) 93 | for s in slices(): 94 | image_slice = image_data[s] 95 | if axis is not None and not np.any(image_slice): 96 | logging.debug('Skipping blank slice.') 97 | continue 98 | logging.debug('Running Sobel filter on image shape %s', image_data.shape) 99 | sobel[s] = ndimage.generic_gradient_magnitude(image_slice, make_prewitt(int((24 / resolution).max() * 2 + 1))) 100 | # sobel = ndimage.grey_dilation(sobel, size=(5,5,3)) 101 | logging.debug('Running distance transform on image shape %s', image_data.shape) 102 | 103 | # For low res images the sobel histogram is unimodal. For now just 104 | # threshold the histogram at the mean. 105 | thresh[s] = sobel[s] < np.mean(sobel[s]) 106 | thresh[s] = ndimage.binary_erosion(thresh[s], structure=structure) 107 | transform[s] = ndimage.distance_transform_cdt(thresh[s]) 108 | # Remove missing sections from distance transform. 109 | transform[s][image_slice == 0] = 0 110 | logging.debug('Finding local maxima of image shape %s', image_data.shape) 111 | skmax[s] = morphology.thin(morphology.extrema.local_maxima(transform[s])) 112 | 113 | if visualize: 114 | viewer = WrappedViewer() 115 | viewer.add(image_data, name='Image') 116 | viewer.add(sobel, name='Filtered') 117 | viewer.add(thresh.astype(np.float), name='Thresholded') 118 | viewer.add(transform.astype(np.float), name='Distance') 119 | viewer.add(skmax, name='Seeds', shader=get_color_shader(0, normalized=False)) 120 | viewer.print_view_prompt() 121 | 122 | mask = np.zeros(np.floor_divide(min_sep, resolution) + 1) 123 | mask[0, 0, 0] = 1 124 | seeds = np.transpose(np.nonzero(skmax)) 125 | for seed in seeds: 126 | if skmax[tuple(seed)]: 127 | lim = np.minimum(mask.shape, skmax.shape - seed) 128 | skmax[list(map(slice, seed, seed + lim))] = mask[list(map(slice, lim))] 129 | 130 | seeds = np.transpose(np.nonzero(skmax)) 131 | 132 | return seeds 133 | 134 | 135 | def grid_seeds(image_data, _, grid_step_spacing=1): 136 | """Create seed locations in a volume on a uniform grid. 137 | 138 | Parameters 139 | ---------- 140 | image_data : ndarray 141 | 142 | Returns 143 | ------- 144 | list of ndarray 145 | """ 146 | seeds = [] 147 | shape = image_data.shape 148 | grid_size = CONFIG.model.move_step * grid_step_spacing 149 | for x in range(grid_size[0], shape[0], grid_size[0]): 150 | for y in range(grid_size[1], shape[1], grid_size[1]): 151 | for z in range(grid_size[2], shape[2], grid_size[2]): 152 | seeds.append(np.array([x, y, z], dtype=np.int32)) 153 | 154 | return seeds 155 | 156 | 157 | # Note that these must be added separately to the CLI. 158 | SEED_GENERATORS = { 159 | 'grid': grid_seeds, 160 | 'sobel': intensity_distance_seeds, 161 | } 162 | -------------------------------------------------------------------------------- /diluvian/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import collections 8 | import csv 9 | import importlib 10 | import itertools 11 | import logging 12 | import sys 13 | import webbrowser 14 | 15 | import neuroglancer 16 | import numpy as np 17 | import six 18 | from six.moves import input as raw_input 19 | 20 | 21 | class WrappedViewer(neuroglancer.Viewer): 22 | def __init__(self, voxel_coordinates=None, **kwargs): 23 | super(WrappedViewer, self).__init__(**kwargs) 24 | self.voxel_coordinates = voxel_coordinates 25 | 26 | def get_json_state(self): 27 | state = super(WrappedViewer, self).get_json_state() 28 | if self.voxel_coordinates is not None: 29 | if 'navigation' not in state: 30 | state['navigation'] = collections.OrderedDict() 31 | if 'pose' not in state['navigation']: 32 | state['navigation']['pose'] = collections.OrderedDict() 33 | if 'position' not in state['navigation']['pose']: 34 | state['navigation']['pose']['position'] = collections.OrderedDict() 35 | state['navigation']['pose']['position']['voxelCoordinates'] = list(map(int, list(self.voxel_coordinates))) 36 | return state 37 | 38 | def open_in_browser(self): 39 | webbrowser.open_new_tab(str(self)) 40 | 41 | def print_view_prompt(self): 42 | print(self) 43 | 44 | while True: 45 | s = raw_input('Press v, enter to open in browser, or enter to close...') 46 | if s == 'v': 47 | self.open_in_browser() 48 | else: 49 | break 50 | 51 | 52 | def extend_keras_history(a, b): 53 | a.epoch.extend(b.epoch) 54 | for k, v in b.history.items(): 55 | a.history.setdefault(k, []).extend(v) 56 | 57 | 58 | def write_keras_history_to_csv(history, filename): 59 | """Write Keras history to a CSV file. 60 | 61 | If the file already exists it will be overwritten. 62 | 63 | Parameters 64 | ---------- 65 | history : keras.callbacks.History 66 | filename : str 67 | """ 68 | if sys.version_info[0] < 3: 69 | args, kwargs = (['wb', ], {}) 70 | else: 71 | args, kwargs = (['w', ], {'newline': '', 'encoding': 'utf8', }) 72 | with open(filename, *args, **kwargs) as csvfile: 73 | writer = csv.writer(csvfile) 74 | metric_cols = history.history.keys() 75 | indices = [i[0] for i in sorted(enumerate(metric_cols), key=lambda x: x[1])] 76 | metric_cols = sorted(metric_cols) 77 | cols = ['epoch'] + metric_cols 78 | sorted_metrics = list(history.history.values()) 79 | sorted_metrics = [sorted_metrics[i] for i in indices] 80 | writer.writerow(cols) 81 | for row in zip(history.epoch, *sorted_metrics): 82 | writer.writerow(row) 83 | 84 | 85 | def get_function(name): 86 | mod_name, func_name = name.rsplit('.', 1) 87 | mod = importlib.import_module(mod_name) 88 | func = getattr(mod, func_name) 89 | 90 | return func 91 | 92 | 93 | def get_color_shader(channel, normalized=True): 94 | xform = 'toNormalized' if normalized else 'toRaw' 95 | value_str = '{}(getDataValue(0))'.format(xform) 96 | channels = ['0', '0', '0', value_str] 97 | channels[channel] = '1' 98 | shader = """ 99 | void main() {{ 100 | emitRGBA(vec4({})); 101 | }} 102 | """.format(', '.join(channels)) 103 | return shader 104 | 105 | 106 | def pad_dims(x): 107 | """Add single-dimensions to the beginning and end of an array.""" 108 | return np.expand_dims(np.expand_dims(x, x.ndim), 0) 109 | 110 | 111 | def get_nonzero_aabb(a): 112 | """Get the axis-aligned bounding box of nonzero elements of a 3D array. 113 | 114 | Parameters 115 | ---------- 116 | a : ndarray 117 | A 3D NumPpy array. 118 | 119 | Returns 120 | ------- 121 | tuple of ndarray 122 | """ 123 | mask_min = [] 124 | mask_max = [] 125 | 126 | for axes in [(1, 2), (0, 2), (0, 1)]: 127 | proj = np.any(a, axis=axes) 128 | w = np.where(proj)[0] 129 | if w.size: 130 | amin, amax = w[[0, -1]] 131 | amax += 1 132 | else: 133 | amin, amax = 0, 0 134 | 135 | mask_min.append(amin) 136 | mask_max.append(amax) 137 | 138 | mask_min = np.array(mask_min, dtype=np.int64) 139 | mask_max = np.array(mask_max, dtype=np.int64) 140 | 141 | return mask_min, mask_max 142 | 143 | 144 | def binary_confusion_matrix(y, y_pred): 145 | cm = np.bincount(2 * y + y_pred, minlength=4).reshape(2, 2) 146 | 147 | return cm 148 | 149 | 150 | def binary_f_score(y, y_pred, beta=1.0): 151 | cm = binary_confusion_matrix(y.flatten(), y_pred.flatten()) 152 | return confusion_f_score(cm, beta) 153 | 154 | 155 | def binary_crossentropy(y, y_pred, eps=1e-15): 156 | y_pred = np.clip(y_pred, eps, 1 - eps) 157 | 158 | loss = y * np.log(y_pred) + (1.0 - y) * np.log(1.0 - y_pred) 159 | return - np.sum(loss) / np.prod(y.shape) 160 | 161 | 162 | def confusion_f_score(cm, beta): 163 | return (1.0 + beta) * cm[1, 1] / ((1.0 + beta) * cm[1, 1] + (beta ** 2) * cm[1, 0] + cm[0, 1]) 164 | 165 | 166 | class Roundrobin(six.Iterator): 167 | """Iterate over a collection of iterables, pulling one item from each in 168 | a cycle. 169 | 170 | Based on a generator function recipe credited to George Sakkis on the 171 | python docs itertools recipes. 172 | 173 | Examples 174 | -------- 175 | >>> list(Roundrobin('ABC', 'D', 'EF')) 176 | ['A', 'D', 'E', 'B', 'F', 'C'] 177 | """ 178 | 179 | def __init__(self, *iterables, **kwargs): 180 | self.iterables = iterables 181 | self.pending = len(self.iterables) 182 | self.nexts = itertools.cycle(self.iterables) 183 | self.name = kwargs.get('name', 'Unknown') 184 | 185 | def __iter__(self): 186 | return self 187 | 188 | def reset(self): 189 | logging.debug('Resetting generator: %s', self.name) 190 | for it in self.iterables: 191 | iter(it).reset() 192 | self.pending = len(self.iterables) 193 | self.nexts = itertools.cycle(self.iterables) 194 | 195 | def __next__(self): 196 | while self.pending: 197 | try: 198 | for nextgen in self.nexts: 199 | return six.next(nextgen) 200 | except StopIteration: 201 | self.pending -= 1 202 | self.nexts = itertools.cycle(itertools.islice(self.nexts, self.pending)) 203 | raise StopIteration() 204 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/diluvian.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/diluvian.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $$HOME/.local/share/devhelp/diluvian" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/diluvian" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 178 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\diluvian.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\diluvian.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # diluvian documentation build configuration file, created by 5 | # sphinx-quickstart on Tue Jul 9 22:26:36 2013. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | import sys 17 | import os 18 | 19 | # If extensions (or modules to document with autodoc) are in another 20 | # directory, add these directories to sys.path here. If the directory is 21 | # relative to the documentation root, use os.path.abspath to make it 22 | # absolute, like shown here. 23 | #sys.path.insert(0, os.path.abspath('.')) 24 | 25 | # Get the project root dir, which is the parent dir of this 26 | cwd = os.getcwd() 27 | project_root = os.path.dirname(cwd) 28 | 29 | # Insert the project root dir as the first element in the PYTHONPATH. 30 | # This lets us ensure that the source package is imported, and that its 31 | # version is used. 32 | sys.path.insert(0, project_root) 33 | 34 | import diluvian 35 | 36 | # -- General configuration --------------------------------------------- 37 | 38 | # If your documentation needs a minimal Sphinx version, state it here. 39 | #needs_sphinx = '1.0' 40 | 41 | # Add any Sphinx extension module names here, as strings. They can be 42 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 43 | extensions = ['sphinx.ext.autodoc', 44 | 'sphinx.ext.autosummary', 45 | 'sphinx.ext.mathjax', 46 | 'sphinx.ext.viewcode', 47 | 'sphinxarg.ext', 48 | 'numpydoc',] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | # The suffix of source filenames. 54 | source_suffix = '.rst' 55 | 56 | # The encoding of source files. 57 | #source_encoding = 'utf-8-sig' 58 | 59 | # The master toctree document. 60 | master_doc = 'index' 61 | 62 | # General information about the project. 63 | project = u'diluvian' 64 | copyright = u"2017, Andrew S. Champion" 65 | 66 | # The version info for the project you're documenting, acts as replacement 67 | # for |version| and |release|, also used in various other places throughout 68 | # the built documents. 69 | # 70 | # The short X.Y version. 71 | version = diluvian.__version__ 72 | # The full version, including alpha/beta/rc tags. 73 | release = diluvian.__version__ 74 | 75 | # The language for content autogenerated by Sphinx. Refer to documentation 76 | # for a list of supported languages. 77 | #language = None 78 | 79 | # There are two options for replacing |today|: either, you set today to 80 | # some non-false value, then it is used: 81 | #today = '' 82 | # Else, today_fmt is used as the format for a strftime call. 83 | #today_fmt = '%B %d, %Y' 84 | 85 | # List of patterns, relative to source directory, that match files and 86 | # directories to ignore when looking for source files. 87 | exclude_patterns = ['_build'] 88 | 89 | # The reST default role (used for this markup: `text`) to use for all 90 | # documents. 91 | #default_role = None 92 | 93 | # If true, '()' will be appended to :func: etc. cross-reference text. 94 | #add_function_parentheses = True 95 | 96 | # If true, the current module name will be prepended to all description 97 | # unit titles (such as .. function::). 98 | #add_module_names = True 99 | 100 | # If true, sectionauthor and moduleauthor directives will be shown in the 101 | # output. They are ignored by default. 102 | #show_authors = False 103 | 104 | # The name of the Pygments (syntax highlighting) style to use. 105 | pygments_style = 'sphinx' 106 | 107 | # A list of ignored prefixes for module index sorting. 108 | #modindex_common_prefix = [] 109 | 110 | # If true, keep warnings as "system message" paragraphs in the built 111 | # documents. 112 | #keep_warnings = False 113 | 114 | 115 | # -- Options for HTML output ------------------------------------------- 116 | 117 | # The theme to use for HTML and HTML Help pages. See the documentation for 118 | # a list of builtin themes. 119 | html_theme = 'default' 120 | 121 | # Theme options are theme-specific and customize the look and feel of a 122 | # theme further. For a list of options available for each theme, see the 123 | # documentation. 124 | #html_theme_options = {} 125 | 126 | # Add any paths that contain custom themes here, relative to this directory. 127 | #html_theme_path = [] 128 | 129 | # The name for this set of Sphinx documents. If None, it defaults to 130 | # " v documentation". 131 | #html_title = None 132 | 133 | # A shorter title for the navigation bar. Default is the same as 134 | # html_title. 135 | #html_short_title = None 136 | 137 | # The name of an image file (relative to this directory) to place at the 138 | # top of the sidebar. 139 | #html_logo = None 140 | 141 | # The name of an image file (within the static path) to use as favicon 142 | # of the docs. This file should be a Windows icon file (.ico) being 143 | # 16x16 or 32x32 pixels large. 144 | #html_favicon = None 145 | 146 | # Add any paths that contain custom static files (such as style sheets) 147 | # here, relative to this directory. They are copied after the builtin 148 | # static files, so a file named "default.css" will overwrite the builtin 149 | # "default.css". 150 | html_static_path = ['_static'] 151 | 152 | # If not '', a 'Last updated on:' timestamp is inserted at every page 153 | # bottom, using the given strftime format. 154 | #html_last_updated_fmt = '%b %d, %Y' 155 | 156 | # If true, SmartyPants will be used to convert quotes and dashes to 157 | # typographically correct entities. 158 | #html_use_smartypants = True 159 | 160 | # Custom sidebar templates, maps document names to template names. 161 | #html_sidebars = {} 162 | 163 | # Additional templates that should be rendered to pages, maps page names 164 | # to template names. 165 | #html_additional_pages = {} 166 | 167 | # If false, no module index is generated. 168 | #html_domain_indices = True 169 | 170 | # If false, no index is generated. 171 | #html_use_index = True 172 | 173 | # If true, the index is split into individual pages for each letter. 174 | #html_split_index = False 175 | 176 | # If true, links to the reST sources are added to the pages. 177 | #html_show_sourcelink = True 178 | 179 | # If true, "Created using Sphinx" is shown in the HTML footer. 180 | # Default is True. 181 | #html_show_sphinx = True 182 | 183 | # If true, "(C) Copyright ..." is shown in the HTML footer. 184 | # Default is True. 185 | #html_show_copyright = True 186 | 187 | # If true, an OpenSearch description file will be output, and all pages 188 | # will contain a tag referring to it. The value of this option 189 | # must be the base URL from which the finished HTML is served. 190 | #html_use_opensearch = '' 191 | 192 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 193 | #html_file_suffix = None 194 | 195 | # Output file base name for HTML help builder. 196 | htmlhelp_basename = 'diluviandoc' 197 | 198 | 199 | # -- Options for LaTeX output ------------------------------------------ 200 | 201 | latex_elements = { 202 | # The paper size ('letterpaper' or 'a4paper'). 203 | #'papersize': 'letterpaper', 204 | 205 | # The font size ('10pt', '11pt' or '12pt'). 206 | #'pointsize': '10pt', 207 | 208 | # Additional stuff for the LaTeX preamble. 209 | #'preamble': '', 210 | } 211 | 212 | # Grouping the document tree into LaTeX files. List of tuples 213 | # (source start file, target name, title, author, documentclass 214 | # [howto/manual]). 215 | latex_documents = [ 216 | ('index', 'diluvian.tex', 217 | u'diluvian Documentation', 218 | u'Andrew S. Champion', 'manual'), 219 | ] 220 | 221 | # The name of an image file (relative to this directory) to place at 222 | # the top of the title page. 223 | #latex_logo = None 224 | 225 | # For "manual" documents, if this is true, then toplevel headings 226 | # are parts, not chapters. 227 | #latex_use_parts = False 228 | 229 | # If true, show page references after internal links. 230 | #latex_show_pagerefs = False 231 | 232 | # If true, show URL addresses after external links. 233 | #latex_show_urls = False 234 | 235 | # Documents to append as an appendix to all manuals. 236 | #latex_appendices = [] 237 | 238 | # If false, no module index is generated. 239 | #latex_domain_indices = True 240 | 241 | 242 | # -- Options for manual page output ------------------------------------ 243 | 244 | # One entry per manual page. List of tuples 245 | # (source start file, name, description, authors, manual section). 246 | man_pages = [ 247 | ('index', 'diluvian', 248 | u'diluvian Documentation', 249 | [u'Andrew S. Champion'], 1) 250 | ] 251 | 252 | # If true, show URL addresses after external links. 253 | #man_show_urls = False 254 | 255 | 256 | # -- Options for Texinfo output ---------------------------------------- 257 | 258 | # Grouping the document tree into Texinfo files. List of tuples 259 | # (source start file, target name, title, author, 260 | # dir menu entry, description, category) 261 | texinfo_documents = [ 262 | ('index', 'diluvian', 263 | u'diluvian Documentation', 264 | u'Andrew S. Champion', 265 | 'diluvian', 266 | 'Flood filling networks for segmenting electron microscopy of neural tissue.', 267 | 'Miscellaneous'), 268 | ] 269 | 270 | # Documents to append as an appendix to all manuals. 271 | #texinfo_appendices = [] 272 | 273 | # If false, no module index is generated. 274 | #texinfo_domain_indices = True 275 | 276 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 277 | #texinfo_show_urls = 'footnote' 278 | 279 | # If true, do not generate a @detailmenu in the "Top" node's menu. 280 | #texinfo_no_detailmenu = False 281 | 282 | # Fix numpydoc and autosummary compatibility. 283 | # See: https://github.com/numpy/numpydoc/pull/6 284 | numpydoc_class_members_toctree = False 285 | 286 | # Run autodoc here, rather than in a Makefile, so that it is also 287 | # executed by readthedocs.org. 288 | def run_apidoc(_): 289 | from sphinx.apidoc import main 290 | import os 291 | import sys 292 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 293 | cur_dir = os.path.abspath(os.path.dirname(__file__)) 294 | module = os.path.join('..', project) 295 | main(['-e', '-o', cur_dir, module, '--force']) 296 | 297 | def setup(app): 298 | app.connect('builder-inited', run_apidoc) 299 | -------------------------------------------------------------------------------- /tests/test_diluvian.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | test_diluvian 6 | ---------------------------------- 7 | 8 | Tests for `diluvian` module. 9 | """ 10 | 11 | 12 | from __future__ import division 13 | 14 | import numpy as np 15 | from pathlib import Path 16 | import shutil 17 | import pyn5 18 | 19 | from diluvian import octrees 20 | from diluvian import regions 21 | from diluvian import volumes 22 | from diluvian.config import CONFIG 23 | from diluvian.util import ( 24 | binary_confusion_matrix, 25 | confusion_f_score, 26 | get_nonzero_aabb, 27 | ) 28 | 29 | 30 | def test_octree_bounds(): 31 | clip_bounds = (np.zeros(3), np.array([11, 6, 5])) 32 | ot = octrees.OctreeVolume([5, 5, 5], clip_bounds, np.uint8) 33 | ot[clip_bounds[0][0]:clip_bounds[1][0], 34 | clip_bounds[0][1]:clip_bounds[1][1], 35 | clip_bounds[0][2]:clip_bounds[1][2]] = 6 36 | assert isinstance(ot.root_node, octrees.UniformNode), "Constant assignment should make root uniform." 37 | 38 | ot[8, 5, 4] = 5 39 | expected_mat = np.array([[[6], [6]], [[6], [5]]], dtype=np.uint8) 40 | assert np.array_equal(ot[7:9, 4:6, 4], expected_mat), "Assignment should break uniformity." 41 | 42 | expected_types = [[[octrees.BranchNode, None], [None, None]], 43 | [[octrees.UniformBranchNode, None], [None, None]]] 44 | for i, col in enumerate(expected_types): 45 | for j, row in enumerate(col): 46 | for k, expected_type in enumerate(row): 47 | if expected_type is None: 48 | assert ot.root_node.children[i][j][k] is None, "Clip bounds should make most nodes empty." 49 | else: 50 | assert isinstance(ot.root_node.children[i][j][k], expected_type), "Nodes are wrong type." 51 | 52 | np.testing.assert_almost_equal(ot.fullness(), 2.0/3.0, err_msg='Octree fullness should be relative to clip bounds.') 53 | 54 | ot[10, 5, 4] = 5 # Break the remaining top-level uniform branch node. 55 | np.testing.assert_almost_equal(ot.fullness(), 1.0, err_msg='Octree fullness should be relative to clip bounds.') 56 | 57 | np.testing.assert_array_equal(ot.get_leaf_bounds()[1], clip_bounds[1], 58 | err_msg='Leaf bounds should be clipped to clip bounds') 59 | 60 | 61 | def test_octree_map_copy(): 62 | clip_bounds = (np.zeros(3), np.array([11, 6, 5])) 63 | ot = octrees.OctreeVolume([5, 5, 5], clip_bounds, np.uint8) 64 | ot[clip_bounds[0][0]:clip_bounds[1][0], 65 | clip_bounds[0][1]:clip_bounds[1][1], 66 | clip_bounds[0][2]:clip_bounds[1][2]] = 6 67 | 68 | ot[8, 5, 4] = 5 69 | 70 | def leaf_map(a): 71 | return a * -1 72 | 73 | def uniform_map(v): 74 | return v * 1.5 75 | 76 | cot = ot.map_copy(np.float32, leaf_map, uniform_map) 77 | for orig, copy in zip(ot.iter_leaves(), cot.iter_leaves()): 78 | np.testing.assert_almost_equal(copy.bounds[0], orig.bounds[0], err_msg='Copy leaves should have same bounds.') 79 | np.testing.assert_almost_equal(copy.bounds[1], orig.bounds[1], err_msg='Copy leaves should have same bounds.') 80 | np.testing.assert_almost_equal(copy.data, leaf_map(orig.data), err_msg='Copy leaves should be mapped.') 81 | expected_mat = np.array([[[9.], [-6.]], [[9.], [-5.]]], dtype=np.float32) 82 | assert np.array_equal(cot[7:9, 4:6, 4], expected_mat), 'Copy should have same uniformity.' 83 | 84 | 85 | def test_region_moves(): 86 | mock_image = np.zeros(tuple(CONFIG.model.training_subv_shape), dtype=np.float32) 87 | region = regions.Region(mock_image) 88 | mock_mask = np.zeros(tuple(CONFIG.model.output_fov_shape), dtype=np.float32) 89 | ctr = np.array(mock_mask.shape) // 2 90 | expected_moves = {} 91 | for i, move in enumerate(map(np.array, [(1, 0, 0), (-1, 0, 0), 92 | (0, 1, 0), (0, -1, 0), 93 | (0, 0, 1), (0, 0, -1)])): 94 | val = 0.1 * (i + 1) 95 | coord = ctr + (region.MOVE_DELTA * move) + np.array([2, 2, 2]) * (np.ones(3) - np.abs(move)) 96 | mock_mask[tuple(coord.astype(np.int64))] = val 97 | expected_moves[tuple(move)] = val 98 | 99 | moves = region.get_moves(mock_mask) 100 | for move in moves: 101 | np.testing.assert_allclose(expected_moves[tuple(move['move'])], move['v']) 102 | 103 | # Test thick move check planes. 104 | mock_mask[:] = 0 105 | for i, move in enumerate(map(np.array, [(1, 0, 0), (-1, 0, 0), 106 | (0, 1, 0), (0, -1, 0), 107 | (0, 0, 1), (0, 0, -1)])): 108 | val = 0.15 * (i + 1) 109 | coord = ctr + ((region.MOVE_DELTA + 1) * move) + np.array([2, 2, 2]) * (np.ones(3) - np.abs(move)) 110 | mock_mask[tuple(coord.astype(np.int64))] = val 111 | expected_moves[tuple(move)] = val 112 | 113 | region.move_check_thickness = 2 114 | moves = region.get_moves(mock_mask) 115 | for move in moves: 116 | np.testing.assert_allclose(expected_moves[tuple(move['move'])], move['v']) 117 | 118 | 119 | def test_volume_transforms(): 120 | mock_image = np.arange(64 * 64 * 64, dtype=np.uint8).reshape((64, 64, 64)) 121 | mock_label = np.zeros((64, 64, 64), dtype=np.int64) 122 | 123 | v = volumes.Volume((1, 1, 1), image_data=mock_image, label_data=mock_label) 124 | pv = v.partition([1, 1, 2], [0, 0, 1]) 125 | dpv = pv.downsample((4, 4, 1)) 126 | 127 | np.testing.assert_array_equal(dpv.local_coord_to_world(np.array([2, 2, 2])), np.array([8, 8, 34])) 128 | np.testing.assert_array_equal(dpv.world_coord_to_local(np.array([8, 8, 34])), np.array([2, 2, 2])) 129 | 130 | svb = volumes.SubvolumeBounds(np.array((0, 0, 32), dtype=np.int64), 131 | np.array((4, 4, 33), dtype=np.int64)) 132 | sv = v.get_subvolume(svb) 133 | 134 | dpsvb = volumes.SubvolumeBounds(np.array((0, 0, 0), dtype=np.int64), 135 | np.array((1, 1, 1), dtype=np.int64)) 136 | dpsv = dpv.get_subvolume(dpsvb) 137 | 138 | np.testing.assert_array_equal(dpsv.image, sv.image.reshape((1, 4, 1, 4, 1, 1)).mean(5).mean(3).mean(1)) 139 | 140 | 141 | def test_volume_transforms_image_stacks(): 142 | # stack info 143 | si = { 144 | "bounds": [28128, 31840, 4841], 145 | "resolution": [3.8, 3.8, 50], 146 | "tile_width": 512, 147 | "tile_height": 512, 148 | "translation": [0, 0, 0], 149 | } 150 | # tile stack parameters 151 | tsp = { 152 | "source_base_url": "https://neurocean.janelia.org/ssd-tiles-no-cache/0111-8/", 153 | "file_extension": "jpg", 154 | "tile_width": 512, 155 | "tile_height": 512, 156 | "tile_source_type": 4, 157 | } 158 | v = volumes.ImageStackVolume.from_catmaid_stack(si, tsp) 159 | pv = v.partition( 160 | [2, 1, 1], [1, 0, 0] 161 | ) # Note axes are flipped after volume initialization 162 | dpv = pv.downsample((50, 15.2, 15.2)) 163 | 164 | np.testing.assert_array_equal( 165 | dpv.local_coord_to_world(np.array([2, 2, 2])), np.array([2422, 8, 8]) 166 | ) 167 | np.testing.assert_array_equal( 168 | dpv.world_coord_to_local(np.array([2422, 8, 8])), np.array([2, 2, 2]) 169 | ) 170 | 171 | svb = volumes.SubvolumeBounds( 172 | np.array((2420, 0, 0), dtype=np.int64), 173 | np.array((2421, 4, 4), dtype=np.int64), 174 | ) 175 | sv = v.get_subvolume(svb) 176 | 177 | dpsvb = volumes.SubvolumeBounds( 178 | np.array((0, 0, 0), dtype=np.int64), np.array((1, 1, 1), dtype=np.int64) 179 | ) 180 | dpsv = dpv.get_subvolume(dpsvb) 181 | 182 | np.testing.assert_array_equal( 183 | dpsv.image, sv.image.reshape((1, 4, 1, 4, 1, 1)).mean(5).mean(3).mean(1) 184 | ) 185 | 186 | 187 | def test_volume_transforms_n5_volume(): 188 | # Create test n5 dataset 189 | test_dataset_path = Path("test.n5") 190 | if test_dataset_path.is_dir(): 191 | shutil.rmtree(str(test_dataset_path.absolute())) 192 | pyn5.create_dataset("test.n5", "test", [10, 10, 10], [2, 2, 2], "UINT8") 193 | test_dataset = pyn5.open("test.n5", "test") 194 | 195 | test_data = np.zeros([10, 10, 10]).astype(int) 196 | x = np.linspace(0, 9, 10).reshape([10, 1, 1]).astype(int) 197 | test_data = test_data + x + x.transpose([1, 2, 0]) + x.transpose([2, 0, 1]) 198 | 199 | block_starts = [(i % 5, i // 5 % 5, i // 25 % 5) for i in range(5 ** 3)] 200 | for block_start in block_starts: 201 | current_bound = list( 202 | map(slice, [2 * x for x in block_start], [2 * x + 2 for x in block_start]) 203 | ) 204 | flattened = test_data[current_bound].reshape(-1) 205 | try: 206 | test_dataset.write_block(block_start, flattened) 207 | except Exception as e: 208 | raise AssertionError("Writing to n5 failed! Could not create test dataset.\nError: {}".format(e)) 209 | 210 | v = volumes.N5Volume("test.n5", 211 | {"image": {"path": "test", "dtype": "UINT8"}}, 212 | bounds=[10, 10, 10], 213 | resolution=[1, 1, 1]) 214 | pv = v.partition( 215 | [2, 1, 1], [1, 0, 0] 216 | ) # Note axes are flipped after volume initialization 217 | dpv = pv.downsample((2, 2, 2)) 218 | 219 | np.testing.assert_array_equal( 220 | dpv.local_coord_to_world(np.array([2, 2, 2])), np.array([9, 4, 4]) 221 | ) 222 | np.testing.assert_array_equal( 223 | dpv.world_coord_to_local(np.array([9, 4, 4])), np.array([2, 2, 2]) 224 | ) 225 | 226 | svb = volumes.SubvolumeBounds( 227 | np.array((5, 0, 0), dtype=np.int64), np.array((7, 2, 2), dtype=np.int64) 228 | ) 229 | sv = v.get_subvolume(svb) 230 | 231 | dpsvb = volumes.SubvolumeBounds( 232 | np.array((0, 0, 0), dtype=np.int64), np.array((1, 1, 1), dtype=np.int64) 233 | ) 234 | dpsv = dpv.get_subvolume(dpsvb) 235 | 236 | np.testing.assert_array_equal( 237 | dpsv.image, sv.image.reshape((1, 2, 1, 2, 1, 2)).mean(5).mean(3).mean(1) 238 | ) 239 | 240 | # sanity check that test.n5 contains varying data 241 | svb2 = volumes.SubvolumeBounds( 242 | np.array((5, 0, 1), dtype=np.int64), np.array((7, 2, 3), dtype=np.int64) 243 | ) 244 | sv2 = v.get_subvolume(svb2) 245 | assert not all(sv.image.flatten() == sv2.image.flatten()) 246 | 247 | if test_dataset_path.is_dir(): 248 | shutil.rmtree(str(test_dataset_path.absolute())) 249 | 250 | 251 | def test_volume_identity_downsample_returns_self(): 252 | resolution = (27, 185, 90) 253 | v = volumes.Volume(resolution, image_data=np.zeros((1, 1, 1)), label_data=np.zeros((1, 1, 1))) 254 | dv = v.downsample(resolution) 255 | 256 | assert v == dv 257 | 258 | 259 | def test_nonzero_aabb(): 260 | a = np.zeros([10, 10, 10], dtype=np.int32) 261 | a[8, 7, 6] = 1 262 | 263 | amin, amax = get_nonzero_aabb(a) 264 | np.testing.assert_array_equal(amin, [8, 7, 6]) 265 | np.testing.assert_array_equal(amax, [9, 8, 7]) 266 | 267 | a[6, 7, 8] = 1 268 | amin, amax = get_nonzero_aabb(a) 269 | np.testing.assert_array_equal(amin, [6, 7, 6]) 270 | np.testing.assert_array_equal(amax, [9, 8, 9]) 271 | 272 | 273 | def test_confusion_matrix(): 274 | a = np.zeros([3, 3, 3], dtype=np.bool) 275 | a[2, 2, :] = True 276 | b = np.ones([3, 3, 3], dtype=np.bool) 277 | b[:, 2, 2] = False 278 | 279 | cm = np.array([[2, 22], [1, 2]]) 280 | np.testing.assert_array_equal(binary_confusion_matrix(a.flatten(), b.flatten()), cm) 281 | 282 | 283 | def test_f1_score(): 284 | a = np.array([[375695, 6409], [31208, 67419]]) 285 | 286 | np.testing.assert_almost_equal(confusion_f_score(a, 1.0), 0.782, decimal=3) 287 | assert confusion_f_score(np.eye(2), 1.0) == 1.0 288 | -------------------------------------------------------------------------------- /diluvian/network.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Flood-fill network creation and compilation using Keras.""" 3 | 4 | 5 | from __future__ import division 6 | 7 | import inspect 8 | 9 | import numpy as np 10 | import six 11 | 12 | from keras.layers import ( 13 | BatchNormalization, 14 | Conv3D, 15 | Conv3DTranspose, 16 | Cropping3D, 17 | Dropout, 18 | Input, 19 | Lambda, 20 | Permute, 21 | ) 22 | from keras.layers.merge import ( 23 | add, 24 | concatenate, 25 | ) 26 | from keras.layers.core import Activation 27 | from keras.models import load_model as keras_load_model, Model 28 | from keras.utils import multi_gpu_model 29 | import keras.optimizers 30 | 31 | 32 | def make_flood_fill_network(input_fov_shape, output_fov_shape, network_config): 33 | """Construct a stacked convolution module flood filling network. 34 | """ 35 | if network_config.convolution_padding != 'same': 36 | raise ValueError('ResNet implementation only supports same padding.') 37 | 38 | image_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='image_input') 39 | if network_config.rescale_image: 40 | ffn = Lambda(lambda x: (x - 0.5) * 2.0)(image_input) 41 | else: 42 | ffn = image_input 43 | mask_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='mask_input') 44 | ffn = concatenate([ffn, mask_input]) 45 | 46 | # Convolve and activate before beginning the skip connection modules, 47 | # as discussed in the Appendix of He et al 2016. 48 | ffn = Conv3D( 49 | network_config.convolution_filters, 50 | tuple(network_config.convolution_dim), 51 | kernel_initializer=network_config.initialization, 52 | activation=network_config.convolution_activation, 53 | padding='same')(ffn) 54 | if network_config.batch_normalization: 55 | ffn = BatchNormalization()(ffn) 56 | 57 | contraction = (input_fov_shape - output_fov_shape) // 2 58 | if np.any(np.less(contraction, 0)): 59 | raise ValueError('Output FOV shape can not be larger than input FOV shape.') 60 | contraction_cumu = np.zeros(3, dtype=np.int32) 61 | contraction_step = np.divide(contraction, float(network_config.num_modules)) 62 | 63 | for i in range(0, network_config.num_modules): 64 | ffn = add_convolution_module(ffn, network_config) 65 | contraction_dims = np.floor(i * contraction_step - contraction_cumu).astype(np.int32) 66 | if np.count_nonzero(contraction_dims): 67 | ffn = Cropping3D(zip(list(contraction_dims), list(contraction_dims)))(ffn) 68 | contraction_cumu += contraction_dims 69 | 70 | if np.any(np.less(contraction_cumu, contraction)): 71 | remainder = contraction - contraction_cumu 72 | ffn = Cropping3D(zip(list(remainder), list(remainder)))(ffn) 73 | 74 | mask_output = Conv3D( 75 | 1, 76 | tuple(network_config.convolution_dim), 77 | kernel_initializer=network_config.initialization, 78 | padding='same', 79 | name='mask_output', 80 | activation=network_config.output_activation)(ffn) 81 | ffn = Model(inputs=[image_input, mask_input], outputs=[mask_output]) 82 | 83 | return ffn 84 | 85 | 86 | def add_convolution_module(model, network_config): 87 | model2 = model 88 | 89 | for _ in range(network_config.num_layers_per_module): 90 | model2 = Conv3D( 91 | network_config.convolution_filters, 92 | tuple(network_config.convolution_dim), 93 | kernel_initializer=network_config.initialization, 94 | activation=network_config.convolution_activation, 95 | padding='same')(model2) 96 | if network_config.batch_normalization: 97 | model2 = BatchNormalization()(model2) 98 | 99 | model = add([model, model2]) 100 | # Note that the activation here differs from He et al 2016, as that 101 | # activation is not on the skip connection path. However, this is not 102 | # likely to be important, see: 103 | # http://torch.ch/blog/2016/02/04/resnets.html 104 | # https://github.com/gcr/torch-residual-networks 105 | model = Activation(network_config.convolution_activation)(model) 106 | if network_config.batch_normalization: 107 | model = BatchNormalization()(model) 108 | if network_config.dropout_probability > 0.0: 109 | model = Dropout(network_config.dropout_probability)(model) 110 | 111 | return model 112 | 113 | 114 | def make_flood_fill_unet(input_fov_shape, output_fov_shape, network_config): 115 | """Construct a U-net flood filling network. 116 | """ 117 | image_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='image_input') 118 | if network_config.rescale_image: 119 | ffn = Lambda(lambda x: (x - 0.5) * 2.0)(image_input) 120 | else: 121 | ffn = image_input 122 | mask_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='mask_input') 123 | ffn = concatenate([ffn, mask_input]) 124 | 125 | # Note that since the Keras 2 upgrade strangely models with depth > 3 are 126 | # rejected by TF. 127 | ffn = add_unet_layer(ffn, network_config, network_config.unet_depth - 1, output_fov_shape, 128 | n_channels=network_config.convolution_filters) 129 | 130 | mask_output = Conv3D( 131 | 1, 132 | (1, 1, 1), 133 | kernel_initializer=network_config.initialization, 134 | padding=network_config.convolution_padding, 135 | name='mask_output', 136 | activation=network_config.output_activation)(ffn) 137 | ffn = Model(inputs=[image_input, mask_input], outputs=[mask_output]) 138 | 139 | return ffn 140 | 141 | 142 | def add_unet_layer(model, network_config, remaining_layers, output_shape, n_channels=None, resolution=None): 143 | if n_channels is None: 144 | n_channels = model.get_shape().as_list()[-1] 145 | 146 | if network_config.unet_downsample_mode == "fixed_rate": 147 | downsample = np.array([x != 0 and remaining_layers % x == 0 for x in network_config.unet_downsample_rate]) 148 | else: 149 | resolution = resolution if resolution is not None else network_config.resolution 150 | min_res = np.min(resolution) 151 | # x < min_res * sqrt(2) because: 152 | # if a > sqrt(2)b, then a/b > sqrt(2) and 2b/a < sqrt(2) 153 | # if sqrt(2)b > a > b, then 2a/2b < sqrt(2) and 2b/a > sqrt(2) 154 | downsample = np.array([x < min_res * (2 ** .5) for x in resolution]) 155 | 156 | if network_config.convolution_padding == 'same': 157 | conv_contract = np.zeros(3, dtype=np.int32) 158 | else: 159 | conv_contract = network_config.convolution_dim - 1 160 | 161 | # First U convolution module. 162 | for i in range(network_config.num_layers_per_module): 163 | if i == network_config.num_layers_per_module - 1: 164 | # Increase the number of channels before downsampling to avoid 165 | # bottleneck (identical to 3D U-Net paper). 166 | n_channels = 2 * n_channels 167 | model = Conv3D( 168 | n_channels, 169 | tuple(network_config.convolution_dim), 170 | kernel_initializer=network_config.initialization, 171 | activation=network_config.convolution_activation, 172 | padding=network_config.convolution_padding)(model) 173 | if network_config.batch_normalization: 174 | model = BatchNormalization()(model) 175 | 176 | # Crop and pass forward to upsampling. 177 | if remaining_layers > 0: 178 | forward_link_shape = output_shape + network_config.num_layers_per_module * conv_contract 179 | else: 180 | forward_link_shape = output_shape 181 | contraction = (np.array(model.get_shape().as_list()[1:4]) - forward_link_shape) // 2 182 | forward = Cropping3D(list(zip(list(contraction), list(contraction))))(model) 183 | if network_config.dropout_probability > 0.0: 184 | forward = Dropout(network_config.dropout_probability)(forward) 185 | 186 | # Terminal layer of the U. 187 | if remaining_layers <= 0: 188 | return forward 189 | 190 | # Downsample and recurse. 191 | model = Conv3D( 192 | n_channels, 193 | tuple(network_config.convolution_dim), 194 | strides=list(downsample + 1), 195 | kernel_initializer=network_config.initialization, 196 | activation=network_config.convolution_activation, 197 | padding='same')(model) 198 | if network_config.batch_normalization: 199 | model = BatchNormalization()(model) 200 | next_output_shape = np.ceil(np.divide(forward_link_shape, downsample.astype(np.float32) + 1.0)).astype(np.int32) 201 | if network_config.unet_downsample_mode == "fixed_rate": 202 | model = add_unet_layer(model, 203 | network_config, 204 | remaining_layers - 1, 205 | next_output_shape.astype(np.int32)) 206 | else: 207 | model = add_unet_layer(model, 208 | network_config, 209 | remaining_layers - 1, 210 | next_output_shape.astype(np.int32), 211 | None, 212 | resolution * (downsample + 1)) 213 | 214 | # Upsample output of previous layer and merge with forward link. 215 | model = Conv3DTranspose( 216 | n_channels * 2, 217 | tuple(network_config.convolution_dim), 218 | strides=list(downsample + 1), 219 | kernel_initializer=network_config.initialization, 220 | activation=network_config.convolution_activation, 221 | padding='same')(model) 222 | if network_config.batch_normalization: 223 | model = BatchNormalization()(model) 224 | # Must crop output because Keras wrongly pads the output shape for odd array sizes. 225 | stride_pad = (network_config.convolution_dim // 2) * np.array(downsample) + (1 - np.mod(forward_link_shape, 2)) 226 | tf_pad_start = stride_pad // 2 # Tensorflow puts odd padding at end. 227 | model = Cropping3D(list(zip(list(tf_pad_start), list(stride_pad - tf_pad_start))))(model) 228 | 229 | model = concatenate([forward, model]) 230 | 231 | # Second U convolution module. 232 | for _ in range(network_config.num_layers_per_module): 233 | model = Conv3D( 234 | n_channels, 235 | tuple(network_config.convolution_dim), 236 | kernel_initializer=network_config.initialization, 237 | activation=network_config.convolution_activation, 238 | padding=network_config.convolution_padding)(model) 239 | if network_config.batch_normalization: 240 | model = BatchNormalization()(model) 241 | 242 | return model 243 | 244 | 245 | def compile_network(model, optimizer_config): 246 | optimizer_klass = getattr(keras.optimizers, optimizer_config.klass) 247 | optimizer_kwargs = inspect.getargspec(optimizer_klass.__init__)[0] 248 | optimizer_kwargs = {k: v for k, v in six.iteritems(optimizer_config.__dict__) if k in optimizer_kwargs} 249 | optimizer = optimizer_klass(**optimizer_kwargs) 250 | model.compile(loss=optimizer_config.loss, 251 | optimizer=optimizer) 252 | 253 | 254 | def load_model(model_file, network_config): 255 | model = keras_load_model(model_file) 256 | 257 | # If necessary, wrap the loaded model to transpose the axes for both 258 | # inputs and outputs. 259 | if network_config.transpose: 260 | inputs = [] 261 | perms = [] 262 | for old_input in model.input_layers: 263 | input_shape = np.asarray(old_input.input_shape)[[3, 2, 1, 4]] 264 | new_input = Input(shape=tuple(input_shape), dtype=old_input.input_dtype, name=old_input.name) 265 | perm = Permute((3, 2, 1, 4), input_shape=tuple(input_shape))(new_input) 266 | inputs.append(new_input) 267 | perms.append(perm) 268 | 269 | old_outputs = model(perms) 270 | if not isinstance(old_outputs, list): 271 | old_outputs = [old_outputs] 272 | 273 | outputs = [] 274 | for old_output in old_outputs: 275 | new_output = Permute((3, 2, 1, 4))(old_output) 276 | outputs.append(new_output) 277 | 278 | new_model = Model(input=inputs, output=outputs) 279 | 280 | # Monkeypatch the save to save just the underlying model. 281 | func_type = type(model.save) 282 | 283 | old_model = model 284 | 285 | def new_save(_, *args, **kwargs): 286 | old_model.save(*args, **kwargs) 287 | new_model.save = func_type(new_save, new_model) 288 | 289 | model = new_model 290 | 291 | return model 292 | 293 | 294 | def make_parallel(model, gpus=None): 295 | new_model = multi_gpu_model(model, gpus) 296 | func_type = type(model.save) 297 | 298 | # monkeypatch the save to save just the underlying model 299 | def new_save(_, *args, **kwargs): 300 | model.save(*args, **kwargs) 301 | new_model.save = func_type(new_save, new_model) 302 | 303 | return new_model 304 | -------------------------------------------------------------------------------- /diluvian/octrees.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Simple octree data structures for block sparse 3D arrays.""" 3 | 4 | 5 | from __future__ import division 6 | 7 | import numpy as np 8 | 9 | 10 | class OctreeVolume(object): 11 | """Octree-backed block sparse 3D array. 12 | 13 | This is a trivial implementation of an octree with NumPy ndarray leaves for 14 | block sparse volume access. This allows oblivious in-memory access of 15 | dense regions spanning out-of-memory volumes by providing read leaves 16 | via a populator. For writing, the octree supports uniform value terminal 17 | nodes at every level, so that only non-uniform data must be written to 18 | leaf level. 19 | 20 | Parameters 21 | ---------- 22 | leaf_shape : tuple of int or ndarray 23 | Shape of tree leaves in voxels. 24 | bounds : tuple of tuple of int or ndarray 25 | The lower and upper coordinate bounds of the volume, in voxels. 26 | dtype : numpy.data-type 27 | populator : function, optional 28 | A function taking a tuple of ndarray bounds for the coordinates of 29 | the subvolume to populate and returning the data for that subvolume. 30 | """ 31 | 32 | def __init__(self, leaf_shape, bounds, dtype, populator=None): 33 | self.leaf_shape = np.asarray(leaf_shape).astype(np.int64) 34 | self.bounds = (np.asarray(bounds[0], dtype=np.int64), 35 | np.asarray(bounds[1], dtype=np.int64)) 36 | self.dtype = np.dtype(dtype) 37 | self.populator = populator 38 | ceil_bounds = self.leaf_shape * \ 39 | np.exp2(np.ceil(np.log2((self.bounds[1] - self.bounds[0]) / 40 | self.leaf_shape.astype(np.float64)))).astype(np.int64).max() 41 | self.root_node = BranchNode(self, (self.bounds[0], self.bounds[0] + ceil_bounds), clip_bound=self.bounds[1]) 42 | 43 | @property 44 | def shape(self): 45 | return tuple(self.root_node.get_size()) 46 | 47 | def get_checked_np_key(self, key): 48 | # Special exception for [:] for uniform assignment. 49 | if isinstance(key, slice) and key.start is None and key.stop is None: 50 | return self.bounds 51 | 52 | if not hasattr(key, '__len__') or len(key) != 3: 53 | raise IndexError('Octrees may only be indexed in 3 dimensions') 54 | 55 | # Convert keys to two numpy arrays for ease. 56 | npkey = (np.zeros(3, dtype=np.int64), np.zeros(3, dtype=np.int64)) 57 | for i, k in enumerate(key): 58 | if isinstance(k, slice): 59 | if k.step is not None: 60 | raise IndexError('Octrees do not yet support step slicing') 61 | npkey[0][i] = k.start if k.start is not None else self.bounds[0][i] 62 | npkey[1][i] = k.stop if k.stop is not None else self.bounds[1][i] 63 | else: 64 | npkey[0][i] = k 65 | npkey[1][i] = k + 1 66 | 67 | if np.any(np.less(npkey[0], self.bounds[0])) or \ 68 | np.any(np.greater(npkey[1], self.bounds[1])) or \ 69 | np.any(np.greater_equal(npkey[0], npkey[1])): 70 | raise IndexError('Invalid indices: outside bounds or empty interval: ' 71 | '{} (bounds {})'.format(str(key), str(self.bounds))) 72 | 73 | return npkey 74 | 75 | def __getitem__(self, key): 76 | npkey = self.get_checked_np_key(key) 77 | 78 | return self.root_node[npkey] 79 | 80 | def __setitem__(self, key, value): 81 | npkey = self.get_checked_np_key(key) 82 | 83 | self.root_node[npkey] = value 84 | 85 | def iter_leaves(self): 86 | """Iterator over all non-uniform leaf nodes. 87 | 88 | Yields 89 | ------ 90 | LeafNode 91 | """ 92 | for leaf in self.root_node.iter_leaves(): 93 | yield leaf 94 | 95 | def get_leaf_bounds(self): 96 | bounds = [np.array(self.bounds[1]), np.array(self.bounds[0])] 97 | for leaf in self.iter_leaves(): 98 | bounds[0] = np.minimum(bounds[0], leaf.bounds[0]) 99 | bounds[1] = np.maximum(bounds[1], leaf.bounds[1]) 100 | 101 | bounds[0] = np.maximum(bounds[0], self.bounds[0]) 102 | bounds[1] = np.minimum(bounds[1], self.bounds[1]) 103 | 104 | return bounds 105 | 106 | def map_copy(self, dtype, leaf_map, uniform_map): 107 | """Create a copy of this octree by mapping node data. 108 | 109 | Note that because leaves and uniform nodes can have separate mapping, 110 | the ranges of this tree and the copied tree may not be bijective. 111 | 112 | Populators are not copied. 113 | 114 | Parameters 115 | ---------- 116 | dtype : numpy.data-type 117 | Data type for the constructed copy 118 | leaf_map : function 119 | Function mapping leaf node data for the constructed copy. 120 | uniform_map : function 121 | Function mapping uniform node values. 122 | 123 | Returns 124 | ------- 125 | OctreeVolume 126 | Copied octree with the same structure as this octree. 127 | """ 128 | copy = OctreeVolume(self.leaf_shape, self.bounds, dtype) 129 | copy.root_node = self.root_node.map_copy(copy, leaf_map, uniform_map) 130 | return copy 131 | 132 | def fullness(self): 133 | potential_leaves = np.prod(np.ceil(np.true_divide(self.bounds[1] - self.bounds[0], self.leaf_shape))) 134 | return self.root_node.count_leaves() / float(potential_leaves) 135 | 136 | def get_volume(self): 137 | return self 138 | 139 | def replace_child(self, child, replacement): 140 | if child != self.root_node: 141 | raise ValueError('Attempt to replace unknown child') 142 | 143 | self.root_node = replacement 144 | 145 | 146 | class Node(object): 147 | def __init__(self, parent, bounds, clip_bound=None): 148 | self.parent = parent 149 | self.bounds = (bounds[0].copy(), bounds[1].copy()) 150 | self.clip_bound = clip_bound 151 | 152 | def count_leaves(self): 153 | return 0 154 | 155 | def get_intersection(self, key): 156 | return (np.maximum(self.bounds[0], key[0]), 157 | np.minimum(self.bounds[1], key[1])) 158 | 159 | def get_size(self): 160 | if self.clip_bound is not None: 161 | return self.clip_bound - self.bounds[0] 162 | return self.bounds[1] - self.bounds[0] 163 | 164 | def get_volume(self): 165 | return self.parent.get_volume() 166 | 167 | def replace(self, replacement): 168 | self.parent.replace_child(self, replacement) 169 | self.parent = None 170 | 171 | 172 | class BranchNode(Node): 173 | def __init__(self, parent, bounds, **kwargs): 174 | super(BranchNode, self).__init__(parent, bounds, **kwargs) 175 | self.midpoint = (self.bounds[1] + self.bounds[0]) // 2 176 | self.children = [[[None for _ in range(2)] for _ in range(2)] for _ in range(2)] 177 | 178 | def count_leaves(self): 179 | return sum(c.count_leaves() for s in self.children for r in s for c in r if c is not None) 180 | 181 | def iter_leaves(self): 182 | for i in range(2): 183 | for j in range(2): 184 | for k in range(2): 185 | child = self.children[i][j][k] 186 | if child is None or isinstance(child, UniformNode): 187 | continue 188 | for leaf in child.iter_leaves(): 189 | yield leaf 190 | 191 | def map_copy(self, copy_parent, leaf_map, uniform_map): 192 | copy = BranchNode(copy_parent, self.bounds, clip_bound=self.clip_bound) 193 | for i in range(2): 194 | for j in range(2): 195 | for k in range(2): 196 | child = self.children[i][j][k] 197 | if child is None: 198 | copy.children[i][j][k] = None 199 | else: 200 | copy.children[i][j][k] = child.map_copy(copy, leaf_map, uniform_map) 201 | return copy 202 | 203 | def get_children_mask(self, key): 204 | p = (np.less(key[0], self.midpoint), 205 | np.greater(key[1], self.midpoint)) 206 | 207 | # TODO must be some way to do combinatorial ops like this with numpy. 208 | return list(zip(*np.where([[[p[i][0] and p[j][1] and p[k][2] 209 | for k in range(2)] 210 | for j in range(2)] 211 | for i in range(2)]))) 212 | 213 | def get_child_bounds(self, i, j, k): 214 | mins = (self.bounds[0], self.midpoint) 215 | maxs = (self.midpoint, self.bounds[1]) 216 | child_bounds = (np.array((mins[i][0], mins[j][1], mins[k][2])), 217 | np.array((maxs[i][0], maxs[j][1], maxs[k][2]))) 218 | if self.clip_bound is not None: 219 | clip_bound = np.minimum(child_bounds[1], self.clip_bound) 220 | if np.array_equal(clip_bound, child_bounds[1]): 221 | clip_bound = None 222 | else: 223 | clip_bound = None 224 | 225 | return (child_bounds, clip_bound) 226 | 227 | def __getitem__(self, key): 228 | inds = self.get_children_mask(key) 229 | 230 | for i, j, k in inds: 231 | if self.children[i][j][k] is None: 232 | self.populate_child(i, j, k) 233 | 234 | if len(inds) == 1: 235 | i, j, k = inds[0] 236 | return self.children[i][j][k][key] 237 | 238 | chunk = np.empty(tuple(key[1] - key[0]), self.get_volume().dtype) 239 | for i, j, k in inds: 240 | child = self.children[i][j][k] 241 | subchunk = child.get_intersection(key) 242 | ind = (subchunk[0] - key[0], subchunk[1] - key[0]) 243 | chunk[ind[0][0]:ind[1][0], 244 | ind[0][1]:ind[1][1], 245 | ind[0][2]:ind[1][2]] = child[subchunk] 246 | 247 | return chunk 248 | 249 | def __setitem__(self, key, value): 250 | if (not hasattr(value, '__len__') or len(value) == 1) and \ 251 | np.array_equal(key[0], self.bounds[0]) and \ 252 | np.array_equal(key[1], self.clip_bound): 253 | self.replace(UniformBranchNode(self.parent, self.bounds, self.get_volume().dtype, value, 254 | clip_bound=self.clip_bound)) 255 | return 256 | 257 | inds = self.get_children_mask(key) 258 | 259 | for i, j, k in inds: 260 | if self.children[i][j][k] is None: 261 | self.populate_child(i, j, k) 262 | 263 | for i, j, k in inds: 264 | child = self.children[i][j][k] 265 | subchunk = child.get_intersection(key) 266 | ind = (subchunk[0] - key[0], subchunk[1] - key[0]) 267 | if isinstance(value, np.ndarray): 268 | child[subchunk] = value[ind[0][0]:ind[1][0], 269 | ind[0][1]:ind[1][1], 270 | ind[0][2]:ind[1][2]] 271 | else: 272 | child[subchunk] = value 273 | 274 | def populate_child(self, i, j, k): 275 | volume = self.get_volume() 276 | if volume.populator is None: 277 | raise ValueError('Attempt to retrieve unpopulated region without octree populator') 278 | 279 | child_bounds, child_clip_bound = self.get_child_bounds(i, j, k) 280 | child_shape = child_bounds[1] - child_bounds[0] 281 | if np.any(np.less_equal(child_shape, volume.leaf_shape)): 282 | populator_bounds = [child_bounds[0].copy(), child_bounds[1].copy()] 283 | if child_clip_bound is not None: 284 | populator_bounds[1] = np.minimum(populator_bounds[1], child_clip_bound) 285 | data = volume.populator(populator_bounds).astype(volume.dtype) 286 | child = LeafNode(self, child_bounds, data) 287 | else: 288 | child = BranchNode(self, child_bounds, clip_bound=child_clip_bound) 289 | 290 | self.children[i][j][k] = child 291 | 292 | def replace_child(self, child, replacement): 293 | for i in range(2): 294 | for j in range(2): 295 | for k in range(2): 296 | if child == self.children[i][j][k]: 297 | self.children[i][j][k] = replacement 298 | return 299 | 300 | raise ValueError('Attempt to replace unknown child') 301 | 302 | 303 | class LeafNode(Node): 304 | def __init__(self, parent, bounds, data): 305 | super(LeafNode, self).__init__(parent, bounds) 306 | self.data = data.copy() 307 | 308 | def count_leaves(self): 309 | return 1 310 | 311 | def iter_leaves(self): 312 | yield self 313 | 314 | def map_copy(self, copy_parent, leaf_map, uniform_map): 315 | copy = LeafNode(copy_parent, self.bounds, leaf_map(self.data)) 316 | return copy 317 | 318 | def __getitem__(self, key): 319 | ind = (key[0] - self.bounds[0], key[1] - self.bounds[0]) 320 | return self.data[ind[0][0]:ind[1][0], 321 | ind[0][1]:ind[1][1], 322 | ind[0][2]:ind[1][2]] 323 | 324 | def __setitem__(self, key, value): 325 | ind = (key[0] - self.bounds[0], key[1] - self.bounds[0]) 326 | self.data[ind[0][0]:ind[1][0], 327 | ind[0][1]:ind[1][1], 328 | ind[0][2]:ind[1][2]] = value 329 | 330 | 331 | class UniformNode(Node): 332 | def __init__(self, parent, bounds, dtype, value, **kwargs): 333 | super(UniformNode, self).__init__(parent, bounds, **kwargs) 334 | self.value = value 335 | self.dtype = dtype 336 | 337 | def __getitem__(self, key): 338 | return np.full(tuple(key[1] - key[0]), self.value, dtype=self.dtype) 339 | 340 | def map_copy(self, copy_parent, leaf_map, uniform_map): 341 | copy = type(self)(copy_parent, self.bounds, copy_parent.get_volume().dtype, 342 | uniform_map(self.value), clip_bound=self.clip_bound) 343 | return copy 344 | 345 | 346 | class UniformBranchNode(UniformNode): 347 | def __setitem__(self, key, value): 348 | replacement = BranchNode(self.parent, self.bounds, clip_bound=self.clip_bound) 349 | volume = self.get_volume() 350 | for i in range(2): 351 | for j in range(2): 352 | for k in range(2): 353 | child_bounds, child_clip_bound = replacement.get_child_bounds(i, j, k) 354 | # If this child is entirely outside the clip bounds, it will never be accessed 355 | # or populated and thus can be omitted. 356 | if child_clip_bound is not None and np.any(np.greater_equal(child_bounds[0], child_clip_bound)): 357 | continue 358 | child_shape = child_bounds[1] - child_bounds[0] 359 | if np.any(np.less_equal(child_shape, volume.leaf_shape)): 360 | child = UniformLeafNode(replacement, child_bounds, self.dtype, self.value) 361 | else: 362 | child = UniformBranchNode(replacement, child_bounds, self.dtype, self.value, 363 | clip_bound=child_clip_bound) 364 | replacement.children[i][j][k] = child 365 | self.replace(replacement) 366 | replacement[key] = value 367 | 368 | 369 | class UniformLeafNode(UniformNode): 370 | def __setitem__(self, key, value): 371 | replacement = LeafNode(self.parent, self.bounds, self[self.bounds]) 372 | self.replace(replacement) 373 | replacement[key] = value 374 | 375 | def count_leaves(self): 376 | return 1 377 | -------------------------------------------------------------------------------- /diluvian/__main__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Command line interface for diluvian.""" 3 | 4 | 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import logging 9 | import os 10 | import random 11 | import re 12 | 13 | import six 14 | 15 | from .config import CONFIG 16 | 17 | 18 | def _make_main_parser(): 19 | """Construct the argparse parser for the main CLI. 20 | 21 | This exists as a separate function so the parser can be used to 22 | auto-generate CLI documentation in Sphinx. 23 | 24 | Returns 25 | ------- 26 | argparse.ArgumentParser 27 | Parser for the main CLI and all subcommands. 28 | """ 29 | common_parser = argparse.ArgumentParser(add_help=False) 30 | 31 | common_parser.add_argument( 32 | '-c', '--config-file', action='append', dest='config_files', default=[], 33 | help='Configuration files to use. For defaults, see `diluvian/conf/default.toml`. ' 34 | 'Values are overwritten in the order provided.') 35 | common_parser.add_argument( 36 | '-cd', action='append_const', dest='config_files', 37 | const=os.path.join(os.path.dirname(__file__), 'conf', 'default.toml'), 38 | help='Add default configuration file to chain of configuration files.') 39 | common_parser.add_argument( 40 | '-m', '--model-file', dest='model_file', default=None, 41 | help='Existing network model file to use for prediction or continued training.') 42 | common_parser.add_argument( 43 | '-v', '--volume-file', action='append', dest='volume_files', default=[], 44 | help='Volume configuration files. For example, see `diluvian/conf/cremi_datasets.toml`.' 45 | 'Values are overwritten in the order provided.') 46 | common_parser.add_argument( 47 | '--no-in-memory', action='store_false', dest='in_memory', default=True, 48 | help='Do not preload entire volumes into memory.') 49 | common_parser.add_argument( 50 | '-rs', '--random-seed', action='store', dest='random_seed', type=int, 51 | help='Seed for initializing the Python and NumPy random generators. ' 52 | 'Overrides any seed specified in configuration files.') 53 | common_parser.add_argument( 54 | '-l', '--log', dest='log_level', 55 | choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], 56 | help='Set the logging level.') 57 | 58 | parser = argparse.ArgumentParser(description='Train or run flood-filling networks on EM data.') 59 | 60 | commandparsers = parser.add_subparsers(help='Commands', dest='command') 61 | 62 | train_parser = commandparsers.add_parser( 63 | 'train', parents=[common_parser], 64 | help='Train a network from labeled volumes.') 65 | train_parser.add_argument( 66 | '-mo', '--model-output-filebase', dest='model_output_filebase', default=None, 67 | help='Base filename for the best trained model and other output artifacts, ' 68 | 'such as metric plots and configuration state.') 69 | train_parser.add_argument( 70 | '-mc', '--model-checkpoint-file', dest='model_checkpoint_file', default=None, 71 | help='Filename for model checkpoints at every epoch. ' 72 | 'This is different than the model output file; if provided, this HDF5 model ' 73 | 'file is saved every epoch regardless of validation performance.' 74 | 'Can use Keras format arguments: https://keras.io/callbacks/#modelcheckpoint') 75 | train_parser.add_argument( 76 | '--early-restart', action='store_true', dest='early_restart', default=False, 77 | help='If training is aborted early because an early abort metric ' 78 | 'criteria, restart training with a new random seed.') 79 | train_parser.add_argument( 80 | '--tensorboard', action='store_true', dest='tensorboard', default=False, 81 | help='Output tensorboard log files while training (limited to network graph).') 82 | train_parser.add_argument( 83 | '--viewer', action='store_true', dest='viewer', default=False, 84 | help='Create a neuroglancer viewer for a training sample at the end of training.') 85 | train_parser.add_argument( 86 | '--metric-plot', action='store_true', dest='metric_plot', default=False, 87 | help='Plot metric history at the end of training. ' 88 | 'Will be saved as a PNG with the model output base filename.') 89 | 90 | fill_common_parser = argparse.ArgumentParser(add_help=False) 91 | fill_common_parser.add_argument( 92 | '--partition-volumes', action='store_true', dest='partition_volumes', default=False, 93 | help='Partition volumes and only fill the validation partition.') 94 | fill_common_parser.add_argument( 95 | '--no-bias', action='store_false', dest='bias', default=True, 96 | help='Overwrite prediction mask at the end of each field of view inference ' 97 | 'rather than using the anti-merge bias update.') 98 | fill_common_parser.add_argument( 99 | '--move-batch-size', dest='move_batch_size', default=1, type=int, 100 | help='Maximum number of fill moves to process in each prediction batch.') 101 | fill_common_parser.add_argument( 102 | '--max-moves', dest='max_moves', default=None, type=int, 103 | help='Cancel filling after this many moves.') 104 | fill_common_parser.add_argument( 105 | '--remask-interval', dest='remask_interval', default=None, type=int, 106 | help='Interval in moves to reset filling region mask based on ' 107 | 'the seeded connected component.') 108 | 109 | fill_parser = commandparsers.add_parser( 110 | 'fill', parents=[common_parser, fill_common_parser], 111 | help='Use a trained network to densely segment a volume.') 112 | fill_parser.add_argument( 113 | '--seed-generator', dest='seed_generator', default='sobel', nargs='?', 114 | # Would be nice to pull these from .preprocessing.SEED_GENERATORS, 115 | # but want to avoid importing so that CLI is responsive. 116 | choices=['grid', 'sobel'], 117 | help='Method to generate seed locations for flood filling.') 118 | fill_parser.add_argument( 119 | '--ordered-seeds', action='store_false', dest='shuffle_seeds', default=True, 120 | help='Do not shuffle order in which seeds are processed.') 121 | fill_parser.add_argument( 122 | '--ignore-mask', dest='ignore_mask', default=False, 123 | help='Ignore the mask channel when generating seeds.') 124 | fill_parser.add_argument( 125 | '--background-label-id', dest='background_label_id', default=0, type=int, 126 | help='Label ID to output for voxels not belonging to any filled body.') 127 | fill_parser.add_argument( 128 | '--viewer', action='store_true', dest='viewer', default=False, 129 | help='Create a neuroglancer viewer for a each volume after filling.') 130 | fill_parser.add_argument( 131 | '--max-bodies', dest='max_bodies', default=None, type=int, 132 | help='Cancel filling after this many bodies (only useful for ' 133 | 'diagnostics).') 134 | fill_parser.add_argument( 135 | '--reject-early-termination', action='store_true', 136 | dest='reject_early_termination', default=False, 137 | help='Reject seeds that terminate early, e.g., due to maximum ' 138 | 'move limits.') 139 | fill_parser.add_argument( 140 | '--resume-file', dest='resume_filename', default=None, 141 | help='Filename for the TOML configuration file of a segmented ' 142 | 'label volume from which to resume filling. The configuration ' 143 | 'should only contain one dataset.') 144 | fill_parser.add_argument( 145 | 'segmentation_output_file', default=None, 146 | help='Filename for the HDF5 segmentation output, without ' 147 | 'extension. Should contain "{volume}", which will be ' 148 | 'substituted with the volume name for each respective ' 149 | 'volume\'s bounds.') 150 | 151 | bounds_common_parser = argparse.ArgumentParser(add_help=False) 152 | bounds_common_parser.add_argument( 153 | '--bounds-num-moves', dest='bounds_num_moves', default=None, nargs=3, type=int, 154 | help='Number of moves in direction to size the subvolume bounds.') 155 | 156 | sparse_fill_parser = commandparsers.add_parser( 157 | 'sparse-fill', parents=[common_parser, fill_common_parser, bounds_common_parser], 158 | help='Use a trained network to fill random regions in a volume.') 159 | sparse_fill_parser.add_argument( 160 | '--augment', action='store_true', dest='augment', default=False, 161 | help='Apply training augmentations to subvolumes before filling.') 162 | sparse_fill_parser.add_argument( 163 | '-bi', '--bounds-input-file', dest='bounds_input_file', default=None, 164 | help='Filename for bounds CSV input. Should contain "{volume}", which will be ' 165 | 'substituted with the volume name for each respective volume\'s bounds.') 166 | 167 | validate_parser = commandparsers.add_parser( # noqa 168 | 'validate', parents=[common_parser], 169 | help='Run a model on validation data.') 170 | 171 | evaluate_parser = commandparsers.add_parser( 172 | 'evaluate', parents=[common_parser], 173 | help='Evaluate a filling result versus a ground truth.') 174 | evaluate_parser.add_argument( 175 | '--border-threshold', dest='border_threshold', default=25, type=float, 176 | help='Region border threshold (in nm) to ignore. Official CREMI ' 177 | 'default is 25nm.') 178 | evaluate_parser.add_argument( 179 | '--partition-volumes', action='store_true', dest='partition_volumes', default=False, 180 | help='Partition volumes and only evaluate the validation partitions.') 181 | evaluate_parser.add_argument( 182 | 'ground_truth_name', default=None, 183 | help='Name of the ground truth volume.') 184 | evaluate_parser.add_argument( 185 | 'prediction_name', default=None, 186 | help='Name of the prediction volume.') 187 | 188 | view_parser = commandparsers.add_parser( 189 | 'view', parents=[common_parser], 190 | help='View a set of co-registered volumes in neuroglancer.') 191 | view_parser.add_argument( 192 | '--partition-volumes', action='store_true', dest='partition_volumes', default=False, 193 | help='Partition volumes and view centered at the validation ' 194 | 'partitions.') 195 | view_parser.add_argument( 196 | 'volume_name_regex', default='.', nargs='?', 197 | help='Regex to filter which volumes of those defined in the ' 198 | 'volume configuration should be loaded.') 199 | 200 | check_config_parser = commandparsers.add_parser( 201 | 'check-config', parents=[common_parser], 202 | help='Check a configuration value.') 203 | check_config_parser.add_argument( 204 | 'config_property', default=None, nargs='?', 205 | help='Name of the property to show, e.g., `training.batch_size`.') 206 | 207 | gen_subv_bounds_parser = commandparsers.add_parser( 208 | 'gen-subv-bounds', parents=[common_parser, bounds_common_parser], 209 | help='Generate subvolume bounds.') 210 | gen_subv_bounds_parser.add_argument( 211 | 'bounds_output_file', default=None, 212 | help='Filename for the CSV output. Should contain "{volume}", which will be ' 213 | 'substituted with the volume name for each respective volume\'s bounds.') 214 | gen_subv_bounds_parser.add_argument( 215 | 'num_bounds', default=None, type=int, 216 | help='Number of bounds to generate.') 217 | 218 | return parser 219 | 220 | 221 | def main(): 222 | """Entry point for the diluvian command line interface.""" 223 | parser = _make_main_parser() 224 | 225 | args = parser.parse_args() 226 | 227 | if args.log_level: 228 | logging.basicConfig(level=logging.getLevelName(args.log_level)) 229 | 230 | if args.config_files: 231 | CONFIG.from_toml(*args.config_files) 232 | 233 | if args.random_seed: 234 | CONFIG.random_seed = args.random_seed 235 | 236 | def init_seeds(): 237 | random.seed(CONFIG.random_seed) 238 | import numpy as np 239 | np.random.seed(CONFIG.random_seed) 240 | import tensorflow as tf 241 | tf.set_random_seed(CONFIG.random_seed) 242 | 243 | if args.command == 'train': 244 | # Late import to prevent loading large modules for short CLI commands. 245 | init_seeds() 246 | from .training import EarlyAbortException, train_network 247 | 248 | volumes = load_volumes(args.volume_files, args.in_memory) 249 | while True: 250 | try: 251 | train_network(model_file=args.model_file, 252 | volumes=volumes, 253 | model_output_filebase=args.model_output_filebase, 254 | model_checkpoint_file=args.model_checkpoint_file, 255 | tensorboard=args.tensorboard, 256 | viewer=args.viewer, 257 | metric_plot=args.metric_plot) 258 | except EarlyAbortException as inst: 259 | if args.early_restart: 260 | import numpy as np 261 | new_seed = CONFIG.random_seed 262 | while new_seed == CONFIG.random_seed: 263 | new_seed = np.random.randint(int(1e8)) 264 | CONFIG.random_seed = new_seed 265 | logging.warning(str(inst)) 266 | logging.warning('Training aborted, restarting with random seed %s', new_seed) 267 | init_seeds() 268 | continue 269 | else: 270 | logging.critical(str(inst)) 271 | break 272 | break 273 | 274 | elif args.command == 'fill': 275 | # Late import to prevent loading large modules for short CLI commands. 276 | init_seeds() 277 | from .diluvian import fill_volumes_with_model 278 | 279 | volumes = load_volumes(args.volume_files, args.in_memory) 280 | fill_volumes_with_model(args.model_file, 281 | volumes, 282 | args.segmentation_output_file, 283 | resume_filename=args.resume_filename, 284 | partition=args.partition_volumes, 285 | viewer=args.viewer, 286 | seed_generator=args.seed_generator, 287 | background_label_id=args.background_label_id, 288 | bias=args.bias, 289 | move_batch_size=args.move_batch_size, 290 | max_moves=args.max_moves, 291 | max_bodies=args.max_bodies, 292 | filter_seeds_by_mask=not args.ignore_mask, 293 | reject_early_termination=args.reject_early_termination, 294 | remask_interval=args.remask_interval, 295 | shuffle_seeds=args.shuffle_seeds) 296 | 297 | elif args.command == 'sparse-fill': 298 | # Late import to prevent loading large modules for short CLI commands. 299 | init_seeds() 300 | from .diluvian import fill_region_with_model 301 | 302 | volumes = load_volumes(args.volume_files, args.in_memory) 303 | fill_region_with_model(args.model_file, 304 | volumes=volumes, 305 | partition=args.partition_volumes, 306 | augment=args.augment, 307 | bounds_input_file=args.bounds_input_file, 308 | bias=args.bias, 309 | move_batch_size=args.move_batch_size, 310 | max_moves=args.max_moves, 311 | remask_interval=args.remask_interval, 312 | moves=args.bounds_num_moves) 313 | 314 | elif args.command == 'validate': 315 | # Late import to prevent loading large modules for short CLI commands. 316 | init_seeds() 317 | from .training import validate_model 318 | 319 | volumes = load_volumes(args.volume_files, args.in_memory) 320 | validate_model(args.model_file, volumes) 321 | 322 | elif args.command == 'evaluate': 323 | from .diluvian import evaluate_volume 324 | 325 | volumes = load_volumes(args.volume_files, args.in_memory) 326 | evaluate_volume(volumes, 327 | args.ground_truth_name, 328 | args.prediction_name, 329 | partition=args.partition_volumes, 330 | border_threshold=args.border_threshold) 331 | 332 | elif args.command == 'view': 333 | # Late import to prevent loading large modules for short CLI commands. 334 | from .diluvian import view_volumes 335 | 336 | volumes = load_volumes(args.volume_files, args.in_memory, name_regex=args.volume_name_regex) 337 | view_volumes(volumes, partition=args.partition_volumes) 338 | 339 | elif args.command == 'check-config': 340 | prop = CONFIG 341 | if args.config_property is not None: 342 | properties = args.config_property.split('.') 343 | for p in properties: 344 | prop = getattr(prop, p) 345 | print(prop) 346 | 347 | elif args.command == 'gen-subv-bounds': 348 | # Late import to prevent loading large modules for short CLI commands. 349 | init_seeds() 350 | from .diluvian import generate_subvolume_bounds 351 | 352 | volumes = load_volumes(args.volume_files, args.in_memory) 353 | generate_subvolume_bounds(args.bounds_output_file, 354 | volumes, 355 | args.num_bounds, 356 | moves=args.bounds_num_moves) 357 | 358 | 359 | def load_volumes(volume_files, in_memory, name_regex=None): 360 | """Load HDF5 volumes specified in a TOML description file. 361 | 362 | Parameters 363 | ---------- 364 | volume_file : list of str 365 | Filenames of the TOML volume descriptions to load. 366 | in_memory : bool 367 | If true, the entire dataset is read into an in-memory volume. 368 | 369 | Returns 370 | ------- 371 | diluvian.volumes.Volume 372 | """ 373 | # Late import to prevent loading large modules for short CLI commands. 374 | from .volumes import HDF5Volume 375 | 376 | print('Loading volumes...') 377 | if volume_files: 378 | volumes = {} 379 | for volume_file in volume_files: 380 | volumes.update(HDF5Volume.from_toml(volume_file)) 381 | else: 382 | volumes = HDF5Volume.from_toml(os.path.join(os.path.dirname(__file__), 'conf', 'cremi_datasets.toml')) 383 | 384 | if name_regex is not None: 385 | name_regex = re.compile(name_regex) 386 | volumes = {k: v for k, v in six.iteritems(volumes) if name_regex.match(k)} 387 | 388 | if in_memory: 389 | print('Copying volumes to memory...') 390 | volumes = {k: v.to_memory_volume() for k, v in six.iteritems(volumes)} 391 | 392 | print('Done.') 393 | return volumes 394 | 395 | 396 | if __name__ == "__main__": 397 | main() 398 | -------------------------------------------------------------------------------- /diluvian/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Global configuration objects. 3 | 4 | This module contains boilerplate configuration objects for storing and loading 5 | configuration state. 6 | """ 7 | 8 | 9 | from __future__ import division 10 | 11 | import os 12 | 13 | import numpy as np 14 | import pytoml as toml 15 | import six 16 | 17 | 18 | class BaseConfig(object): 19 | """Base class for configuration objects. 20 | 21 | String representation yields TOML that should parse back to a dictionary 22 | that will initialize the same configuration object. 23 | """ 24 | def __str__(self): 25 | sanitized = {} 26 | for k, v in six.iteritems(self.__dict__): 27 | if isinstance(v, np.ndarray): 28 | sanitized[k] = v.tolist() 29 | else: 30 | sanitized[k] = v 31 | return toml.dumps(sanitized) 32 | 33 | __repr__ = __str__ 34 | 35 | 36 | class VolumeConfig(BaseConfig): 37 | """Configuration for the use of volumes. 38 | 39 | Attributes 40 | ---------- 41 | resolution : sequence or ndarray of float 42 | Resolution to which volumes will be downsampled before processing. 43 | label_downsampling : str 44 | Method for downsampling label masks. One of 'majority' or 'conjunction'. 45 | """ 46 | def __init__(self, settings): 47 | self.resolution = np.array(settings.get('resolution', [1, 1, 1])) 48 | self.label_downsampling = str(settings.get('label_downsampling', 'majority')) 49 | 50 | 51 | class ModelConfig(BaseConfig): 52 | """Configuration for non-network aspects of the flood filling model. 53 | 54 | Attributes 55 | ---------- 56 | input_fov_shape : sequence or ndarray of int 57 | Input field of view shape in voxels for each flood filling move. 58 | output_fov_shape : sequence or ndarray of int 59 | Output field of view shape in voxels for each flood filling move. Can 60 | not be larger than ``input_fov_shape``. 61 | output_fov_move_fraction : int 62 | Move size as a fraction of the output field of view shape. 63 | v_true, v_false : float 64 | Soft target values for in-object and out-of-object mask voxels, 65 | respectively. 66 | t_move : float 67 | Threshold mask probability in the move check plane to queue a move 68 | to that position. 69 | t_final : float, optional 70 | Threshold mask probability to produce the final segmentation. Defaults 71 | to ``t_move``. 72 | move_check_thickness : int 73 | Thickness of move check plane in voxels. Setting this greater than 1 74 | is useful to make moves more robust even if the move grid aligns with 75 | missing sections or image artifacts. 76 | move_priority : str 77 | How to prioritize the move queue. Either 'descending' to order by 78 | descending mask probability in the move check plane (default), 79 | 'proximity' to prioritize moves minimizing L1 path distance from the 80 | seed, or 'random'. 81 | move_recheck : bool 82 | If true, when moves are retrieved from the queue a cube in the 83 | probability mask will be checked around the move location. If no voxels 84 | in this cube are greater than the move threshold, the move will be 85 | skipped. The cube size is one move step in each direction. 86 | training_subv_shape : sequence or ndarray of int, optional 87 | Shape of the subvolumes used during moving training. 88 | validation_subv_shape : sequence or ndarray of int, optional 89 | Shape of the subvolumes used during training validation. 90 | """ 91 | def __init__(self, settings): 92 | self.input_fov_shape = np.array(settings.get('input_fov_shape', [17, 33, 33])) 93 | self.output_fov_shape = np.array(settings.get('output_fov_shape', [17, 33, 33])) 94 | self.output_fov_move_fraction = int(settings.get('output_fov_move_fraction', 4)) 95 | self.v_true = float(settings.get('v_true', 0.95)) 96 | self.v_false = float(settings.get('v_false', 0.05)) 97 | self.t_move = float(settings.get('t_move', 0.9)) 98 | self.t_final = float(settings.get('t_final', self.t_move)) 99 | self.move_check_thickness = int(settings.get('move_check_thickness', 1)) 100 | self.move_priority = str(settings.get('move_priority', 'descending')) 101 | self.move_recheck = bool(settings.get('move_recheck', True)) 102 | self.training_subv_shape = np.array(settings.get('training_subv_shape', 103 | self.input_fov_shape + self.move_step * 2)) 104 | self.validation_subv_shape = np.array(settings.get('validation_subv_shape', 105 | self.input_fov_shape + self.move_step * 4)) 106 | 107 | @property 108 | def move_step(self): 109 | return (self.output_fov_shape - 1) // self.output_fov_move_fraction 110 | 111 | def subv_moves(self, shape): 112 | return np.prod((shape - self.input_fov_shape) // self.move_step + 1) 113 | 114 | @property 115 | def training_subv_moves(self): 116 | return self.subv_moves(self.training_subv_shape) 117 | 118 | @property 119 | def validation_subv_moves(self): 120 | return self.subv_moves(self.validation_subv_shape) 121 | 122 | 123 | class NetworkConfig(BaseConfig): 124 | """Configuration for the flood filling network architecture. 125 | 126 | Attributes 127 | ---------- 128 | factory : str 129 | Module and function name for a factory method for creating the flood 130 | filling network. This allows a custom architecture to be provided 131 | without needing to modify diluvian. 132 | transpose : bool 133 | If true, any loaded networks will reverse the order of axes for both 134 | inputs and outputs. Data is assumed to be ZYX row-major, but old 135 | versions of diluvian used XYZ, so this is necessary to load old 136 | networks. 137 | rescale_image : bool 138 | If true, rescale the input image intensity from [0, 1) to [-1, 1). 139 | num_modules : int 140 | Number of convolution modules to use, each module consisting of a skip 141 | link in parallel with ``num_layers_per_module`` convolution layers. 142 | num_layers_per_module : int 143 | Number of layers to use in each organizational module, e.g., the 144 | number of convolution layers in each convolution module or the number 145 | of convolution layers before and after each down- and up-sampling 146 | respectively in a U-Net level. 147 | convolution_dim : sequence or ndarray of int 148 | Shape of the convolution for each layer. 149 | convolution_filters : int 150 | Number of convolution filters for each layer. 151 | convolution_activation : str 152 | Name of the Keras activation function to apply after convolution layers. 153 | convolution_padding : str 154 | Name of the padding mode for convolutions, either 'same' (default) or 155 | 'valid'. 156 | initialization : str 157 | Name of the Keras initialization function to use for weight 158 | initialization of all layers. 159 | output_activation : str 160 | Name of the Keras activation function to use for the final network 161 | output. 162 | dropout_probability : float 163 | Probability for dropout layers. If zero, no dropout layers will be 164 | included. 165 | batch_normalization : bool 166 | Whether to apply batch normalization. Note that in included networks 167 | normalization is applied after activation, rather than before as in the 168 | original paper, because this is now more common practice. 169 | unet_depth : int 170 | For U-Net models, the total number of downsampled levels in the network. 171 | unet_downsample_rate : sequence or ndarray of int 172 | The frequency in levels to downsample each axis. For example, a standard 173 | U-Net downsamples all axes at each level, so this value would be all 174 | ones. If data is anisotropic and Z should only be downsampled every 175 | other level, this value could be [2, 1, 1]. Axes set to 0 are never 176 | downsampled. 177 | unet_downsample_mode: string 178 | The mode to use for downsampling. The two options are "fixed_rate", 179 | which will use the downsample rate previously defined, and "isotropy_approximating", 180 | which will downsample on lower resolution axes until the volume is as 181 | isotropic as possible. For example given a volume with resolution 182 | [40,4,4] and 4 unet layers, would downsample to 183 | [40,8,8],[40,16,16],[40,32,32],[80,64,64] 184 | resolution: sequence or ndarray of int 185 | The resolution of the input image data. This is necessary if you want 186 | to use "isotropy_approximating" for ``unet_downsampling_mode`` 187 | """ 188 | def __init__(self, settings): 189 | self.factory = str(settings.get('factory')) 190 | self.transpose = bool(settings.get('transpose', False)) 191 | self.rescale_image = bool(settings.get('rescale_image', False)) 192 | self.num_modules = int(settings.get('num_modules', 8)) 193 | self.num_layers_per_module = int(settings.get('num_layers_per_module', 2)) 194 | self.convolution_dim = np.array(settings.get('convolution_dim', [3, 3, 3])) 195 | self.convolution_filters = int(settings.get('convolution_filters', 32)) 196 | self.convolution_activation = str(settings.get('convolution_activation', 'relu')) 197 | self.convolution_padding = str(settings.get('convolution_padding', 'same')) 198 | self.initialization = str(settings.get('initialization', 'glorot_uniform')) 199 | self.output_activation = str(settings.get('output_activation', 'sigmoid')) 200 | self.dropout_probability = float(settings.get('dropout_probability', 0.0)) 201 | self.batch_normalization = bool(settings.get('batch_normalization', False)) 202 | self.unet_depth = int(settings.get('unet_depth', 4)) 203 | self.unet_downsample_rate = np.array(settings.get('unet_downsample_rate', [1, 1, 1])) 204 | 205 | self.unet_downsample_mode = np.array(settings.get("unet_downsample_mode", "fixed_rate")) 206 | self.resolution = np.array(settings.get("resolution", [1, 1, 1])) 207 | 208 | 209 | class OptimizerConfig(BaseConfig): 210 | """Configuration for the network optimizer. 211 | 212 | Any settings dict entries passed to this initializer will be added as 213 | configuration attributes and passed to the optimizer initializer as keyword 214 | arguments. 215 | 216 | Attributes 217 | ---------- 218 | klass : str 219 | Class name of the Keras optimizer to use. 220 | loss : str 221 | Name of the Keras loss function to use. 222 | """ 223 | def __init__(self, settings): 224 | for k, v in six.iteritems(settings): 225 | if k != 'klass' and k != 'loss': 226 | setattr(self, k, v) 227 | self.klass = str(settings.get('klass', 'SGD')) 228 | self.loss = str(settings.get('loss', 'binary_crossentropy')) 229 | 230 | 231 | class TrainingConfig(BaseConfig): 232 | """Configuration for model training. 233 | 234 | Attributes 235 | ---------- 236 | num_gpus : int 237 | Number of GPUs to use for data-parallelism. 238 | num_workers : int 239 | Number of worker queues to use for generating training data. 240 | gpu_batch_size : int 241 | Per-GPU batch size. The effective batch size will be this times 242 | ``num_gpus``. 243 | training_size : int 244 | Number of samples to use for training **from each volume**. 245 | validation_size : int 246 | Number of samples to use for validation **from each volume**. 247 | total_epochs : int 248 | Maximum number of training epochs. 249 | reset_generators : bool 250 | Reset training generators after each epoch, so that the training 251 | examples at each epoch are identical. 252 | fill_factor_bins : sequence of float 253 | Bin boundaries for filling fractions. If provided, sample loss will be 254 | weighted to increase loss contribution from less-frequent bins. 255 | Otherwise all samples are weighted equally. 256 | partitions : dict 257 | Dictionary mapping volume name regexes to a sequence of int indicating 258 | number of volume partitions along each axis. Only one axis should be 259 | greater than 1. Each volume should match at most one regex. 260 | training_partition, validation_partition : dict 261 | Dictionaries mapping volume name regexes to a sequence of int indicating 262 | index of the partitions to use for training and validation, 263 | respectively. Each volume should match at most one regex. 264 | validation_metric : dict 265 | Module and function name for a metric function taking a true and 266 | predicted region mask ('metric'). Boolean of whether to threshold the 267 | mask for the metric (true) or use the mask and target probabilities 268 | ('threshold'). 269 | String 'min' or 'max'for how to choose best validation metric value 270 | ('mode'). 271 | patience : int 272 | Number of epochs after the last minimal validation loss to terminate 273 | training. 274 | early_abort_epoch : int 275 | If provided, training will check at the end of this epoch 276 | whether validation loss is less than ``early_abort_loss``. If not, 277 | training will be aborted, and may be restarted with a new seed 278 | depending on CLI options. By default this is disabled. 279 | early_abort_loss : float 280 | See ``early_abort_epoch``. 281 | label_erosion : sequence or ndarray of int 282 | Amount to erode label mask for each training subvolume in each 283 | dimension, in pixels. For example, a value of [0, 1, 1] will result 284 | in erosion with a structuring element of size [1, 3, 3]. 285 | relabel_seed_component : bool 286 | Relabel training subvolumes to only include the seeded connected 287 | component. 288 | augment_validation : bool 289 | Whether validation data should also be augmented. 290 | augment_use_both : bool 291 | Whether to sequentially use both the augmented and unaugmented version 292 | of each subvolume. 293 | augment_mirrors : sequence of int 294 | Axes along which to mirror for data augmentation. 295 | augment_permute_axes : sequence of sequence of int 296 | Axis permutations to use for data augmentation. 297 | augment_missing_data : list of dict 298 | List of dictionaries with ``axis`` and ``prob`` keys, indicating 299 | an axis to perform data blanking along, and the probability to blank 300 | each plane in the axis, respectively. 301 | augment_noise : list of dict 302 | List of dictionaries with ``axis``, ``mul`` and `add`` keys, indicating 303 | an axis to perform independent Gaussian noise augmentation on, and the 304 | standard deviations of 1-mean multiplicative and 0-mean additive noise, 305 | respectively. 306 | augment_contrast : list of dict 307 | List of dictionaries with ``axis``, ``prob``, ``scaling_mean``, 308 | ``scaling_std``, ``center_mean`` and ``center_std`` keys. These 309 | specify the probability to alter the contrast of a section, the mean 310 | and standard deviation to draw from a normal distribution to scale 311 | contrast, and the mean and standard deviation to draw from a normal 312 | distribution to move the intensity center multiplicatively. 313 | augment_missing_data : list of dict 314 | List of dictionaries with ``axis``, ``prob`` and ``volume_file`` 315 | keys, indicating an axis to perform data artifacting along, the 316 | probability to add artifacts to each plane in the axis, and the 317 | volume configuration file from which to draw artifacts, respectively. 318 | """ 319 | def __init__(self, settings): 320 | self.num_gpus = int(settings.get('num_gpus', 1)) 321 | self.num_workers = int(settings.get('num_workers', 4)) 322 | self.gpu_batch_size = int(settings.get('gpu_batch_size', 8)) 323 | self.batch_size = self.num_gpus * self.gpu_batch_size 324 | self.training_size = int(settings.get('training_size', 256)) 325 | self.validation_size = int(settings.get('validation_size', 256)) 326 | self.total_epochs = int(settings.get('total_epochs', 100)) 327 | self.reset_generators = bool(settings.get('reset_generators', False)) 328 | self.fill_factor_bins = settings.get('fill_factor_bins', None) 329 | if self.fill_factor_bins is not None: 330 | self.fill_factor_bins = np.array(self.fill_factor_bins) 331 | self.partitions = settings.get('partitions', {'.*': [2, 1, 1]}) 332 | self.training_partition = settings.get('training_partition', {'.*': [0, 0, 0]}) 333 | self.validation_partition = settings.get('validation_partition', {'.*': [1, 0, 0]}) 334 | self.validation_metric = settings.get( 335 | 'validation_metric', 336 | {'metric': 'diluvian.util.binary_f_score', 'threshold': True, 'mode': 'max', 'args': {'beta': 0.5}}) 337 | self.patience = int(np.array(settings.get('patience', 10))) 338 | self.early_abort_epoch = settings.get('early_abort_epoch', None) 339 | self.early_abort_loss = settings.get('early_abort_loss', None) 340 | self.label_erosion = np.array(settings.get('label_erosion', [0, 1, 1]), dtype=np.int64) 341 | self.relabel_seed_component = bool(settings.get('relabel_seed_component', False)) 342 | self.augment_validation = bool(settings.get('augment_validation', True)) 343 | self.augment_use_both = bool(settings.get('augment_use_both', True)) 344 | self.augment_mirrors = [int(x) for x in settings.get('augment_mirrors', [0, 1, 2])] 345 | self.augment_permute_axes = settings.get('augment_permute_axes', [[0, 2, 1]]) 346 | self.augment_missing_data = settings.get('augment_missing_data', [{'axis': 0, 'prob': 0.01}]) 347 | self.augment_noise = settings.get('augment_noise', [{'axis': 0, 'mul': 0.1, 'add': 0.1}]) 348 | self.augment_contrast = settings.get( 349 | 'augment_contrast', 350 | [{'axis': 0, 'prob': 0.05, 'scaling_mean': 0.5, 'scaling_std': 0.1, 351 | 'center_mean': 1.2, 'center_std': 0.2}]) 352 | self.augment_artifacts = settings.get('augment_artifacts', []) 353 | 354 | 355 | class PostprocessingConfig(BaseConfig): 356 | """Configuration for segmentation processing after flood filling. 357 | 358 | Attributes 359 | ---------- 360 | closing_shape : sequence or ndarray of int 361 | Shape of the structuring element for morphological closing, in voxels. 362 | """ 363 | def __init__(self, settings): 364 | self.closing_shape = settings.get('closing_shape', None) 365 | 366 | 367 | class Config(object): 368 | """A complete collection of configuration objects. 369 | 370 | Attributes 371 | ---------- 372 | random_seed : int 373 | Seed for initializing the Python and NumPy random generators. 374 | """ 375 | 376 | def __init__(self, settings_collection=None): 377 | if settings_collection is not None: 378 | settings = settings_collection[0].copy() 379 | for s in settings_collection: 380 | for c in s: 381 | if c in settings and isinstance(settings[c], dict): 382 | settings[c].update(s[c]) 383 | else: 384 | settings[c] = s[c] 385 | else: 386 | settings = {} 387 | 388 | self.volume = VolumeConfig(settings.get('volume', {})) 389 | self.model = ModelConfig(settings.get('model', {})) 390 | self.network = NetworkConfig(settings.get('network', {})) 391 | self.optimizer = OptimizerConfig(settings.get('optimizer', {})) 392 | self.training = TrainingConfig(settings.get('training', {})) 393 | self.postprocessing = PostprocessingConfig(settings.get('postprocessing', {})) 394 | 395 | self.random_seed = int(settings.get('random_seed', 0)) 396 | 397 | def __str__(self): 398 | sanitized = {} 399 | for n, c in six.iteritems(self.__dict__): 400 | if not isinstance(c, BaseConfig): 401 | sanitized[n] = c 402 | continue 403 | sanitized[n] = {} 404 | for k, v in six.iteritems(c.__dict__): 405 | if isinstance(v, np.ndarray): 406 | sanitized[n][k] = v.tolist() 407 | else: 408 | sanitized[n][k] = v 409 | return toml.dumps(sanitized) 410 | 411 | def from_toml(self, *filenames): 412 | """Reinitializes this Config from a list of TOML configuration files. 413 | 414 | Existing settings are discarded. When multiple files are provided, 415 | configuration is overridden by later files in the list. 416 | 417 | Parameters 418 | ---------- 419 | filenames : interable of str 420 | Filenames of TOML configuration files to load. 421 | """ 422 | settings = [] 423 | for filename in filenames: 424 | with open(filename, 'rb') as fin: 425 | settings.append(toml.load(fin)) 426 | 427 | return self.__init__(settings) 428 | 429 | def to_toml(self, filename): 430 | with open(filename, 'w') as tomlfile: 431 | tomlfile.write(str(self)) 432 | 433 | 434 | CONFIG = Config() 435 | CONFIG.from_toml(os.path.join(os.path.dirname(__file__), 'conf', 'default.toml')) 436 | -------------------------------------------------------------------------------- /diluvian/diluvian.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from collections import deque 8 | import itertools 9 | import logging 10 | from multiprocessing import ( 11 | Manager, 12 | Process, 13 | ) 14 | import os 15 | import random 16 | 17 | import numpy as np 18 | import pytoml as toml 19 | import six 20 | from six.moves import input as raw_input 21 | from tqdm import tqdm 22 | 23 | from .config import CONFIG 24 | from . import preprocessing 25 | from .training import augment_subvolume_generator 26 | from .util import ( 27 | get_color_shader, 28 | Roundrobin, 29 | WrappedViewer, 30 | ) 31 | from .volumes import ( 32 | HDF5Volume, 33 | partition_volumes, 34 | SubvolumeBounds, 35 | ) 36 | from .regions import Region 37 | 38 | 39 | def generate_subvolume_bounds(filename, volumes, num_bounds, sparse=False, moves=None): 40 | if '{volume}' not in filename: 41 | raise ValueError('CSV filename must contain "{volume}" for volume name replacement.') 42 | 43 | if moves is None: 44 | moves = 5 45 | else: 46 | moves = np.asarray(moves) 47 | subv_shape = CONFIG.model.input_fov_shape + CONFIG.model.move_step * 2 * moves 48 | 49 | if sparse: 50 | gen_kwargs = {'sparse_margin': subv_shape} 51 | else: 52 | gen_kwargs = {'shape': subv_shape} 53 | for k, v in six.iteritems(volumes): 54 | bounds = v.downsample(CONFIG.volume.resolution)\ 55 | .subvolume_bounds_generator(**gen_kwargs) 56 | bounds = itertools.islice(bounds, num_bounds) 57 | SubvolumeBounds.iterable_to_csv(bounds, filename.format(volume=k)) 58 | 59 | 60 | def fill_volume_with_model( 61 | model_file, 62 | volume, 63 | resume_prediction=None, 64 | checkpoint_filename=None, 65 | checkpoint_label_interval=20, 66 | seed_generator='sobel', 67 | background_label_id=0, 68 | bias=True, 69 | move_batch_size=1, 70 | max_moves=None, 71 | max_bodies=None, 72 | num_workers=CONFIG.training.num_gpus, 73 | worker_prequeue=1, 74 | filter_seeds_by_mask=True, 75 | reject_non_seed_components=True, 76 | reject_early_termination=False, 77 | remask_interval=None, 78 | shuffle_seeds=True): 79 | subvolume = volume.get_subvolume(SubvolumeBounds(start=np.zeros(3, dtype=np.int64), stop=volume.shape)) 80 | # Create an output label volume. 81 | if resume_prediction is None: 82 | prediction = np.full_like(subvolume.image, background_label_id, dtype=np.uint64) 83 | label_id = 0 84 | else: 85 | if resume_prediction.shape != subvolume.image.shape: 86 | raise ValueError('Resume volume prediction is wrong shape.') 87 | prediction = resume_prediction 88 | prediction.flags.writeable = True 89 | label_id = prediction.max() 90 | # Create a conflict count volume that tracks locations where segmented 91 | # bodies overlap. For now the first body takes precedence in the 92 | # predicted labels. 93 | conflict_count = np.full_like(prediction, 0, dtype=np.uint32) 94 | 95 | def worker(worker_id, set_devices, model_file, image, seeds, results, lock, revoked): 96 | lock.acquire() 97 | import tensorflow as tf 98 | 99 | if set_devices: 100 | # Only make one GPU visible to Tensorflow so that it does not allocate 101 | # all available memory on all devices. 102 | # See: https://stackoverflow.com/questions/37893755 103 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 104 | os.environ['CUDA_VISIBLE_DEVICES'] = str(worker_id) 105 | 106 | with tf.device('/gpu:0'): 107 | # Late import to avoid Keras import until TF bindings are set. 108 | from .network import load_model 109 | 110 | logging.debug('Worker %s: loading model', worker_id) 111 | model = load_model(model_file, CONFIG.network) 112 | lock.release() 113 | 114 | def is_revoked(test_seed): 115 | ret = False 116 | lock.acquire() 117 | if tuple(test_seed) in revoked: 118 | ret = True 119 | revoked.remove(tuple(test_seed)) 120 | lock.release() 121 | return ret 122 | 123 | while True: 124 | seed = seeds.get(True) 125 | 126 | if not isinstance(seed, np.ndarray): 127 | logging.debug('Worker %s: got DONE', worker_id) 128 | break 129 | 130 | if is_revoked(seed): 131 | results.put((seed, None)) 132 | continue 133 | 134 | def stopping_callback(region): 135 | stop = is_revoked(seed) 136 | if reject_non_seed_components and \ 137 | region.bias_against_merge and \ 138 | region.mask[tuple(region.seed_vox)] < 0.5: 139 | stop = True 140 | return stop 141 | 142 | logging.debug('Worker %s: got seed %s', worker_id, np.array_str(seed)) 143 | 144 | # Flood-fill and get resulting mask. 145 | # Allow reading outside the image volume bounds to allow segmentation 146 | # to fill all the way to the boundary. 147 | region = Region(image, seed_vox=seed, sparse_mask=True, block_padding='reflect') 148 | region.bias_against_merge = bias 149 | early_termination = False 150 | try: 151 | six.next(region.fill( 152 | model, 153 | move_batch_size=move_batch_size, 154 | max_moves=max_moves, 155 | progress=2 + worker_id, 156 | stopping_callback=stopping_callback, 157 | remask_interval=remask_interval)) 158 | except Region.EarlyFillTermination: 159 | early_termination = True 160 | except StopIteration: 161 | pass 162 | if reject_early_termination and early_termination: 163 | body = None 164 | else: 165 | body = region.to_body() 166 | logging.debug('Worker %s: seed %s filled', worker_id, np.array_str(seed)) 167 | 168 | results.put((seed, body)) 169 | 170 | # Generate seeds from volume. 171 | generator = preprocessing.SEED_GENERATORS[seed_generator] 172 | seeds = generator(subvolume.image, CONFIG.volume.resolution) 173 | 174 | if filter_seeds_by_mask and volume.mask_data is not None: 175 | seeds = [s for s in seeds if volume.mask_data[tuple(volume.world_coord_to_local(s))]] 176 | 177 | pbar = tqdm(desc='Seed queue', total=len(seeds), miniters=1, smoothing=0.0) 178 | label_pbar = tqdm(desc='Labeled vox', total=prediction.size, miniters=1, smoothing=0.0, position=1) 179 | num_seeds = len(seeds) 180 | if shuffle_seeds: 181 | random.shuffle(seeds) 182 | seeds = iter(seeds) 183 | 184 | manager = Manager() 185 | # Queue of seeds to be picked up by workers. 186 | seed_queue = manager.Queue() 187 | # Queue of results from workers. 188 | results_queue = manager.Queue() 189 | # Dequeue of seeds that were put in seed_queue but have not yet been 190 | # combined by the main process. 191 | dispatched_seeds = deque() 192 | # Seeds that were placed in seed_queue but subsequently covered by other 193 | # results before their results have been processed. This allows workers to 194 | # abort working on these seeds by checking this list. 195 | revoked_seeds = manager.list() 196 | # Results that have been received by the main process but have not yet 197 | # been combined because they were not received in the dispatch order. 198 | unordered_results = {} 199 | 200 | def queue_next_seed(): 201 | total = 0 202 | for seed in seeds: 203 | if prediction[seed[0], seed[1], seed[2]] != background_label_id: 204 | # This seed has already been filled. 205 | total += 1 206 | continue 207 | dispatched_seeds.append(seed) 208 | seed_queue.put(seed) 209 | 210 | break 211 | 212 | return total 213 | 214 | for _ in range(min(num_seeds, num_workers * worker_prequeue)): 215 | processed_seeds = queue_next_seed() 216 | pbar.update(processed_seeds) 217 | 218 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 219 | set_devices = False 220 | num_workers = 1 221 | logging.warn('Environment variable CUDA_VISIBLE_DEVICES is set, so only one worker can be used.\n' 222 | 'See https://github.com/aschampion/diluvian/issues/11') 223 | else: 224 | set_devices = True 225 | 226 | workers = [] 227 | loading_lock = manager.Lock() 228 | for worker_id in range(num_workers): 229 | w = Process(target=worker, args=(worker_id, set_devices, model_file, subvolume.image, 230 | seed_queue, results_queue, loading_lock, revoked_seeds)) 231 | w.start() 232 | workers.append(w) 233 | 234 | last_checkpoint_label = label_id 235 | 236 | # For each seed, create region, fill, threshold, and merge to output volume. 237 | while dispatched_seeds: 238 | processed_seeds = 1 239 | expected_seed = dispatched_seeds.popleft() 240 | logging.debug('Expecting seed %s', np.array_str(expected_seed)) 241 | 242 | if tuple(expected_seed) in unordered_results: 243 | logging.debug('Expected seed %s is in old results', np.array_str(expected_seed)) 244 | seed = expected_seed 245 | body = unordered_results[tuple(seed)] 246 | del unordered_results[tuple(seed)] 247 | 248 | else: 249 | seed, body = results_queue.get(True) 250 | processed_seeds += queue_next_seed() 251 | 252 | while not np.array_equal(seed, expected_seed): 253 | logging.debug('Seed %s is early, stashing', np.array_str(seed)) 254 | unordered_results[tuple(seed)] = body 255 | seed, body = results_queue.get(True) 256 | processed_seeds += queue_next_seed() 257 | 258 | logging.debug('Processing seed at %s', np.array_str(seed)) 259 | pbar.set_description('Seed ' + np.array_str(seed)) 260 | pbar.update(processed_seeds) 261 | 262 | if prediction[seed[0], seed[1], seed[2]] != background_label_id: 263 | # This seed has already been filled. 264 | logging.debug('Seed (%s) was filled but has been covered in the meantime.', 265 | np.array_str(seed)) 266 | loading_lock.acquire() 267 | if tuple(seed) in revoked_seeds: 268 | revoked_seeds.remove(tuple(seed)) 269 | loading_lock.release() 270 | continue 271 | 272 | if body is None: 273 | logging.debug('Body was None.') 274 | continue 275 | 276 | if reject_non_seed_components and not body.is_seed_in_mask(): 277 | logging.debug('Seed (%s) is not in its body.', np.array_str(seed)) 278 | continue 279 | 280 | if reject_non_seed_components: 281 | mask, bounds = body.get_seeded_component(CONFIG.postprocessing.closing_shape) 282 | else: 283 | mask, bounds = body._get_bounded_mask() 284 | 285 | body_size = np.count_nonzero(mask) 286 | 287 | if body_size == 0: 288 | logging.debug('Body was empty.') 289 | continue 290 | 291 | # Generate a label ID for this region. 292 | label_id += 1 293 | if label_id == background_label_id: 294 | label_id += 1 295 | 296 | logging.debug('Adding body to prediction label volume.') 297 | bounds_shape = list(map(slice, bounds[0], bounds[1])) 298 | prediction_mask = prediction[bounds_shape] == background_label_id 299 | for seed in dispatched_seeds: 300 | if np.all(bounds[0] <= seed) and np.all(bounds[1] > seed) and mask[tuple(seed - bounds[0])]: 301 | loading_lock.acquire() 302 | if tuple(seed) not in revoked_seeds: 303 | revoked_seeds.append(tuple(seed)) 304 | loading_lock.release() 305 | conflict_count[bounds_shape][np.logical_and(np.logical_not(prediction_mask), mask)] += 1 306 | label_shape = np.logical_and(prediction_mask, mask) 307 | prediction[bounds_shape][np.logical_and(prediction_mask, mask)] = label_id 308 | 309 | label_pbar.set_description('Label {}'.format(label_id)) 310 | label_pbar.update(np.count_nonzero(label_shape)) 311 | logging.info('Filled seed (%s) with %s voxels labeled %s.', 312 | np.array_str(seed), body_size, label_id) 313 | 314 | if max_bodies and label_id >= max_bodies: 315 | # Drain the queues. 316 | while not seed_queue.empty(): 317 | seed_queue.get_nowait() 318 | break 319 | 320 | if checkpoint_filename is not None and label_id - last_checkpoint_label > checkpoint_label_interval: 321 | config = HDF5Volume.write_file( 322 | checkpoint_filename + '.hdf5', 323 | CONFIG.volume.resolution, 324 | label_data=prediction) 325 | config['name'] = 'segmentation checkpoint' 326 | with open(checkpoint_filename + '.toml', 'wb') as tomlfile: 327 | tomlfile.write('# Filling model: {}\n'.format(model_file)) 328 | tomlfile.write(str(toml.dumps({'dataset': [config]}))) 329 | 330 | for _ in range(num_workers): 331 | seed_queue.put('DONE') 332 | for wid, worker in enumerate(workers): 333 | worker.join() 334 | manager.shutdown() 335 | 336 | label_pbar.close() 337 | pbar.close() 338 | 339 | return prediction, conflict_count 340 | 341 | 342 | def fill_volumes_with_model( 343 | model_file, 344 | volumes, 345 | filename, 346 | resume_filename=None, 347 | partition=False, 348 | viewer=False, 349 | **kwargs): 350 | if '{volume}' not in filename: 351 | raise ValueError('HDF5 filename must contain "{volume}" for volume name replacement.') 352 | if resume_filename is not None and '{volume}' not in resume_filename: 353 | raise ValueError('TOML resume filename must contain "{volume}" for volume name replacement.') 354 | 355 | if partition: 356 | _, volumes = partition_volumes(volumes) 357 | 358 | for volume_name, volume in six.iteritems(volumes): 359 | logging.info('Filling volume %s...', volume_name) 360 | volume = volume.downsample(CONFIG.volume.resolution) 361 | if resume_filename is not None: 362 | resume_volume_filename = resume_filename.format(volume=volume_name) 363 | resume_volume = six.next(six.itervalues(HDF5Volume.from_toml(resume_volume_filename))) 364 | resume_prediction = resume_volume.to_memory_volume().label_data 365 | else: 366 | resume_prediction = None 367 | 368 | volume_filename = filename.format(volume=volume_name) 369 | checkpoint_filename = volume_filename + '_checkpoint' 370 | prediction, conflict_count = fill_volume_with_model( 371 | model_file, 372 | volume, 373 | resume_prediction=resume_prediction, 374 | checkpoint_filename=checkpoint_filename, 375 | **kwargs) 376 | 377 | config = HDF5Volume.write_file( 378 | volume_filename + '.hdf5', 379 | CONFIG.volume.resolution, 380 | label_data=prediction) 381 | config['name'] = volume_name + ' segmentation' 382 | with open(volume_filename + '.toml', 'wb') as tomlfile: 383 | tomlfile.write('# Filling model: {}\n'.format(model_file)) 384 | tomlfile.write('# Filling kwargs: {}\n'.format(str(kwargs))) 385 | tomlfile.write(str(toml.dumps({'dataset': [config]}))) 386 | 387 | if viewer: 388 | viewer = WrappedViewer(voxel_size=list(np.flipud(CONFIG.volume.resolution))) 389 | subvolume = volume.get_subvolume(SubvolumeBounds(start=np.zeros(3, dtype=np.int64), stop=volume.shape)) 390 | viewer.add(subvolume.image, name='Image') 391 | viewer.add(prediction, name='Labels') 392 | viewer.add(conflict_count, name='Conflicts') 393 | 394 | viewer.print_view_prompt() 395 | 396 | 397 | def fill_region_with_model( 398 | model_file, 399 | volumes=None, 400 | partition=False, 401 | augment=False, 402 | bounds_input_file=None, 403 | bias=True, 404 | move_batch_size=1, 405 | max_moves=None, 406 | remask_interval=None, 407 | sparse=False, 408 | moves=None): 409 | # Late import to avoid Keras import until TF bindings are set. 410 | from .network import load_model 411 | 412 | if volumes is None: 413 | raise ValueError('Volumes must be provided.') 414 | 415 | if partition: 416 | _, volumes = partition_volumes(volumes) 417 | 418 | if bounds_input_file is not None: 419 | gen_kwargs = { 420 | k: {'bounds_generator': iter(SubvolumeBounds.iterable_from_csv(bounds_input_file.format(volume=k)))} 421 | for k in volumes.iterkeys()} 422 | else: 423 | if moves is None: 424 | moves = 5 425 | else: 426 | moves = np.asarray(moves) 427 | subv_shape = CONFIG.model.input_fov_shape + CONFIG.model.move_step * 2 * moves 428 | 429 | if sparse: 430 | gen_kwargs = { 431 | k: {'sparse_margin': subv_shape} 432 | for k in volumes.iterkeys()} 433 | else: 434 | gen_kwargs = { 435 | k: {'shape': subv_shape} 436 | for k in volumes.iterkeys()} 437 | subvolumes = [ 438 | v.downsample(CONFIG.volume.resolution) 439 | .subvolume_generator(**gen_kwargs[k]) 440 | for k, v in six.iteritems(volumes)] 441 | if augment: 442 | subvolumes = map(augment_subvolume_generator, subvolumes) 443 | regions = Roundrobin(*[Region.from_subvolume_generator(v, block_padding='reflect') for v in subvolumes]) 444 | 445 | model = load_model(model_file, CONFIG.network) 446 | 447 | for region in regions: 448 | region.bias_against_merge = bias 449 | try: 450 | six.next(region.fill( 451 | model, 452 | progress=True, 453 | move_batch_size=move_batch_size, 454 | max_moves=max_moves, 455 | remask_interval=remask_interval)) 456 | except (StopIteration, Region.EarlyFillTermination): 457 | pass 458 | body = region.to_body() 459 | viewer = region.get_viewer() 460 | try: 461 | mask, bounds = body.get_seeded_component(CONFIG.postprocessing.closing_shape) 462 | viewer.add(mask.astype(np.float32), 463 | name='Body Mask', 464 | offset=bounds[0], 465 | shader=get_color_shader(2)) 466 | except ValueError: 467 | logging.info('Seed not in body.') 468 | print(viewer) 469 | while True: 470 | s = raw_input('Press Enter to continue, ' 471 | 'v to open in browser, ' 472 | 'a to export animation, ' 473 | 'r to 3D render body, ' 474 | 'q to quit...') 475 | if s == 'q': 476 | return 477 | elif s == 'a': 478 | region_copy = region.unfilled_copy() 479 | # Must assign the animation to a variable so that it is not GCed. 480 | ani = region_copy.fill_animation( # noqa 481 | 'export.mp4', 482 | model, 483 | progress=True, 484 | move_batch_size=move_batch_size, 485 | max_moves=max_moves, 486 | remask_interval=remask_interval) 487 | s = raw_input("Press Enter when animation is complete...") 488 | elif s == 'r': 489 | region.render_body() 490 | elif s == 'ra': 491 | region_copy = region.unfilled_copy() 492 | region_copy.fill_render( 493 | model, 494 | progress=True, 495 | move_batch_size=move_batch_size, 496 | max_moves=max_moves, 497 | remask_interval=remask_interval) 498 | elif s == 's': 499 | body.to_swc('{}.swc'.format('_'.join(map(str, tuple(body.seed))))) 500 | elif s == 'v': 501 | viewer.open_in_browser() 502 | else: 503 | break 504 | 505 | 506 | def evaluate_volume( 507 | volumes, 508 | gt_name, 509 | pred_name, 510 | partition=False, 511 | border_threshold=None, 512 | use_gt_mask=True, 513 | relabel=False): 514 | # TODO: This is very intrusive into Volumes and should be refactored to 515 | # handle much of the partioned access and resampling there. 516 | 517 | import cremi 518 | 519 | if partition: 520 | _, volumes = partition_volumes(volumes, downsample=False) 521 | 522 | def labels_to_cremi(v): 523 | label_data = v.label_data.copy() 524 | if hasattr(v, 'bounds'): 525 | label_data = label_data[list(map(slice, list(v.bounds[0]), list(v.bounds[1])))] 526 | volume = cremi.Volume(label_data, resolution=v.resolution) 527 | 528 | return volume 529 | 530 | gt_vol = volumes[gt_name] 531 | pred_vol = volumes[pred_name] 532 | logging.info('GT shape: %s\t Prediction shape:%s', gt_vol.shape, pred_vol.shape) 533 | 534 | pred_upsample = gt_vol._get_downsample_from_resolution(pred_vol.resolution) 535 | if np.any(pred_upsample > 0): 536 | scale = np.exp2(pred_upsample).astype(np.int64) 537 | logging.warn('Segmentation is different resolution than groundtruth. Upsampling by %s.', scale) 538 | 539 | pred_data = pred_vol.label_data 540 | if hasattr(pred_vol, 'bounds'): 541 | pred_data = pred_data[list(map(slice, list(pred_vol.bounds[0]), list(pred_vol.bounds[1])))] 542 | orig_shape = pred_data.shape 543 | pred_data = np.lib.stride_tricks.as_strided(pred_data, 544 | [b for a in zip(orig_shape, scale) for b in a], 545 | [b for a in zip(pred_data.strides, [0, 0, 0]) for b in a]) 546 | new_shape = np.array(orig_shape) * scale 547 | pred_data = pred_data.reshape(list(new_shape)) 548 | 549 | padding = np.array(gt_vol.shape) - new_shape 550 | if np.any(padding > 0): 551 | logging.warn('Padding segmentation (%s) to be groundtruth size (%s)', new_shape, gt_vol.shape) 552 | pred_data = np.pad(pred_data, zip([0, 0, 0], list(padding)), 'edge') 553 | 554 | pred = cremi.Volume(pred_data, resolution=gt_vol.resolution) 555 | else: 556 | pred = labels_to_cremi(pred_vol) 557 | 558 | gt = labels_to_cremi(gt_vol) 559 | 560 | # Some augmented CREMI volumes have not just a uint64 -1 as background, but 561 | # several large values. Set these all to background to avoid breaking 562 | # coo_matrix. 563 | gt.data[gt.data > np.uint64(-10)] = np.uint64(-1) 564 | background_label_id = 0 565 | pred.data[pred.data > np.uint64(-10)] = background_label_id 566 | 567 | if use_gt_mask and gt_vol.mask_data is not None: 568 | logging.warn('Groundtruth has a mask channel that will be applied to segmentation.') 569 | mask_data = gt_vol.mask_data 570 | if hasattr(gt_vol, 'bounds'): 571 | mask_data = mask_data[list(map(slice, list(gt_vol.bounds[0]), list(gt_vol.bounds[1])))] 572 | 573 | if relabel: 574 | mask_exiting_bodies = np.unique(pred.data[np.logical_not(mask_data)]) 575 | 576 | pred.data[np.logical_not(mask_data)] = background_label_id 577 | 578 | if relabel: 579 | from skimage import morphology 580 | 581 | pred_copy = np.zeros_like(pred.data) 582 | exiting_bodies_mask = np.isin(pred.data, mask_exiting_bodies) 583 | pred_copy[exiting_bodies_mask] = pred.data[exiting_bodies_mask] 584 | 585 | new_pred = morphology.label(pred_copy, background=background_label_id, connectivity=2) 586 | 587 | pred.data[exiting_bodies_mask] = new_pred[exiting_bodies_mask] 588 | 589 | gt_neuron_ids = cremi.evaluation.NeuronIds(gt, border_threshold=border_threshold) 590 | 591 | (voi_split, voi_merge) = gt_neuron_ids.voi(pred) 592 | adapted_rand = gt_neuron_ids.adapted_rand(pred) 593 | 594 | print('VOI split :', voi_split) 595 | print('VOI merge :', voi_merge) 596 | print('Adapted Rand-index:', adapted_rand) 597 | print('CREMI :', np.sqrt((voi_split + voi_merge) * adapted_rand)) 598 | 599 | 600 | def view_volumes(volumes, partition=False): 601 | """Display a set of volumes together in a neuroglancer viewer. 602 | 603 | Parameters 604 | ---------- 605 | volumes : dict 606 | Dictionary mapping volume name to diluvian.volumes.Volume. 607 | partition : bool 608 | If true, partition the volumes and put the view origin at the validaiton 609 | partition origin. 610 | """ 611 | 612 | if partition: 613 | _, volumes = partition_volumes(volumes, downsample=False) 614 | 615 | viewer = WrappedViewer() 616 | 617 | for volume_name, volume in six.iteritems(volumes): 618 | resolution = list(np.flipud(volume.resolution)) 619 | offset = getattr(volume, 'bounds', [np.zeros(3, dtype=np.int32)])[0] 620 | offset = np.flipud(-offset) 621 | 622 | viewer.add(volume.image_data, 623 | name='{} (Image)'.format(volume_name), 624 | voxel_size=resolution, 625 | voxel_offset=offset) 626 | if volume.label_data is not None: 627 | viewer.add(volume.label_data, 628 | name='{} (Labels)'.format(volume_name), 629 | voxel_size=resolution, 630 | voxel_offset=offset) 631 | if volume.mask_data is not None: 632 | viewer.add(volume.mask_data, 633 | name='{} (Mask)'.format(volume_name), 634 | voxel_size=resolution, 635 | voxel_offset=offset) 636 | 637 | viewer.print_view_prompt() 638 | -------------------------------------------------------------------------------- /diluvian/training.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Functions for generating training data and training networks.""" 3 | 4 | 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import collections 9 | import copy 10 | import itertools 11 | import logging 12 | import random 13 | 14 | import matplotlib as mpl 15 | # Use the 'Agg' backend to allow the generation of plots even if no X server 16 | # is available. The matplotlib backend must be set before importing pyplot. 17 | mpl.use('Agg') # noqa 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import six 21 | from six.moves import range as xrange 22 | import tensorflow as tf 23 | from tqdm import tqdm 24 | 25 | import keras.backend as K 26 | from keras.callbacks import ( 27 | Callback, 28 | EarlyStopping, 29 | ModelCheckpoint, 30 | TensorBoard, 31 | ) 32 | 33 | from .config import CONFIG 34 | from .network import compile_network, load_model, make_parallel 35 | from .util import ( 36 | get_color_shader, 37 | get_function, 38 | pad_dims, 39 | Roundrobin, 40 | WrappedViewer, 41 | write_keras_history_to_csv, 42 | ) 43 | from .volumes import ( 44 | ClipSubvolumeImageGenerator, 45 | ContrastAugmentGenerator, 46 | ErodedMaskGenerator, 47 | GaussianNoiseAugmentGenerator, 48 | MaskedArtifactAugmentGenerator, 49 | MirrorAugmentGenerator, 50 | MissingDataAugmentGenerator, 51 | partition_volumes, 52 | PermuteAxesAugmentGenerator, 53 | RelabelSeedComponentGenerator, 54 | ) 55 | from .regions import ( 56 | Region, 57 | ) 58 | 59 | 60 | def plot_history(history): 61 | fig = plt.figure() 62 | ax = fig.add_subplot(111) 63 | ax.plot(history.history['loss']) 64 | ax.plot(history.history['val_loss']) 65 | ax.plot(history.history['val_subv_metric']) 66 | fig.suptitle('model loss') 67 | ax.set_ylabel('loss') 68 | ax.set_xlabel('epoch') 69 | ax.legend(['train', 'validation', 'val subvolumes'], loc='upper right') 70 | 71 | return fig 72 | 73 | 74 | def patch_prediction_copy(model): 75 | """Patch a Keras model to copy outputs to a kludge during training. 76 | 77 | This is necessary for mask updates to a region during training. 78 | 79 | Parameters 80 | ---------- 81 | model : keras.engine.Model 82 | """ 83 | model.train_function = None 84 | model.test_function = None 85 | 86 | model._orig_train_on_batch = model.train_on_batch 87 | 88 | def train_on_batch(self, x, y, **kwargs): 89 | kludge = x.pop('kludge', None) 90 | outputs = self._orig_train_on_batch(x, y, **kwargs) 91 | kludge['outputs'] = outputs.pop() 92 | if len(outputs) == 1: 93 | return outputs[0] 94 | return outputs 95 | 96 | model.train_on_batch = six.create_bound_method(train_on_batch, model) 97 | 98 | model._orig_test_on_batch = model.test_on_batch 99 | 100 | def test_on_batch(self, x, y, **kwargs): 101 | kludge = x.pop('kludge', None) 102 | outputs = self._orig_test_on_batch(x, y, **kwargs) 103 | kludge['outputs'] = outputs.pop() 104 | if len(outputs) == 1: 105 | return outputs[0] 106 | return outputs 107 | 108 | model.test_on_batch = six.create_bound_method(test_on_batch, model) 109 | 110 | # Below is copied and modified from Keras Model._make_train_function. 111 | # The only change is the addition of `self.outputs` to the train function. 112 | def _make_train_function(self): 113 | if not hasattr(self, 'train_function'): 114 | raise RuntimeError('You must compile your model before using it.') 115 | if self.train_function is None: 116 | inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights 117 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 118 | inputs += [K.learning_phase()] 119 | 120 | with K.name_scope('training'): 121 | with K.name_scope(self.optimizer.__class__.__name__): 122 | training_updates = self.optimizer.get_updates( 123 | params=self._collected_trainable_weights, 124 | loss=self.total_loss) 125 | updates = self.updates + training_updates 126 | # Gets loss and metrics. Updates weights at each call. 127 | self.train_function = K.function(inputs, 128 | [self.total_loss] + self.metrics_tensors + self.outputs, 129 | updates=updates, 130 | name='train_function', 131 | **self._function_kwargs) 132 | 133 | model._make_train_function = six.create_bound_method(_make_train_function, model) 134 | 135 | def _make_test_function(self): 136 | if not hasattr(self, 'test_function'): 137 | raise RuntimeError('You must compile your model before using it.') 138 | if self.test_function is None: 139 | inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights 140 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 141 | inputs += [K.learning_phase()] 142 | # Return loss and metrics, no gradient updates. 143 | # Does update the network states. 144 | self.test_function = K.function(inputs, 145 | [self.total_loss] + self.metrics_tensors + self.outputs, 146 | updates=self.state_updates, 147 | name='test_function', 148 | **self._function_kwargs) 149 | 150 | model._make_test_function = six.create_bound_method(_make_test_function, model) 151 | 152 | 153 | class GeneratorReset(Callback): 154 | """Keras epoch end callback to reset prediction copy kludges. 155 | """ 156 | def __init__(self, gens): 157 | self.gens = gens 158 | 159 | def on_epoch_end(self, epoch, logs=None): 160 | for gen in self.gens: 161 | gen.reset() 162 | 163 | 164 | class GeneratorSubvolumeMetric(Callback): 165 | """Add a data generator's subvolume metric to Keras' metric logs. 166 | 167 | Parameters 168 | ---------- 169 | gens : iterable of diluvian.training.MovingTrainingGenerator 170 | metric_name : string 171 | """ 172 | def __init__(self, gens, metric_name): 173 | self.gens = gens 174 | self.metric_name = metric_name 175 | 176 | def on_epoch_end(self, epoch, logs=None): 177 | if self.metric_name not in self.params['metrics']: 178 | self.params['metrics'].append(self.metric_name) 179 | if logs: 180 | metric = np.mean([np.mean(gen.get_epoch_metric()) for gen in self.gens]) 181 | logs[self.metric_name] = metric 182 | 183 | 184 | class EarlyAbortException(Exception): 185 | pass 186 | 187 | 188 | class EarlyAbort(Callback): 189 | """Keras epoch end callback that aborts if a metric is above a threshold. 190 | 191 | This is useful when convergence is sensitive to initial conditions and 192 | models are obviously not useful to continue training after only a few 193 | epochs. Unlike the early stopping callback, this is considered an 194 | abnormal termination and throws an exception so that behaviors like 195 | restarting with a new random seed are possible. 196 | """ 197 | def __init__(self, monitor='val_loss', threshold_epoch=None, threshold_value=None): 198 | if threshold_epoch is None or threshold_value is None: 199 | raise ValueError('Epoch and value to enforce threshold must be provided.') 200 | 201 | self.monitor = monitor 202 | self.threshold_epoch = threshold_epoch - 1 203 | self.threshold_value = threshold_value 204 | 205 | def on_epoch_end(self, epoch, logs=None): 206 | if epoch == self.threshold_epoch: 207 | current = logs.get(self.monitor) 208 | if current >= self.threshold_value: 209 | raise EarlyAbortException('Aborted after epoch {} because {} was {} >= {}'.format( 210 | self.threshold_epoch, self.monitor, current, self.threshold_value)) 211 | 212 | 213 | def preprocess_subvolume_generator(subvolume_generator): 214 | """Apply non-augmentation preprocessing to a subvolume generator. 215 | 216 | Parameters 217 | ---------- 218 | subvolume_generator : diluvian.volumes.SubvolumeGenerator 219 | 220 | Returns 221 | ------- 222 | diluvian.volumes.SubvolumeGenerator 223 | """ 224 | gen = subvolume_generator 225 | if np.any(CONFIG.training.label_erosion): 226 | gen = ErodedMaskGenerator(gen, CONFIG.training.label_erosion) 227 | if CONFIG.training.relabel_seed_component: 228 | gen = RelabelSeedComponentGenerator(gen) 229 | 230 | return gen 231 | 232 | 233 | def augment_subvolume_generator(subvolume_generator): 234 | """Apply data augmentations to a subvolume generator. 235 | 236 | Parameters 237 | ---------- 238 | subvolume_generator : diluvian.volumes.SubvolumeGenerator 239 | 240 | Returns 241 | ------- 242 | diluvian.volumes.SubvolumeGenerator 243 | """ 244 | gen = subvolume_generator 245 | for axes in CONFIG.training.augment_permute_axes: 246 | gen = PermuteAxesAugmentGenerator(gen, CONFIG.training.augment_use_both, axes) 247 | for axis in CONFIG.training.augment_mirrors: 248 | gen = MirrorAugmentGenerator(gen, CONFIG.training.augment_use_both, axis) 249 | for v in CONFIG.training.augment_noise: 250 | gen = GaussianNoiseAugmentGenerator(gen, CONFIG.training.augment_use_both, v['axis'], v['mul'], v['add']) 251 | for v in CONFIG.training.augment_artifacts: 252 | if 'cache' not in v: 253 | v['cache'] = {} 254 | gen = MaskedArtifactAugmentGenerator(gen, CONFIG.training.augment_use_both, 255 | v['axis'], v['prob'], v['volume_file'], v['cache']) 256 | for v in CONFIG.training.augment_missing_data: 257 | gen = MissingDataAugmentGenerator(gen, CONFIG.training.augment_use_both, v['axis'], v['prob']) 258 | for v in CONFIG.training.augment_contrast: 259 | gen = ContrastAugmentGenerator(gen, CONFIG.training.augment_use_both, v['axis'], v['prob'], 260 | v['scaling_mean'], v['scaling_std'], 261 | v['center_mean'], v['center_std']) 262 | gen = ClipSubvolumeImageGenerator(gen) 263 | 264 | return gen 265 | 266 | 267 | class MovingTrainingGenerator(six.Iterator): 268 | """Generate Keras moving FOV training tuples from a subvolume generator. 269 | 270 | This generator expects a subvolume generator that will provide subvolumes 271 | larger than the network FOV, and will allow the output of training at one 272 | batch to generate moves within these subvolumes to produce training data 273 | for the subsequent batch. 274 | 275 | Parameters 276 | ---------- 277 | subvolumes : generator of Subvolume 278 | batch_size : int 279 | kludge : dict 280 | A kludge object to allow this generator to provide inputs and receive 281 | outputs from the network. 282 | See ``diluvian.training.patch_prediction_copy``. 283 | f_a_bins : sequence of float, optional 284 | Bin boundaries for filling fractions. If provided, sample loss will be 285 | weighted to increase loss contribution from less-frequent f_a bins. 286 | Otherwise all samples are weighted equally. 287 | reset_generators : bool 288 | Whether to reset subvolume generators when this generator is reset. 289 | If true subvolumes will be sampled in the same order each epoch. 290 | subv_per_epoch : int, optional 291 | If specified, the generator will only return moves from this many 292 | subvolumes before being reset. Once this number of subvolumes is 293 | exceeded, the generator will yield garbage batches (this is 294 | necessary because Keras currently uses a fixed number of batches 295 | per epoch). If specified, once each subvolume is complete its 296 | total loss will be calculated. 297 | subv_metric_fn : function, option 298 | Metric function to run on subvolumes when `subv_per_epoch` is set. 299 | subv_metric_threshold : bool, optional 300 | Whether to threshold subvolume masks for metrics. 301 | subv_metric_args : dict, optional 302 | Keyword arguments that will be passed to the subvolume metric. 303 | """ 304 | def __init__(self, subvolumes, batch_size, kludge, 305 | f_a_bins=None, reset_generators=True, subv_per_epoch=None, 306 | subv_metric_fn=None, subv_metric_threshold=False, subv_metric_args=None): 307 | self.subvolumes = subvolumes 308 | self.batch_size = batch_size 309 | self.kludge = kludge 310 | self.reset_generators = reset_generators 311 | self.subv_per_epoch = subv_per_epoch 312 | self.subv_metric_fn = subv_metric_fn 313 | self.subv_metric_threshold = subv_metric_threshold 314 | self.subv_metric_args = subv_metric_args 315 | if self.subv_metric_args is None: 316 | self.subv_metric_args = {} 317 | 318 | self.regions = [None] * batch_size 319 | self.region_pos = [None] * batch_size 320 | self.move_counts = [0] * batch_size 321 | self.epoch_move_counts = [] 322 | self.epoch_subv_metrics = [] 323 | self.epoch_subvolumes = 0 324 | self.batch_image_input = [None] * batch_size 325 | 326 | self.f_a_bins = f_a_bins 327 | self.f_a_init = False 328 | if f_a_bins is not None: 329 | self.f_a_init = True 330 | self.f_a_counts = np.ones_like(f_a_bins, dtype=np.int64) 331 | self.f_as = np.zeros(batch_size) 332 | 333 | self.fake_block = None 334 | self.fake_mask = [False] * batch_size 335 | 336 | def __iter__(self): 337 | return self 338 | 339 | def reset(self): 340 | self.f_a_init = False 341 | if self.reset_generators: 342 | self.subvolumes.reset() 343 | self.regions = [None] * self.batch_size 344 | self.kludge['inputs'] = None 345 | self.kludge['outputs'] = None 346 | if len(self.epoch_move_counts): 347 | logging.info(' Average moves (%s): %s', 348 | self.subvolumes.name, 349 | sum(self.epoch_move_counts)/float(len(self.epoch_move_counts))) 350 | self.epoch_move_counts = [] 351 | self.epoch_subvolumes = 0 352 | self.epoch_subv_metrics = [] 353 | self.fake_mask = [False] * self.batch_size 354 | 355 | def get_epoch_metric(self): 356 | assert len(self.epoch_subv_metrics) == self.subv_per_epoch, \ 357 | 'Not all validation subvs completed: {}/{} (Finished moves: {}, ongoing: {})'.format( 358 | len(self.epoch_subv_metrics), self.subv_per_epoch, self.epoch_move_counts, self.move_counts) 359 | return self.epoch_subv_metrics 360 | 361 | def __next__(self): 362 | # If in the fixed-subvolumes-per-epoch mode and completed, yield fake 363 | # data quickly. 364 | if all(self.fake_mask): 365 | inputs = collections.OrderedDict({ 366 | 'image_input': np.repeat(pad_dims(self.fake_block['image']), 367 | CONFIG.training.num_gpus, axis=0), 368 | 'mask_input': np.repeat(pad_dims(self.fake_block['mask']), 369 | CONFIG.training.num_gpus, axis=0) 370 | }) 371 | inputs['kludge'] = self.kludge 372 | outputs = np.repeat(pad_dims(self.fake_block['target']), CONFIG.training.num_gpus, axis=0) 373 | return (inputs, outputs) 374 | 375 | # Before clearing last batches, reuse them to predict mask outputs 376 | # for move training. Add mask outputs to regions. 377 | active_regions = [n for n, region in enumerate(self.regions) if region is not None] 378 | if active_regions and self.kludge['outputs'] is not None and self.kludge['inputs'] is not None: 379 | for n in active_regions: 380 | assert np.array_equal(self.kludge['inputs'][n, :], 381 | self.batch_image_input[n, 0, 0, :, 0]) 382 | self.regions[n].add_mask(self.kludge['outputs'][n, :, :, :, 0], self.region_pos[n]) 383 | 384 | self.batch_image_input = [None] * self.batch_size 385 | batch_mask_input = [None] * self.batch_size 386 | batch_mask_target = [None] * self.batch_size 387 | 388 | for r, region in enumerate(self.regions): 389 | block_data = region.get_next_block() if region is not None else None 390 | if block_data is None: 391 | if self.subv_per_epoch: 392 | if region is not None: 393 | metric = region.prediction_metric( 394 | self.subv_metric_fn, 395 | threshold=self.subv_metric_threshold, 396 | **self.subv_metric_args) 397 | self.epoch_subv_metrics.append(metric) 398 | self.regions[r] = None 399 | if self.epoch_subvolumes >= self.subv_per_epoch: 400 | block_data = self.fake_block 401 | self.fake_mask[r] = True 402 | while block_data is None: 403 | subvolume = six.next(self.subvolumes) 404 | self.epoch_subvolumes += 1 405 | self.f_as[r] = subvolume.f_a() 406 | 407 | self.regions[r] = Region.from_subvolume(subvolume) 408 | if region is not None: 409 | self.epoch_move_counts.append(self.move_counts[r]) 410 | region = self.regions[r] 411 | self.move_counts[r] = 0 412 | block_data = region.get_next_block() 413 | else: 414 | self.move_counts[r] += 1 415 | 416 | if self.subv_per_epoch and self.fake_block is None: 417 | assert block_data is not None 418 | self.fake_block = copy.deepcopy(block_data) 419 | 420 | self.batch_image_input[r] = pad_dims(block_data['image']) 421 | batch_mask_input[r] = pad_dims(block_data['mask']) 422 | batch_mask_target[r] = pad_dims(block_data['target']) 423 | self.region_pos[r] = block_data['position'] 424 | 425 | self.batch_image_input = np.concatenate(self.batch_image_input) 426 | batch_mask_input = np.concatenate(batch_mask_input) 427 | batch_mask_target = np.concatenate(batch_mask_target) 428 | 429 | inputs = collections.OrderedDict({'image_input': self.batch_image_input, 430 | 'mask_input': batch_mask_input}) 431 | inputs['kludge'] = self.kludge 432 | # These inputs are only necessary for assurance the correct FOV is updated. 433 | self.kludge['inputs'] = self.batch_image_input[:, 0, 0, :, 0].copy() 434 | self.kludge['outputs'] = None 435 | 436 | if self.f_a_bins is None: 437 | return (inputs, 438 | [batch_mask_target]) 439 | else: 440 | f_a_inds = np.digitize(self.f_as, self.f_a_bins) - 1 441 | inds, counts = np.unique(f_a_inds, return_counts=True) 442 | if self.f_a_init: 443 | self.f_a_counts[inds] += counts.astype(np.int64) 444 | sample_weights = np.ones(self.f_as.size, dtype=np.float64) 445 | else: 446 | sample_weights = np.reciprocal(self.f_a_counts[f_a_inds], dtype=np.float64) * float(self.f_as.size) 447 | return (inputs, 448 | [batch_mask_target], 449 | sample_weights) 450 | 451 | 452 | DataGenerator = collections.namedtuple('DataGenerator', ['data', 'gens', 'callbacks', 'steps_per_epoch']) 453 | 454 | 455 | def get_output_margin(model_config): 456 | return np.floor_divide(model_config.input_fov_shape - model_config.output_fov_shape, 2) 457 | 458 | 459 | def build_validation_gen(validation_volumes): 460 | output_margin = get_output_margin(CONFIG.model) 461 | 462 | # If there is only one volume, duplicate since more than one is needed 463 | # for Keras queuing. 464 | if len(validation_volumes) == 1: 465 | single_vol = six.next(six.itervalues(validation_volumes)) 466 | validation_volumes = {'dupe {}'.format(n): single_vol for n in range(CONFIG.training.num_workers)} 467 | 468 | validation_gens = [ 469 | preprocess_subvolume_generator( 470 | v.subvolume_generator(shape=CONFIG.model.validation_subv_shape, 471 | label_margin=output_margin)) 472 | for v in six.itervalues(validation_volumes)] 473 | if CONFIG.training.augment_validation: 474 | validation_gens = list(map(augment_subvolume_generator, validation_gens)) 475 | 476 | # Divide training generators up for workers. 477 | validation_worker_gens = [ 478 | validation_gens[i::CONFIG.training.num_workers] 479 | for i in xrange(CONFIG.training.num_workers)] 480 | 481 | # Some workers may not receive any generators. 482 | validation_worker_gens = [g for g in validation_worker_gens if len(g) > 0] 483 | subv_per_worker = CONFIG.training.validation_size // len(validation_worker_gens) 484 | logging.debug('# of validation workers: %s', len(validation_worker_gens)) 485 | 486 | validation_metric = get_function(CONFIG.training.validation_metric['metric']) 487 | validation_kludges = [{'inputs': None, 'outputs': None} for _ in range(CONFIG.training.num_workers)] 488 | validation_data = [MovingTrainingGenerator( 489 | Roundrobin(*gen, name='validation {}'.format(i)), 490 | CONFIG.training.batch_size, 491 | kludge, 492 | f_a_bins=CONFIG.training.fill_factor_bins, 493 | reset_generators=True, 494 | subv_per_epoch=subv_per_worker, 495 | subv_metric_fn=validation_metric, 496 | subv_metric_threshold=CONFIG.training.validation_metric['threshold'], 497 | subv_metric_args=CONFIG.training.validation_metric['args']) 498 | for i, (gen, kludge) in enumerate(zip(validation_worker_gens, validation_kludges))] 499 | 500 | callbacks = [] 501 | callbacks.append(GeneratorSubvolumeMetric(validation_data, 'val_subv_metric')) 502 | callbacks.append(GeneratorReset(validation_data)) 503 | 504 | VALIDATION_STEPS = np.ceil(CONFIG.training.validation_size / CONFIG.training.batch_size) 505 | # Number of all-move sequences must be a multiple of number of worker gens. 506 | VALIDATION_STEPS = np.ceil(VALIDATION_STEPS / len(validation_worker_gens)) * len(validation_worker_gens) 507 | VALIDATION_STEPS = VALIDATION_STEPS * CONFIG.model.validation_subv_moves + len(validation_worker_gens) 508 | VALIDATION_STEPS = VALIDATION_STEPS.astype(np.int64) 509 | 510 | return DataGenerator( 511 | data=validation_data, 512 | gens=validation_worker_gens, 513 | callbacks=callbacks, 514 | steps_per_epoch=VALIDATION_STEPS) 515 | 516 | 517 | def build_training_gen(training_volumes): 518 | output_margin = get_output_margin(CONFIG.model) 519 | 520 | # If there is only one volume, duplicate since more than one is needed 521 | # for Keras queuing. 522 | if len(training_volumes) == 1: 523 | single_vol = six.next(six.itervalues(training_volumes)) 524 | training_volumes = {'dupe {}'.format(n): single_vol for n in range(CONFIG.training.num_workers)} 525 | 526 | training_gens = [ 527 | augment_subvolume_generator( 528 | preprocess_subvolume_generator( 529 | v.subvolume_generator(shape=CONFIG.model.training_subv_shape, 530 | label_margin=output_margin))) 531 | for v in six.itervalues(training_volumes)] 532 | random.shuffle(training_gens) 533 | 534 | # Divide training generators up for workers. 535 | worker_gens = [ 536 | training_gens[i::CONFIG.training.num_workers] 537 | for i in xrange(CONFIG.training.num_workers)] 538 | 539 | # Some workers may not receive any generators. 540 | worker_gens = [g for g in worker_gens if len(g) > 0] 541 | logging.debug('# of training workers: %s', len(worker_gens)) 542 | 543 | kludges = [{'inputs': None, 'outputs': None} for _ in range(CONFIG.training.num_workers)] 544 | # Create a training data generator for each worker. 545 | training_data = [MovingTrainingGenerator( 546 | Roundrobin(*gen, name='training {}'.format(i)), 547 | CONFIG.training.batch_size, 548 | kludge, 549 | f_a_bins=CONFIG.training.fill_factor_bins, 550 | reset_generators=CONFIG.training.reset_generators) 551 | for i, (gen, kludge) in enumerate(zip(worker_gens, kludges))] 552 | training_reset_callback = GeneratorReset(training_data) 553 | callbacks = [training_reset_callback] 554 | 555 | TRAINING_STEPS_PER_EPOCH = CONFIG.training.training_size // CONFIG.training.batch_size 556 | 557 | return DataGenerator( 558 | data=training_data, 559 | gens=worker_gens, 560 | callbacks=callbacks, 561 | steps_per_epoch=TRAINING_STEPS_PER_EPOCH) 562 | 563 | 564 | def train_network( 565 | model_file=None, 566 | volumes=None, 567 | model_output_filebase=None, 568 | model_checkpoint_file=None, 569 | tensorboard=False, 570 | viewer=False, 571 | metric_plot=False): 572 | random.seed(CONFIG.random_seed) 573 | 574 | tf_device = 'cpu:0' if CONFIG.training.num_gpus > 1 else 'gpu:0' 575 | 576 | if model_file is None: 577 | factory = get_function(CONFIG.network.factory) 578 | with tf.device(tf_device): 579 | ffn = factory(CONFIG.model.input_fov_shape, 580 | CONFIG.model.output_fov_shape, 581 | CONFIG.network) 582 | else: 583 | with tf.device(tf_device): 584 | ffn = load_model(model_file, CONFIG.network) 585 | 586 | # Multi-GPU models are saved as a single-GPU model prior to compilation, 587 | # so if loading from such a model file it will need to be recompiled. 588 | if not hasattr(ffn, 'optimizer'): 589 | if CONFIG.training.num_gpus > 1: 590 | ffn = make_parallel(ffn, CONFIG.training.num_gpus) 591 | compile_network(ffn, CONFIG.optimizer) 592 | 593 | patch_prediction_copy(ffn) 594 | 595 | if model_output_filebase is None: 596 | model_output_filebase = 'model_output' 597 | 598 | if volumes is None: 599 | raise ValueError('Volumes must be provided.') 600 | 601 | CONFIG.to_toml(model_output_filebase + '.toml') 602 | 603 | training_volumes, validation_volumes = partition_volumes(volumes) 604 | 605 | num_training = len(training_volumes) 606 | num_validation = len(validation_volumes) 607 | 608 | logging.info('Using {} volumes for training, {} for validation.'.format(num_training, num_validation)) 609 | 610 | validation = build_validation_gen(validation_volumes) 611 | training = build_training_gen(training_volumes) 612 | 613 | callbacks = [] 614 | callbacks.extend(validation.callbacks) 615 | callbacks.extend(training.callbacks) 616 | 617 | validation_mode = CONFIG.training.validation_metric['mode'] 618 | 619 | if CONFIG.training.early_abort_epoch is not None and \ 620 | CONFIG.training.early_abort_loss is not None: 621 | callbacks.append(EarlyAbort(threshold_epoch=CONFIG.training.early_abort_epoch, 622 | threshold_value=CONFIG.training.early_abort_loss)) 623 | 624 | callbacks.append(ModelCheckpoint(model_output_filebase + '.hdf5', 625 | monitor='val_subv_metric', 626 | save_best_only=True, 627 | mode=validation_mode)) 628 | if model_checkpoint_file: 629 | callbacks.append(ModelCheckpoint(model_checkpoint_file)) 630 | callbacks.append(EarlyStopping(monitor='val_subv_metric', 631 | patience=CONFIG.training.patience, 632 | mode=validation_mode)) 633 | # Activation histograms and weight images for TensorBoard will not work 634 | # because the Keras callback does not currently support validation data 635 | # generators. 636 | if tensorboard: 637 | callbacks.append(TensorBoard()) 638 | 639 | history = ffn.fit_generator( 640 | Roundrobin(*training.data, name='training outer'), 641 | steps_per_epoch=training.steps_per_epoch, 642 | epochs=CONFIG.training.total_epochs, 643 | max_queue_size=len(training.gens) - 1, 644 | workers=1, 645 | callbacks=callbacks, 646 | validation_data=Roundrobin(*validation.data, name='validation outer'), 647 | validation_steps=validation.steps_per_epoch) 648 | 649 | write_keras_history_to_csv(history, model_output_filebase + '.csv') 650 | 651 | if viewer: 652 | viz_ex = itertools.islice(validation.data[0], 1) 653 | 654 | for inputs, targets in viz_ex: 655 | viewer = WrappedViewer(voxel_size=list(np.flipud(CONFIG.volume.resolution))) 656 | output_offset = np.array(inputs['image_input'].shape[1:4]) - np.array(targets[0].shape[1:4]) 657 | output_offset = np.flipud(output_offset // 2) 658 | viewer.add(inputs['image_input'][0, :, :, :, 0], 659 | name='Image') 660 | viewer.add(inputs['mask_input'][0, :, :, :, 0], 661 | name='Mask Input', 662 | shader=get_color_shader(2)) 663 | viewer.add(targets[0][0, :, :, :, 0], 664 | name='Mask Target', 665 | shader=get_color_shader(0), 666 | voxel_offset=output_offset) 667 | output = ffn.predict_on_batch(inputs) 668 | viewer.add(output[0, :, :, :, 0], 669 | name='Mask Output', 670 | shader=get_color_shader(1), 671 | voxel_offset=output_offset) 672 | 673 | viewer.print_view_prompt() 674 | 675 | if metric_plot: 676 | fig = plot_history(history) 677 | fig.savefig(model_output_filebase + '.png') 678 | 679 | return history 680 | 681 | 682 | def validate_model(model_file, volumes): 683 | from .network import load_model 684 | 685 | _, volumes = partition_volumes(volumes) 686 | 687 | validation = build_validation_gen(volumes) 688 | 689 | tf_device = 'cpu:0' if CONFIG.training.num_gpus > 1 else 'gpu:0' 690 | with tf.device(tf_device): 691 | model = load_model(model_file, CONFIG.network) 692 | 693 | # Multi-GPU models are saved as a single-GPU model prior to compilation, 694 | # so if loading from such a model file it will need to be recompiled. 695 | if not hasattr(model, 'optimizer'): 696 | if CONFIG.training.num_gpus > 1: 697 | model = make_parallel(model, CONFIG.training.num_gpus) 698 | compile_network(model, CONFIG.optimizer) 699 | 700 | patch_prediction_copy(model) 701 | 702 | pbar = tqdm(desc='Validation batches', total=validation.steps_per_epoch) 703 | finished = [False] * len(validation.gens) 704 | 705 | for n, data in itertools.cycle(enumerate(validation.data)): 706 | if all(finished): 707 | break 708 | 709 | pbar.update(1) 710 | 711 | if all(data.fake_mask): 712 | finished[n] = True 713 | continue 714 | 715 | batch = six.next(data) 716 | model.test_on_batch(*batch) 717 | 718 | pbar.close() 719 | 720 | metrics = [] 721 | for gen in validation.data: 722 | metrics.extend(gen.get_epoch_metric()) 723 | 724 | print('Metric: ', np.mean(metrics)) 725 | print('All: ', metrics) 726 | --------------------------------------------------------------------------------