├── .autorc ├── .dockerignore ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ ├── feature_request.md │ ├── maintenance.md │ └── question.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── ci.yml │ ├── guide-notebooks-ec2.yml │ ├── publish.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .zenodo.json ├── CHANGELOG.md ├── CITATION ├── LICENSE ├── MANIFEST.in ├── README.md ├── docker ├── README.md ├── cpu.Dockerfile └── gpu.Dockerfile ├── nobrainer ├── __init__.py ├── _version.py ├── bayesian_utils.py ├── cli │ ├── __init__.py │ ├── main.py │ └── tests │ │ ├── __init__.py │ │ └── main_test.py ├── dataset.py ├── distributed_learning │ └── dwc.py ├── intensity_transforms.py ├── io.py ├── layers │ ├── Conv_v.py │ ├── InstanceNorm.py │ ├── __init__.py │ ├── dropout.py │ ├── groupnorm.py │ ├── max_pool4d.py │ ├── padding.py │ └── tests │ │ ├── __init__.py │ │ ├── dropout_test.py │ │ └── layers_test.py ├── losses.py ├── metrics.py ├── models │ ├── __init__.py │ ├── attention_unet.py │ ├── attention_unet_with_inception.py │ ├── autoencoder.py │ ├── bayesian_meshnet.py │ ├── bayesian_vnet.py │ ├── bayesian_vnet_semi.py │ ├── brainsiam.py │ ├── dcgan.py │ ├── highresnet.py │ ├── meshnet.py │ ├── progressiveae.py │ ├── progressivegan.py │ ├── tests │ │ ├── __init__.py │ │ └── models_test.py │ ├── unet.py │ ├── unet_lstm.py │ ├── unetr.py │ ├── vnet.py │ └── vox2vox.py ├── prediction.py ├── processing │ ├── __init__.py │ ├── base.py │ ├── checkpoint.py │ ├── generation.py │ └── segmentation.py ├── spatial_transforms.py ├── tests │ ├── __init__.py │ ├── checkpoint_test.py │ ├── dataset_test.py │ ├── io_test.py │ ├── losses_test.py │ ├── metrics_test.py │ ├── prediction_test.py │ ├── test_intensity_transforms.py │ ├── test_spatial_transforms.py │ ├── test_utils.py │ ├── tfrecord_test.py │ ├── transform_test.py │ ├── utils.py │ └── volume_test.py ├── tfrecord.py ├── training.py ├── transform.py ├── utils.py ├── validation.py └── volume.py ├── pyproject.toml ├── setup.cfg ├── setup.py └── versioneer.py /.autorc: -------------------------------------------------------------------------------- 1 | { 2 | "onlyPublishWithReleaseLabel": true, 3 | "baseBranch": "master", 4 | "author": "Nobrainer Bot ", 5 | "noVersionPrefix": true, 6 | "plugins": ["git-tag"] 7 | } 8 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | docker/ 3 | .idea/ 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | nobrainer/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Report a bug (e.g., something not working as described, missing/incorrect documentation). 4 | title: '' 5 | labels: 'bug' 6 | assignees: '' 7 | 8 | --- 9 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation improvement 3 | about: Request improvements to the documentation and tutorials. 4 | title: '' 5 | labels: 'documentation' 6 | assignees: '' 7 | 8 | --- 9 | 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Propose a new feature or a change to an existing feature. 4 | title: '' 5 | labels: 'feature' 6 | assignees: '' 7 | 8 | --- 9 | 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/maintenance.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Maintenance and delivery 3 | about: Suggestions and requests regarding the infrastructure for development, testing, and delivery. 4 | title: '' 5 | labels: 'maintenance' 6 | assignees: '' 7 | 8 | --- 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Not sure if you are using Nobrainer correctly, or other questions? This is the place. 4 | title: '' 5 | labels: 'question' 6 | assignees: '' 7 | 8 | --- 9 | 16 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Types of changes 2 | 3 | - [ ] Bug fix (non-breaking change which fixes an issue) 4 | - [ ] New feature (non-breaking change which adds functionality) 5 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 6 | 7 | ## Summary 8 | 9 | 10 | ## Checklist 11 | 12 | - [ ] I have added tests to cover my changes 13 | - [ ] I have updated documentation (if necessary) 14 | 15 | ## Acknowledgment 16 | - [ ] I acknowledge that this contribution will be available under the Apache 2 license. 17 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | branches: [master] 6 | 7 | jobs: 8 | build: 9 | 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [ubuntu-20.04] 14 | python-version: ["3.11", "3.10", "3.9"] 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | with: # no need for the history 19 | fetch-depth: 1 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip setuptools 28 | pip install --upgrade --force-reinstall --no-cache-dir --editable=".[dev]" 29 | - name: Test with pytest 30 | run: | 31 | pytest 32 | image-build: 33 | runs-on: ${{ matrix.os }} 34 | strategy: 35 | matrix: 36 | os: [ubuntu-20.04] 37 | python-version: ["3.11", "3.10", "3.9"] 38 | steps: 39 | - uses: actions/checkout@v3 40 | with: 41 | fetch-depth: 1 42 | - name: Test CPU Docker image build 43 | run: | 44 | docker build -t neuronets/nobrainer:master-cpu -f docker/cpu.Dockerfile . 45 | - name: Test GPU Docker image build 46 | run: | 47 | docker build -t neuronets/nobrainer:master-gpu -f docker/gpu.Dockerfile . 48 | -------------------------------------------------------------------------------- /.github/workflows/guide-notebooks-ec2.yml: -------------------------------------------------------------------------------- 1 | name: Guide Notebooks Regression - EC2 2 | run-name: ${{ github.ref_name }} - Guide Notebooks Regression - EC2 3 | on: [push] 4 | jobs: 5 | start-runner: 6 | name: Start self-hosted EC2 runner 7 | runs-on: ubuntu-latest 8 | outputs: 9 | label: ${{ steps.start-ec2-runner.outputs.label }} 10 | ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} 11 | steps: 12 | - name: Configure AWS credentials 13 | uses: aws-actions/configure-aws-credentials@v1 14 | with: 15 | aws-access-key-id: ${{ secrets.AWS_KEY_ID }} 16 | aws-secret-access-key: ${{ secrets.AWS_KEY_SECRET }} 17 | aws-region: ${{ vars.AWS_REGION }} 18 | - name: Start EC2 runner 19 | id: start-ec2-runner 20 | uses: machulav/ec2-github-runner@v2 21 | with: 22 | mode: start 23 | github-token: ${{ secrets.GH_TOKEN }} 24 | ec2-image-id: ${{ vars.AWS_IMAGE_ID }} 25 | ec2-instance-type: ${{ vars.AWS_INSTANCE_TYPE }} 26 | subnet-id: ${{ vars.AWS_SUBNET }} 27 | security-group-id: ${{ vars.AWS_SECURITY_GROUP }} 28 | 29 | guide_notebooks_regression_ec2: 30 | needs: start-runner # required to start the main job when the runner is ready 31 | runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner 32 | steps: 33 | - name: clone 34 | uses: actions/checkout@v3 35 | - name: install 36 | run: | 37 | set -xe 38 | cd ${{ github.workspace }} 39 | source /opt/tensorflow/bin/activate 40 | export LD_LIBRARY_PATH=opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib 41 | echo $LD_LIBRARY_PATH 42 | pip install matplotlib nilearn 43 | pip install -U tensorflow 44 | pip install -e . 45 | nobrainer info 46 | - name: run 47 | run: | 48 | set -xe 49 | cd ${{ github.workspace }} 50 | git clone https://github.com/neuronets/nobrainer-book.git 51 | cd nobrainer-book 52 | 53 | # if there is a matching book branch, switch to it 54 | if [ $(git ls-remote --heads https://github.com/neuronets/nobrainer-book.git ${{ github.ref_name }} | wc -l) -ne 0 ]; then 55 | echo "Checking out branch ${{ github.ref_name }}" 56 | git checkout ${{ github.ref_name }}; 57 | else 58 | echo "No matching branch found, sticking with the default" 59 | fi 60 | 61 | cd ${{ github.workspace }} 62 | export LD_LIBRARY_PATH=opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/usr/local/cuda/efa/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib 63 | source /opt/tensorflow/bin/activate 64 | for notebook_script in $(ls nobrainer-book/docs/nobrainer-guides/scripts/*.py); do 65 | echo "running ${notebook_script}" 66 | python ${notebook_script} 67 | done 68 | 69 | stop-runner: 70 | name: Stop self-hosted EC2 runner 71 | needs: 72 | - start-runner # required to get output from the start-runner job 73 | - guide_notebooks_regression_ec2 # required to wait when the main job is done 74 | runs-on: ubuntu-latest 75 | if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs 76 | steps: 77 | - name: Configure AWS credentials 78 | uses: aws-actions/configure-aws-credentials@v1 79 | with: 80 | aws-access-key-id: ${{ secrets.AWS_KEY_ID }} 81 | aws-secret-access-key: ${{ secrets.AWS_KEY_SECRET }} 82 | aws-region: ${{ vars.AWS_REGION }} 83 | - name: Stop EC2 runner 84 | uses: machulav/ec2-github-runner@v2 85 | with: 86 | mode: stop 87 | github-token: ${{ secrets.GH_TOKEN }} 88 | label: ${{ needs.start-runner.outputs.label }} 89 | ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} 90 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to pypi on Github release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | pypi-release: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | 13 | - name: Set up Python 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.9 17 | 18 | - name: Install build & twine 19 | run: python -m pip install build twine 20 | 21 | - name: Publish to pypi 22 | run: | 23 | python -m build 24 | twine upload dist/* 25 | env: 26 | TWINE_USERNAME: __token__ 27 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 28 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Auto-release on PR merge 2 | 3 | on: 4 | # ATM, this is the closest trigger to a PR merging 5 | push: 6 | branches: 7 | - master 8 | 9 | env: 10 | AUTO_VERSION: v11.0.5 11 | 12 | jobs: 13 | auto-release: 14 | runs-on: ubuntu-latest 15 | if: "!contains(github.event.head_commit.message, 'ci skip') && !contains(github.event.head_commit.message, 'skip ci')" 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Prepare repository 20 | # Fetch full git history and tags 21 | run: git fetch --unshallow --tags 22 | 23 | - name: Unset header 24 | # checkout@v2 adds a header that makes branch protection report errors 25 | # because the Github action bot is not a collaborator on the repo 26 | run: git config --local --unset http.https://github.com/.extraheader 27 | 28 | - name: Set up Python 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: 3.9 32 | 33 | - name: Download auto 34 | run: | 35 | auto_download_url="$(curl -fsSL https://api.github.com/repos/intuit/auto/releases/tags/$AUTO_VERSION | jq -r '.assets[] | select(.name == "auto-linux.gz") | .browser_download_url')" 36 | wget -O- "$auto_download_url" | gunzip > ~/auto 37 | chmod a+x ~/auto 38 | 39 | - name: Create release 40 | run: | 41 | ~/auto shipit -vv 42 | env: 43 | GH_TOKEN: ${{ secrets.AUTO_USER_TOKEN }} 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # pycharm 141 | .idea/ 142 | 143 | # guide data 144 | guide/data/ 145 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | ci: 4 | skip: [codespell] 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.5.0 9 | hooks: 10 | - id: check-added-large-files 11 | - id: check-yaml 12 | - id: end-of-file-fixer 13 | - id: trailing-whitespace 14 | - repo: https://github.com/psf/black 15 | rev: 24.3.0 16 | hooks: 17 | - id: black 18 | - repo: https://github.com/PyCQA/flake8 19 | rev: 7.0.0 20 | hooks: 21 | - id: flake8 22 | - repo: https://github.com/PyCQA/isort 23 | rev: 5.13.2 24 | hooks: 25 | - id: isort 26 | exclude: ^(nobrainer/_version\.py|versioneer\.py)$ 27 | - repo: https://github.com/codespell-project/codespell 28 | rev: v2.2.6 29 | hooks: 30 | - id: codespell 31 | exclude: ^(nobrainer/_version\.py|versioneer\.py|pyproject\.toml|CHANGELOG\.md)$ 32 | -------------------------------------------------------------------------------- /.zenodo.json: -------------------------------------------------------------------------------- 1 | { 2 | "creators": [ 3 | { 4 | "affiliation": "Stony Brook University", 5 | "name": "Kaczmarzyk, Jakub", 6 | "orcid": "0000-0002-5544-7577" 7 | }, 8 | { 9 | "affiliation": "NIMH", 10 | "name": "McClure, Patrick" 11 | }, 12 | { 13 | "affiliation": "MIT", 14 | "name": "Zulfikar, Wazeer" 15 | }, 16 | { 17 | "affiliation": "MIT", 18 | "name": "Rana, Aakanksha", 19 | "orcid": "0000-0002-8350-7602" 20 | }, 21 | { 22 | "affiliation": "MIT", 23 | "name": "Rajaei, Hoda", 24 | "orcid": "0000-0002-0754-5586" 25 | }, 26 | { 27 | "affiliation": "University of Washington", 28 | "name": "Richie-Halford, Adam", 29 | "orcid": "0000-0001-9276-9084" 30 | }, 31 | { 32 | "affiliation": "Department of Psychology, Stanford University", 33 | "name": "Bansal, Shashank", 34 | "orcid": "0000-0002-1252-8772" 35 | }, 36 | { 37 | "affiliation": "MIT", 38 | "name": "Jarecka, Dorota", 39 | "orcid": "0000-0001-8282-2988" 40 | }, 41 | { 42 | "affiliation": "NIMH", 43 | "name": "Lee, John" 44 | }, 45 | { 46 | "affiliation": "MIT, HMS", 47 | "name": "Ghosh, Satrajit", 48 | "orcid": "0000-0002-5312-6729" 49 | } 50 | ], 51 | "keywords": [ 52 | "neuroimaging", 53 | "deep learning", 54 | "bayesian neural network" 55 | ], 56 | "license": "Apache-2.0", 57 | "upload_type": "software" 58 | } 59 | -------------------------------------------------------------------------------- /CITATION: -------------------------------------------------------------------------------- 1 | Please follow this DOI (https://doi.org/10.5281/zenodo.4995077) to find 2 | the latest citation on Zenodo. The different citation formats are available 3 | in the Share and Export sections of the page. On a desktop browser these 4 | are on the bottom right of the page. 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 The Nobrainer Authors. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # This line includes versioneer.py in sdists, which is necessary for wheels 2 | # built from sdists to have the version set in their metadata. 3 | include versioneer.py 4 | include CHANGELOG.md tox.ini 5 | 6 | graft nobrainer 7 | 8 | global-exclude *.py[cod] 9 | include nobrainer/_version.py 10 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Nobrainer in a container 2 | 3 | The Dockerfiles in this directory can be used to create Docker images to use _Nobrainer_ on CPU or GPU. 4 | 5 | ## Build images 6 | 7 | ```bash 8 | cd /code/nobrainer # Top-level nobrainer directory 9 | docker build -t neuronets/nobrainer:master-cpu -f docker/cpu.Dockerfile . 10 | docker build -t neuronets/nobrainer:master-gpu -f docker/gpu.Dockerfile . 11 | ``` 12 | 13 | # Convert Docker images to Singularity containers 14 | 15 | Using Singularity version 3.x, Docker images can be converted to Singularity containers using the `singularity` command-line tool. 16 | 17 | ## Pulling from DockerHub 18 | 19 | In most cases (e.g., working on a HPC cluster), the _Nobrainer_ singularity container can be created with: 20 | 21 | ```bash 22 | singularity pull docker://neuronets/nobrainer:master-gpu 23 | ``` 24 | 25 | ## Building from local Docker cache 26 | 27 | If you built a _Nobrainer_ Docker images locally and would like to convert it to a Singularity container, you can do so with: 28 | 29 | ```bash 30 | sudo singularity pull docker-daemon://neuronets/nobrainer:master-gpu 31 | ``` 32 | 33 | Please note the use of `sudo` here. This is necessary for interacting with the Docker daemon. 34 | -------------------------------------------------------------------------------- /docker/cpu.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.15.0.post1-jupyter 2 | COPY [".", "/opt/nobrainer"] 3 | RUN cd /opt/nobrainer \ 4 | && sed -i 's/tensorflow >=/tensorflow-cpu >=/g' setup.cfg 5 | RUN python3 -m pip install --no-cache-dir /opt/nobrainer 6 | ENV LC_ALL=C.UTF-8 \ 7 | LANG=C.UTF-8 8 | WORKDIR "/work" 9 | LABEL maintainer="Satrajit Ghosh " 10 | ENTRYPOINT ["nobrainer"] 11 | -------------------------------------------------------------------------------- /docker/gpu.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.15.0.post1-gpu-jupyter 2 | COPY [".", "/opt/nobrainer"] 3 | RUN cd /opt/nobrainer 4 | RUN python3 -m pip install --no-cache-dir /opt/nobrainer 5 | ENV LC_ALL=C.UTF-8 \ 6 | LANG=C.UTF-8 7 | WORKDIR "/work" 8 | LABEL maintainer="Satrajit Ghosh " 9 | ENTRYPOINT ["nobrainer"] 10 | -------------------------------------------------------------------------------- /nobrainer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( # noqa: F401 2 | _version, 3 | dataset, 4 | io, 5 | layers, 6 | losses, 7 | metrics, 8 | models, 9 | prediction, 10 | training, 11 | transform, 12 | utils, 13 | volume, 14 | ) 15 | 16 | __version__ = _version.get_versions()["version"] 17 | -------------------------------------------------------------------------------- /nobrainer/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuronets/nobrainer/976691d685824fd4bba836498abea4184cffd798/nobrainer/cli/__init__.py -------------------------------------------------------------------------------- /nobrainer/cli/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuronets/nobrainer/976691d685824fd4bba836498abea4184cffd798/nobrainer/cli/tests/__init__.py -------------------------------------------------------------------------------- /nobrainer/cli/tests/main_test.py: -------------------------------------------------------------------------------- 1 | """Tests for `nobrainer.cli.main`.""" 2 | 3 | import csv 4 | from pathlib import Path 5 | 6 | from click.testing import CliRunner 7 | import nibabel as nib 8 | import numpy as np 9 | import pytest 10 | 11 | from .. import main as climain 12 | from ...io import read_csv 13 | from ...models.meshnet import meshnet 14 | from ...models.progressivegan import progressivegan 15 | from ...utils import get_data 16 | 17 | 18 | def test_convert_nonscalar_labels(tmp_path): 19 | runner = CliRunner() 20 | with runner.isolated_filesystem(): 21 | csvpath = get_data(tmp_path) 22 | tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") 23 | tfrecords_template.parent.mkdir(exist_ok=True) 24 | args = """\ 25 | convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 26 | --examples-per-shard=2 --to-ras --no-verify-volumes 27 | """.format( 28 | csvpath, tfrecords_template 29 | ) 30 | result = runner.invoke(climain.cli, args.split()) 31 | assert result.exit_code == 0 32 | assert Path("data/shard-000.tfrecords").is_file() 33 | assert Path("data/shard-001.tfrecords").is_file() 34 | assert Path("data/shard-002.tfrecords").is_file() 35 | assert Path("data/shard-003.tfrecords").is_file() 36 | assert Path("data/shard-004.tfrecords").is_file() 37 | assert not Path("data/shard-005.tfrecords").is_file() 38 | 39 | 40 | def test_convert_scalar_int_labels(tmp_path): 41 | runner = CliRunner() 42 | with runner.isolated_filesystem(): 43 | csvpath = get_data(str(tmp_path)) 44 | # Make labels scalars. 45 | data = [(x, 0) for (x, _) in read_csv(csvpath)] 46 | csvpath = tmp_path.with_suffix(".new.csv") 47 | with open(csvpath, "w", newline="") as myfile: 48 | wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) 49 | wr.writerows(data) 50 | tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") 51 | tfrecords_template.parent.mkdir(exist_ok=True) 52 | args = """\ 53 | convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 54 | --examples-per-shard=2 --to-ras --no-verify-volumes 55 | """.format( 56 | csvpath, tfrecords_template 57 | ) 58 | result = runner.invoke(climain.cli, args.split()) 59 | assert result.exit_code == 0 60 | assert Path("data/shard-000.tfrecords").is_file() 61 | assert Path("data/shard-001.tfrecords").is_file() 62 | assert Path("data/shard-002.tfrecords").is_file() 63 | assert Path("data/shard-003.tfrecords").is_file() 64 | assert Path("data/shard-004.tfrecords").is_file() 65 | assert not Path("data/shard-005.tfrecords").is_file() 66 | 67 | 68 | def test_convert_scalar_float_labels(tmp_path): 69 | runner = CliRunner() 70 | with runner.isolated_filesystem(): 71 | csvpath = get_data(str(tmp_path)) 72 | # Make labels scalars. 73 | data = [(x, 1.0) for (x, _) in read_csv(csvpath)] 74 | csvpath = tmp_path.with_suffix(".new.csv") 75 | with open(csvpath, "w", newline="") as myfile: 76 | wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) 77 | wr.writerows(data) 78 | tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") 79 | tfrecords_template.parent.mkdir(exist_ok=True) 80 | args = """\ 81 | convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 82 | --examples-per-shard=2 --to-ras --no-verify-volumes 83 | """.format( 84 | csvpath, tfrecords_template 85 | ) 86 | result = runner.invoke(climain.cli, args.split()) 87 | assert result.exit_code == 0 88 | assert Path("data/shard-000.tfrecords").is_file() 89 | assert Path("data/shard-001.tfrecords").is_file() 90 | assert Path("data/shard-002.tfrecords").is_file() 91 | assert Path("data/shard-003.tfrecords").is_file() 92 | assert Path("data/shard-004.tfrecords").is_file() 93 | assert not Path("data/shard-005.tfrecords").is_file() 94 | 95 | 96 | def test_convert_multi_resolution(tmp_path): 97 | runner = CliRunner() 98 | with runner.isolated_filesystem(): 99 | csvpath = get_data(str(tmp_path)) 100 | # Make labels scalars. 101 | data = [(x, 1.0) for (x, _) in read_csv(csvpath)] 102 | csvpath = tmp_path.with_suffix(".new.csv") 103 | with open(csvpath, "w", newline="") as myfile: 104 | wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) 105 | wr.writerows(data) 106 | tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") 107 | tfrecords_template.parent.mkdir(exist_ok=True) 108 | args = """\ 109 | convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 --start-resolution 64 110 | --examples-per-shard=2 --no-verify-volumes --multi-resolution 111 | """.format( 112 | csvpath, tfrecords_template 113 | ) 114 | result = runner.invoke(climain.cli, args.split()) 115 | assert result.exit_code == 0 116 | 117 | resolutions = [64, 128, 256] 118 | for res in resolutions: 119 | assert Path("data/shard-000-res-{:03d}.tfrecords".format(res)).is_file() 120 | assert Path("data/shard-001-res-{:03d}.tfrecords".format(res)).is_file() 121 | assert Path("data/shard-002-res-{:03d}.tfrecords".format(res)).is_file() 122 | assert Path("data/shard-003-res-{:03d}.tfrecords".format(res)).is_file() 123 | assert Path("data/shard-004-res-{:03d}.tfrecords".format(res)).is_file() 124 | assert not Path("data/shard-005-res-{:03d}.tfrecords".format(res)).is_file() 125 | 126 | 127 | @pytest.mark.xfail 128 | def test_merge(): 129 | assert False 130 | 131 | 132 | def test_predict(): 133 | runner = CliRunner() 134 | with runner.isolated_filesystem(): 135 | model = meshnet(1, (10, 10, 10, 1)) 136 | model_path = "model.h5" 137 | model.save(model_path) 138 | 139 | img_path = "features.nii.gz" 140 | nib.Nifti1Image(np.random.randn(20, 20, 20), np.eye(4)).to_filename(img_path) 141 | out_path = "predictions.nii.gz" 142 | 143 | args = """\ 144 | predict --model={} --block-shape 10 10 10 --resize-features-to 20 20 20 145 | --largest-label --rotate-and-predict {} {} 146 | """.format( 147 | model_path, img_path, out_path 148 | ) 149 | 150 | result = runner.invoke(climain.cli, args.split()) 151 | assert result.exit_code == 0 152 | assert Path("predictions.nii.gz").is_file() 153 | assert nib.load(out_path).shape == (20, 20, 20) 154 | 155 | 156 | def test_generate(): 157 | runner = CliRunner() 158 | with runner.isolated_filesystem(): 159 | generator, _ = progressivegan( 160 | latent_size=256, g_fmap_base=1024, d_fmap_base=1024 161 | ) 162 | resolutions = [8, 16] 163 | Path("models").mkdir(exist_ok=True) 164 | for res in resolutions: 165 | generator.add_resolution() 166 | generator([np.random.random((1, 256)), 1.0]) # to build the model by a call 167 | model_path = "models/generator_res_{}".format(res) 168 | generator.save(model_path) 169 | assert Path(model_path).is_dir() 170 | 171 | out_path = "generated.nii.gz" 172 | 173 | args = """\ 174 | generate --model {} --multi-resolution --latent-size 256 {} 175 | """.format( 176 | "models", out_path 177 | ) 178 | result = runner.invoke(climain.cli, args.split()) 179 | assert result.exit_code == 0 180 | for res in resolutions: 181 | assert Path("generated_res_{}.nii.gz".format(res)).is_file() 182 | assert nib.load("generated_res_{}.nii.gz".format(res)).shape == ( 183 | res, 184 | res, 185 | res, 186 | ) 187 | 188 | 189 | @pytest.mark.xfail 190 | def test_save(): 191 | assert False 192 | 193 | 194 | @pytest.mark.xfail 195 | def test_evaluate(): 196 | assert False 197 | 198 | 199 | def test_info(): 200 | runner = CliRunner() 201 | result = runner.invoke(climain.cli, ["info"]) 202 | assert result.exit_code == 0 203 | assert "Python" in result.output 204 | assert "System" in result.output 205 | assert "Timestamp" in result.output 206 | -------------------------------------------------------------------------------- /nobrainer/distributed_learning/dwc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Distributed weight consolidation for Bayesian Deep Neural Networks 4 | # Implemented according to the: 5 | # McClure, Patrick, et al. Distributed weight consolidation: a brain segmentation case study. 6 | # Advances in neural information processing systems 31 (2018): 4093. 7 | 8 | 9 | def distributed_weight_consolidation(model_weights, model_priors): 10 | # model_weights is a list of weights of client-models; models = [model1, model2, model3...] 11 | # model_priors is a list of priors of client models sames as models 12 | num_layers = int(len(model_weights[0]) / 2.0) 13 | num_datasets = np.shape(model_weights)[0] 14 | consolidated_model = model_weights[0] 15 | mean_idx = [i for i in range(0, len(model_weights[0])) if i % 2 == 0] 16 | std_idx = [i for i in range(0, len(model_weights[0])) if i % 2 != 0] 17 | ep = 1e-5 18 | for i in range(num_layers): 19 | num_1 = 0 20 | num_2 = 0 21 | den_1 = 0 22 | den_2 = 0 23 | for m in range(num_datasets): 24 | model = model_weights[m] 25 | prior = model_priors[m] 26 | mu_s = model[mean_idx[i]] 27 | mu_o = prior[mean_idx[i]] 28 | sig_s = model[std_idx[i]] 29 | sig_o = prior[std_idx[i]] 30 | d1 = np.power(sig_s, 2) + ep 31 | d2 = np.power(sig_o, 2) + ep 32 | num_1 = num_1 + (mu_s / d1) 33 | num_2 = num_2 + (mu_o / d2) 34 | den_1 = den_1 + (1.0 / d1) 35 | den_2 = den_2 + (1.0 / d2) 36 | consolidated_model[mean_idx[i]] = (num_1 - num_2) / (den_1 - den_2) 37 | consolidated_model[std_idx[i]] = 1 / (den_1 - den_2) 38 | return consolidated_model 39 | -------------------------------------------------------------------------------- /nobrainer/io.py: -------------------------------------------------------------------------------- 1 | """Input/output methods.""" 2 | 3 | import csv 4 | import functools 5 | import multiprocessing 6 | import os 7 | 8 | from fsspec.implementations.local import LocalFileSystem 9 | import nibabel as nib 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from .utils import get_num_parallel 14 | 15 | _TFRECORDS_FEATURES_DTYPE = "float32" 16 | 17 | 18 | def read_csv(filepath, skip_header=True, delimiter=","): 19 | """Return list of tuples from a CSV, where each tuple contains the items 20 | in a row. 21 | """ 22 | with open(filepath, newline="") as csvfile: 23 | reader = csv.reader(csvfile, delimiter=delimiter) 24 | if skip_header: 25 | next(reader) 26 | return [tuple(row) for row in reader] 27 | 28 | 29 | def read_mapping(filepath, skip_header=True, delimiter=","): 30 | """Read CSV to dictionary, where first column becomes keys and second 31 | columns becomes values. Other columns are ignored. Keys and values are 32 | coerced to integers. 33 | """ 34 | mapping = read_csv(filepath, skip_header=skip_header, delimiter=delimiter) 35 | if not all(map(lambda r: len(r) >= 2, mapping)): 36 | raise ValueError("not all rows in the mapping have at least 2 values") 37 | try: 38 | return {int(row[0]): int(row[1]) for row in mapping} 39 | except ValueError: 40 | raise ValueError("mapping values must be integers but non-integer encountered") 41 | 42 | 43 | def read_volume(filepath, dtype=None, return_affine=False, to_ras=False): 44 | """Return numpy array of data from a neuroimaging file.""" 45 | img = nib.load(filepath) 46 | if to_ras: 47 | img = nib.as_closest_canonical(img) 48 | data = img.get_fdata(caching="unchanged") 49 | if dtype is not None: 50 | data = data.astype(dtype) 51 | return data if not return_affine else (data, img.affine) 52 | 53 | 54 | def verify_features_labels( 55 | volume_filepaths, 56 | volume_shape=(256, 256, 256), 57 | check_shape=True, 58 | check_labels_int=True, 59 | check_labels_gte_zero=True, 60 | num_parallel_calls=None, 61 | verbose=1, 62 | ): 63 | """Verify a list of files. This function is meant to be run before 64 | converting volumes to TFRecords. 65 | 66 | Parameters 67 | ---------- 68 | volume_filepaths: nested list. Every sublist in the list should contain two 69 | items: (1) path to feature volume and (2) path to label volume or a scalar. 70 | volume_shape: tuple of three ints. Shape that both volumes should be. 71 | check_shape: boolean, if true, validate that the shape of both volumes is 72 | equal to 'volume_shape'. 73 | check_labels_int: boolean, if true, validate that every labels volume is an 74 | integer type or can be safely converted to an integer type. 75 | check_labels_gte_zero: boolean, if true, validate that every labels volume 76 | has values greater than or equal to zero. 77 | num_parallel_calls: int, number of processes to use for multiprocessing. If 78 | None, will use all available processes. 79 | verbose: {0, 1, 2}, verbosity of the progress bar. 0 is silent, 1 is verbose, 80 | and 2 is semi-verbose. 81 | 82 | Returns 83 | ------- 84 | List of invalid pairs of filepaths. If the list is empty, all filepaths are 85 | valid. 86 | """ 87 | from nobrainer.tfrecord import _labels_all_scalar 88 | 89 | for pair in volume_filepaths: 90 | if len(pair) != 2: 91 | raise ValueError( 92 | "all items in 'volume_filepaths' must have length of 2, but" 93 | " found at least one item with length != 2." 94 | ) 95 | 96 | labels = (y for _, y in volume_filepaths) 97 | scalar_labels = _labels_all_scalar(labels) 98 | 99 | for pair in volume_filepaths: 100 | if not os.path.exists(pair[0]): 101 | raise ValueError("file does not exist: {}".format(pair[0])) 102 | if not scalar_labels: 103 | if not os.path.exists(pair[1]): 104 | raise ValueError("file does not exist: {}".format(pair[1])) 105 | 106 | if scalar_labels: 107 | map_fn = functools.partial( 108 | _verify_features_scalar_labels, 109 | volume_shape=volume_shape, 110 | check_shape=check_shape, 111 | ) 112 | else: 113 | map_fn = functools.partial( 114 | _verify_features_nonscalar_labels, 115 | volume_shape=volume_shape, 116 | check_shape=check_shape, 117 | check_labels_int=check_labels_int, 118 | check_labels_gte_zero=check_labels_gte_zero, 119 | ) 120 | if num_parallel_calls is None: 121 | num_parallel_calls = get_num_parallel() 122 | 123 | print("Verifying {} examples".format(len(volume_filepaths))) 124 | progbar = tf.keras.utils.Progbar(len(volume_filepaths), verbose=verbose) 125 | progbar.update(0) 126 | 127 | outputs = [] 128 | if num_parallel_calls == 1: 129 | for vf in volume_filepaths: 130 | valid = map_fn(vf) 131 | outputs.append(valid) 132 | progbar.add(1) 133 | else: 134 | with multiprocessing.Pool(num_parallel_calls) as p: 135 | for valid in p.imap(map_fn, volume_filepaths, chunksize=2): 136 | outputs.append(valid) 137 | progbar.add(1) 138 | invalid_files = [ 139 | pair for valid, pair in zip(outputs, volume_filepaths) if not valid 140 | ] 141 | return invalid_files 142 | 143 | 144 | def _verify_features_nonscalar_labels( 145 | pair_of_paths, *, volume_shape, check_shape, check_labels_int, check_labels_gte_zero 146 | ): 147 | """Verify a pair of features and labels volumes.""" 148 | x = nib.load(pair_of_paths[0]) 149 | y = nib.load(pair_of_paths[1]) 150 | if check_shape: 151 | if not volume_shape: 152 | raise ValueError( 153 | "`volume_shape` must be specified if `check_shape` is true." 154 | ) 155 | if x.shape != volume_shape: 156 | return False 157 | if x.shape != y.shape: 158 | return False 159 | if check_labels_int: 160 | # Quick check of integer type. 161 | if not np.issubdtype(y.dataobj.dtype, np.integer): 162 | return False 163 | y = y.get_fdata(caching="unchanged", dtype=np.float32) 164 | # Longer check that all values in labels can be cast to int. 165 | if not np.all(np.mod(y, 1) == 0): 166 | return False 167 | if check_labels_gte_zero: 168 | if not np.all(y >= 0): 169 | return False 170 | return True 171 | 172 | 173 | def _verify_features_scalar_labels(path_scalar, *, volume_shape, check_shape): 174 | """Check that feature has the desired shape and that label is scalar.""" 175 | from nobrainer.tfrecord import _is_int_or_float 176 | 177 | feature, label = path_scalar 178 | x = nib.load(feature) 179 | if check_shape: 180 | if not volume_shape: 181 | raise ValueError( 182 | "`volume_shape` must be specified if `check_shape` is true." 183 | ) 184 | if x.shape != volume_shape: 185 | return False 186 | if not _is_int_or_float(label): 187 | return False 188 | return True 189 | 190 | 191 | def _is_gzipped(filepath, filesys=None): 192 | """Return True if the file is gzip-compressed, False otherwise.""" 193 | fs = filesys if filesys is not None else LocalFileSystem() 194 | with fs.open(filepath, "rb") as f: 195 | return f.read(2) == b"\x1f\x8b" 196 | -------------------------------------------------------------------------------- /nobrainer/layers/InstanceNorm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ..layers.groupnorm import GroupNormalization 4 | 5 | 6 | class InstanceNormalization(GroupNormalization): 7 | """Instance normalization layer. 8 | Instance Normalization is an specific case of ```GroupNormalization```since 9 | it normalizes all features of one channel. The Groupsize is equal to the 10 | channel size. Empirically, its accuracy is more stable than batch norm in a 11 | wide range of small batch sizes, if learning rate is adjusted linearly 12 | with batch sizes. 13 | Arguments 14 | axis: Integer, the axis that should be normalized. 15 | epsilon: Small float added to variance to avoid dividing by zero. 16 | center: If True, add offset of `beta` to normalized tensor. 17 | If False, `beta` is ignored. 18 | scale: If True, multiply by `gamma`. 19 | If False, `gamma` is not used. 20 | beta_initializer: Initializer for the beta weight. 21 | gamma_initializer: Initializer for the gamma weight. 22 | beta_regularizer: Optional regularizer for the beta weight. 23 | gamma_regularizer: Optional regularizer for the gamma weight. 24 | beta_constraint: Optional constraint for the beta weight. 25 | gamma_constraint: Optional constraint for the gamma weight. 26 | Input shape 27 | Arbitrary. Use the keyword argument `input_shape` 28 | (tuple of integers, does not include the samples axis) 29 | when using this layer as the first layer in a model. 30 | Output shape 31 | Same shape as input. 32 | References 33 | - [Instance Normalization: The Missing Ingredient for Fast Stylization] 34 | (https://arxiv.org/abs/1607.08022) 35 | """ 36 | 37 | def __init__(self, **kwargs): 38 | if "groups" in kwargs: 39 | logging.warning("The given value for groups will be overwritten.") 40 | 41 | kwargs["groups"] = -1 42 | super().__init__(**kwargs) 43 | -------------------------------------------------------------------------------- /nobrainer/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dropout import BernoulliDropout, ConcreteDropout, GaussianDropout 2 | from .padding import ZeroPadding3DChannels 3 | 4 | __all__ = [ 5 | "BernoulliDropout", 6 | "ConcreteDropout", 7 | "GaussianDropout", 8 | "ZeroPadding3DChannels", 9 | ] 10 | -------------------------------------------------------------------------------- /nobrainer/layers/max_pool4d.py: -------------------------------------------------------------------------------- 1 | # import tensorflow as tf 2 | # from tensorflow.python.ops import gen_nn_ops 3 | 4 | # def _get_sequence(value, n, channel_index, name): 5 | # """Formats a value input for gen_nn_ops.""" 6 | # # Performance is fast-pathed for common cases: 7 | # # `None`, `list`, `tuple` and `int`. 8 | # if value is None: 9 | # return [1] * (n + 2) 10 | 11 | # # Always convert `value` to a `list`. 12 | # if isinstance(value, list): 13 | # pass 14 | # elif isinstance(value, tuple): 15 | # value = list(value) 16 | # elif isinstance(value, int): 17 | # value = [value] 18 | # elif not isinstance(value, collections_abc.Sized): 19 | # value = [value] 20 | # else: 21 | # value = list(value) # Try casting to a list. 22 | 23 | # len_value = len(value) 24 | 25 | # # Fully specified, including batch and channel dims. 26 | # if len_value == n + 2: 27 | # return value 28 | 29 | # # Apply value to spatial dims only. 30 | # if len_value == 1: 31 | # value = value * n # Broadcast to spatial dimensions. 32 | # elif len_value != n: 33 | # raise ValueError(f"{name} should be of length 1, {n} or {n + 2}. " 34 | # f"Received: {name}={value} of length {len_value}") 35 | 36 | # # Add batch and channel dims (always 1). 37 | # if channel_index == 1: 38 | # return [1, 1] + value 39 | # else: 40 | # return [1] + value + [1] 41 | 42 | # @tf_export("nn.max_pool4d") 43 | # @dispatch.add_dispatch_support 44 | # def max_pool4d(input, ksize, strides, padding, data_format="NVDHWC", name=None): 45 | # """Performs the max pooling on the input. 46 | # Args: 47 | # input: A 6-D `Tensor` of the format specified by `data_format`. 48 | # ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of 49 | # the window for each dimension of the input tensor. 50 | # strides: An int or list of `ints` that has length `1`, `3` or `5`. The 51 | # stride of the sliding window for each dimension of the input tensor. 52 | # padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See 53 | # [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2) 54 | # for more information. 55 | # data_format: An optional string from: "NVDHWC", "NCVDHW". Defaults to "NVDHWC". 56 | # The data format of the input and output data. With the default format 57 | # "NVDHWC", the data is stored in the order of: [batch, in_depth, in_height, 58 | # in_width, in_channels]. Alternatively, the format could be "NCVDHW", the 59 | # data storage order is: [batch, in_channels, in_volumes, in_depth, in_height, 60 | # in_width]. 61 | # name: A name for the operation (optional). 62 | # Returns: 63 | # A `Tensor` of format specified by `data_format`. 64 | # The max pooled output tensor. 65 | # """ 66 | # with ops.name_scope(name, "MaxPool4D", [input]) as name: 67 | # if data_format is None: 68 | # data_format = "NVDHWC" 69 | # channel_index = 1 if data_format.startswith("NC") else 5 70 | 71 | # ksize = _get_sequence(ksize, 3, channel_index, "ksize") 72 | # strides = _get_sequence(strides, 3, channel_index, "strides") 73 | 74 | # return gen_nn_ops.max_pool4d( 75 | # input, 76 | # ksize=ksize, 77 | # strides=strides, 78 | # padding=padding, 79 | # data_format=data_format, 80 | # name=name) 81 | -------------------------------------------------------------------------------- /nobrainer/layers/padding.py: -------------------------------------------------------------------------------- 1 | """Custom padding layers for nobrainer.""" 2 | 3 | import tensorflow as tf 4 | from tensorflow.keras import layers 5 | 6 | 7 | class ZeroPadding3DChannels(layers.Layer): 8 | """Pad the last dimension of a 5D tensor symmetrically with zeros. 9 | 10 | This is meant for 3D convolutions, where tensors are 5D. 11 | """ 12 | 13 | def __init__(self, padding, **kwds): 14 | self.padding = padding 15 | # batch, x, y, z, channels 16 | self._paddings = [[0, 0], [0, 0], [0, 0], [0, 0], [self.padding, self.padding]] 17 | super(ZeroPadding3DChannels, self).__init__(**kwds) 18 | 19 | def call(self, x): 20 | return tf.pad(x, paddings=self._paddings, mode="CONSTANT", constant_values=0) 21 | -------------------------------------------------------------------------------- /nobrainer/layers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuronets/nobrainer/976691d685824fd4bba836498abea4184cffd798/nobrainer/layers/tests/__init__.py -------------------------------------------------------------------------------- /nobrainer/layers/tests/layers_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from ..padding import ZeroPadding3DChannels 5 | 6 | 7 | def test_zeropadding3dchannels(): 8 | # This test function is a much shorter version of 9 | # `tensorflow.python.keras.testing_utils.layer_test`. 10 | input_data_shape = (4, 32, 32, 32, 1) 11 | input_data = 10 * np.random.random(input_data_shape) 12 | 13 | x = tf.keras.layers.Input(shape=input_data_shape[1:], dtype=input_data.dtype) 14 | y = ZeroPadding3DChannels(4)(x) 15 | model = tf.keras.Model(x, y) 16 | 17 | actual_output = model.predict(input_data) 18 | actual_output_shape = actual_output.shape 19 | assert actual_output_shape == (4, 32, 32, 32, 9) 20 | assert not actual_output[..., :4].any() 21 | assert actual_output[..., 4].any() 22 | assert not actual_output[..., 5:].any() 23 | 24 | return actual_output 25 | -------------------------------------------------------------------------------- /nobrainer/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def average_volume_difference(): 5 | raise NotImplementedError() 6 | 7 | 8 | def dice(y_true, y_pred, axis=(1, 2, 3, 4)): 9 | """Calculate Dice similarity between labels and predictions. 10 | Dice similarity is in [0, 1], where 1 is perfect overlap and 0 is no 11 | overlap. If both labels and predictions are empty (e.g., all background), 12 | then Dice similarity is 1. 13 | If we assume the inputs are rank 5 [`(batch, x, y, z, classes)`], then an 14 | axis parameter of `(1, 2, 3)` will result in a tensor that contains a Dice 15 | score for every class in every item in the batch. The shape of this tensor 16 | will be `(batch, classes)`. If the inputs only have one class (e.g., binary 17 | segmentation), then an axis parameter of `(1, 2, 3, 4)` should be used. 18 | This will result in a tensor of shape `(batch,)`, where every value is the 19 | Dice similarity for that prediction. 20 | Implemented according to https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4533825/#Equ6 21 | Returns 22 | ------- 23 | Tensor of Dice similarities. 24 | Citations 25 | --------- 26 | Taha AA, Hanbury A. Metrics for evaluating 3D medical image segmentation: 27 | analysis, selection, and tool. BMC Med Imaging. 2015;15:29. Published 2015 28 | Aug 12. doi:10.1186/s12880-015-0068-x 29 | """ 30 | y_pred = tf.convert_to_tensor(y_pred) 31 | y_true = tf.cast(y_true, y_pred.dtype) 32 | eps = tf.keras.backend.epsilon() 33 | 34 | intersection = tf.reduce_sum(y_true * y_pred, axis=axis) 35 | summation = tf.reduce_sum(y_true, axis=axis) + tf.reduce_sum(y_pred, axis=axis) 36 | return (2 * intersection + eps) / (summation + eps) 37 | 38 | 39 | def generalized_dice(y_true, y_pred, axis=(1, 2, 3)): 40 | """Calculate Generalized Dice similarity. This is useful for multi-class 41 | predictions. 42 | If we assume the inputs are rank 5 [`(batch, x, y, z, classes)`], then an 43 | axis parameter of `(1, 2, 3)` should be used. This will result in a tensor 44 | of shape `(batch,)`, where every value is the Generalized Dice similarity 45 | for that prediction, across all classes. 46 | Returns 47 | ------- 48 | Tensor of Generalized Dice similarities. 49 | """ 50 | y_pred = tf.convert_to_tensor(y_pred) 51 | y_true = tf.cast(y_true, y_pred.dtype) 52 | 53 | if y_true.get_shape().ndims < 2 or y_pred.get_shape().ndims < 2: 54 | raise ValueError("y_true and y_pred must be at least rank 2.") 55 | 56 | epsilon = tf.keras.backend.epsilon() 57 | 58 | w = tf.math.reciprocal(tf.square(tf.reduce_sum(y_true, axis=axis))) 59 | w = tf.where(tf.math.is_finite(w), w, epsilon) 60 | num = 2 * tf.reduce_sum(w * tf.reduce_sum(y_true * y_pred, axis=axis), axis=-1) 61 | den = tf.reduce_sum(w * tf.reduce_sum(y_true + y_pred, axis=axis), axis=-1) 62 | gdice = (num + epsilon) / (den + epsilon) 63 | return gdice 64 | 65 | 66 | def hamming(y_true, y_pred, axis=(1, 2, 3)): 67 | y_pred = tf.convert_to_tensor(y_pred) 68 | y_true = tf.cast(y_true, y_pred.dtype) 69 | return tf.reduce_mean(tf.not_equal(y_pred, y_true), axis=axis) 70 | 71 | 72 | def haussdorf(): 73 | raise NotADirectoryError() 74 | 75 | 76 | def jaccard(y_true, y_pred, axis=(1, 2, 3, 4)): 77 | """Calculate Jaccard similarity between labels and predictions. 78 | Jaccard similarity is in [0, 1], where 1 is perfect overlap and 0 is no 79 | overlap. If both labels and predictions are empty (e.g., all background), 80 | then Jaccard similarity is 1. 81 | If we assume the inputs are rank 5 [`(batch, x, y, z, classes)`], then an 82 | axis parameter of `(1, 2, 3)` will result in a tensor that contains a Jaccard 83 | score for every class in every item in the batch. The shape of this tensor 84 | will be `(batch, classes)`. If the inputs only have one class (e.g., binary 85 | segmentation), then an axis parameter of `(1, 2, 3, 4)` should be used. 86 | This will result in a tensor of shape `(batch,)`, where every value is the 87 | Jaccard similarity for that prediction. 88 | Implemented according to https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4533825/#Equ7 89 | Returns 90 | ------- 91 | Tensor of Jaccard similarities. 92 | Citations 93 | --------- 94 | Taha AA, Hanbury A. Metrics for evaluating 3D medical image segmentation: 95 | analysis, selection, and tool. BMC Med Imaging. 2015;15:29. Published 2015 96 | Aug 12. doi:10.1186/s12880-015-0068-x 97 | """ 98 | y_pred = tf.convert_to_tensor(y_pred) 99 | y_true = tf.cast(y_true, y_pred.dtype) 100 | eps = tf.keras.backend.epsilon() 101 | 102 | intersection = tf.reduce_sum(y_true * y_pred, axis=axis) 103 | union = tf.reduce_sum(y_true, axis=axis) + tf.reduce_sum(y_pred, axis=axis) 104 | return (intersection + eps) / (union - intersection + eps) 105 | 106 | 107 | def tversky(y_true, y_pred, axis=(1, 2, 3), alpha=0.3, beta=0.7): 108 | y_pred = tf.convert_to_tensor(y_pred) 109 | y_true = tf.cast(y_true, y_pred.dtype) 110 | 111 | if y_true.get_shape().ndims < 2 or y_pred.get_shape().ndims < 2: 112 | raise ValueError("y_true and y_pred must be at least rank 2.") 113 | 114 | eps = tf.keras.backend.epsilon() 115 | 116 | num = tf.reduce_sum(y_pred * y_true, axis=axis) 117 | den = ( 118 | num 119 | + alpha * tf.reduce_sum(y_pred * (1 - y_true), axis=axis) 120 | + beta * tf.reduce_sum((1 - y_pred) * y_true, axis=axis) 121 | ) 122 | # Sum over classes. 123 | return tf.reduce_sum((num + eps) / (den + eps), axis=-1) 124 | -------------------------------------------------------------------------------- /nobrainer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from .attention_unet import attention_unet 4 | from .attention_unet_with_inception import attention_unet_with_inception 5 | from .autoencoder import autoencoder 6 | from .bayesian_meshnet import variational_meshnet 7 | from .dcgan import dcgan 8 | from .highresnet import highresnet 9 | from .meshnet import meshnet 10 | from .progressiveae import progressiveae 11 | from .progressivegan import progressivegan 12 | from .unet import unet 13 | from .unetr import unetr 14 | 15 | __all__ = ["get", "list_available_models"] 16 | 17 | _models = { 18 | "highresnet": highresnet, 19 | "meshnet": meshnet, 20 | "unet": unet, 21 | "autoencoder": autoencoder, 22 | "progressivegan": progressivegan, 23 | "progressiveae": progressiveae, 24 | "dcgan": dcgan, 25 | "attention_unet": attention_unet, 26 | "attention_unet_with_inception": attention_unet_with_inception, 27 | "unetr": unetr, 28 | "variational_meshnet": variational_meshnet, 29 | } 30 | 31 | 32 | def get(name): 33 | """Return callable that creates a particular `tf.keras.Model`. 34 | 35 | Parameters 36 | ---------- 37 | name: str, the name of the model (case-insensitive). 38 | 39 | Returns 40 | ------- 41 | Callable, which instantiates a `tf.keras.Model` object. 42 | """ 43 | if not isinstance(name, str): 44 | raise ValueError("Model name must be a string.") 45 | 46 | try: 47 | return _models[name.lower()] 48 | except KeyError: 49 | avail = ", ".join(_models.keys()) 50 | raise ValueError( 51 | "Unknown model: '{}'. Available models are {}.".format(name, avail) 52 | ) 53 | 54 | 55 | def available_models(): 56 | return list(_models) 57 | 58 | 59 | def list_available_models(): 60 | pprint(available_models()) 61 | -------------------------------------------------------------------------------- /nobrainer/models/attention_unet.py: -------------------------------------------------------------------------------- 1 | """Model definition for Attention U-Net. 2 | Adapted from https://github.com/nikhilroxtomar/Semantic-Segmentation-Architecture/blob/main/TensorFlow/attention-unet.py 3 | """ # noqa: E501 4 | 5 | from tensorflow.keras import layers 6 | import tensorflow.keras.layers as L 7 | from tensorflow.keras.models import Model 8 | 9 | 10 | def conv_block(x, num_filters): 11 | x = L.Conv3D(num_filters, 3, padding="same")(x) 12 | x = L.BatchNormalization()(x) 13 | x = L.Activation("relu")(x) 14 | 15 | x = L.Conv3D(num_filters, 3, padding="same")(x) 16 | x = L.BatchNormalization()(x) 17 | x = L.Activation("relu")(x) 18 | 19 | return x 20 | 21 | 22 | def encoder_block(x, num_filters): 23 | x = conv_block(x, num_filters) 24 | p = L.MaxPool3D()(x) 25 | return x, p 26 | 27 | 28 | def attention_gate(g, s, num_filters): 29 | Wg = L.Conv3D(num_filters, 1, padding="same")(g) 30 | Wg = L.BatchNormalization()(Wg) 31 | 32 | Ws = L.Conv3D(num_filters, 1, padding="same")(s) 33 | Ws = L.BatchNormalization()(Ws) 34 | 35 | out = L.Activation("relu")(Wg + Ws) 36 | out = L.Conv3D(num_filters, 1, padding="same")(out) 37 | out = L.Activation("sigmoid")(out) 38 | 39 | return out * s 40 | 41 | 42 | def decoder_block(x, s, num_filters): 43 | x = L.UpSampling3D()(x) 44 | s = attention_gate(x, s, num_filters) 45 | x = L.Concatenate()([x, s]) 46 | x = conv_block(x, num_filters) 47 | return x 48 | 49 | 50 | def attention_unet(n_classes, input_shape): 51 | """Inputs""" 52 | inputs = L.Input(input_shape) 53 | 54 | """ Encoder """ 55 | s1, p1 = encoder_block(inputs, 64) 56 | s2, p2 = encoder_block(p1, 128) 57 | s3, p3 = encoder_block(p2, 256) 58 | 59 | b1 = conv_block(p3, 512) 60 | 61 | """ Decoder """ 62 | d1 = decoder_block(b1, s3, 256) 63 | d2 = decoder_block(d1, s2, 128) 64 | d3 = decoder_block(d2, s1, 64) 65 | 66 | """ Outputs """ 67 | outputs = L.Conv3D(n_classes, 1, padding="same")(d3) 68 | 69 | final_activation = "sigmoid" if n_classes == 1 else "softmax" 70 | outputs = layers.Activation(final_activation)(outputs) 71 | 72 | """ Model """ 73 | return Model(inputs=inputs, outputs=outputs, name="Attention_U-Net") 74 | 75 | 76 | if __name__ == "__main__": 77 | n_classes = 50 78 | input_shape = (256, 256, 256, 3) 79 | model = attention_unet(n_classes, input_shape) 80 | model.summary() 81 | -------------------------------------------------------------------------------- /nobrainer/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | """Model definition for Autoencoder. 2 | """ 3 | 4 | import math 5 | 6 | from tensorflow.keras import layers, models 7 | 8 | 9 | def autoencoder( 10 | input_shape, 11 | encoding_dim=512, 12 | n_base_filters=16, 13 | batchnorm=True, 14 | batch_size=None, 15 | ): 16 | """Instantiate Autoencoder Architecture. 17 | 18 | Parameters 19 | ---------- 20 | input_shape: list or tuple of four ints, the shape of the input data. Should be 21 | scaled to [0,1]. Omit the batch dimension, and include the number of channels. 22 | Currently, only squares and cubes supported. 23 | encoding_dim: int, the dimensions of the encoding of the input data. This would 24 | translate to a latent code of dimensions encoding_dimx1. 25 | n_base_filters: int, number of base filters the models first convolutional layer. 26 | The subsequent layers have n_filters which are multiples of n_base_filters. 27 | batchnorm: bool, whether to use batch normalization in the network. 28 | batch_size: int, number of samples in each batch. This must be set when training on 29 | TPUs. 30 | name: str, name to give to the resulting model object. 31 | 32 | Returns 33 | ------- 34 | Model object. 35 | """ 36 | 37 | conv_kwds = {"kernel_size": 4, "activation": None, "padding": "same", "strides": 2} 38 | 39 | conv_transpose_kwds = { 40 | "kernel_size": 4, 41 | "strides": 2, 42 | "activation": None, 43 | "padding": "same", 44 | } 45 | 46 | dimensions = input_shape[:-1] 47 | n_dims = len(dimensions) 48 | 49 | if not (n_dims in [2, 3] and dimensions[1:] == dimensions[:-1]): 50 | raise ValueError("Dimensions should be of square or cube!") 51 | 52 | Conv = getattr(layers, "Conv{}D".format(n_dims)) 53 | ConvTranspose = getattr(layers, "Conv{}DTranspose".format(n_dims)) 54 | n_layers = int(math.log(dimensions[0], 2)) 55 | 56 | # Input layer 57 | inputs = x = layers.Input(shape=input_shape, batch_size=batch_size, name="inputs") 58 | 59 | # Encoder 60 | for i in range(n_layers): 61 | n_filters = min(n_base_filters * (2 ** (i)), encoding_dim) 62 | 63 | x = Conv(n_filters, **conv_kwds)(x) 64 | if batchnorm: 65 | x = layers.BatchNormalization()(x) 66 | x = layers.ReLU()(x) 67 | 68 | # Encoding of the input image 69 | x = layers.Flatten(name="Encoding")(x) 70 | 71 | # Decoder 72 | x = layers.Reshape((1,) * n_dims + (encoding_dim,))(x) 73 | for i in range(n_layers)[::-1]: 74 | n_filters = min(n_base_filters * (2 ** (i)), encoding_dim) 75 | 76 | x = ConvTranspose(n_filters, **conv_transpose_kwds)(x) 77 | if batchnorm: 78 | x = layers.BatchNormalization()(x) 79 | x = layers.LeakyReLU()(x) 80 | 81 | # Output layer 82 | outputs = Conv(1, 3, activation="sigmoid", padding="same")(x) 83 | 84 | return models.Model(inputs=inputs, outputs=outputs) 85 | -------------------------------------------------------------------------------- /nobrainer/models/bayesian_meshnet.py: -------------------------------------------------------------------------------- 1 | """Implementations of Bayesian neural networks.""" 2 | 3 | import tensorflow as tf 4 | import tensorflow_probability as tfp 5 | 6 | from ..bayesian_utils import divergence_fn_bayesian, prior_fn_for_bayesian 7 | from ..layers.dropout import BernoulliDropout, ConcreteDropout 8 | 9 | tfk = tf.keras 10 | tfkl = tfk.layers 11 | tfpl = tfp.layers 12 | tfd = tfp.distributions 13 | 14 | 15 | def variational_meshnet( 16 | n_classes, 17 | input_shape, 18 | receptive_field=67, 19 | filters=71, 20 | no_examples=3000, 21 | is_monte_carlo=False, 22 | dropout=None, 23 | activation=tf.nn.relu, 24 | batch_size=None, 25 | name="variational_meshnet", 26 | ): 27 | """Instantiate variational MeshNet model. 28 | 29 | Please see https://arxiv.org/abs/1805.10863 for Meshnet related information. 30 | 31 | Parameters 32 | ---------- 33 | n_classes: int, number of classes to classify. For binary applications, use 34 | a value of 1. 35 | input_shape: list or tuple of four ints, the shape of the input data. Omit 36 | the batch dimension, and include the number of channels. 37 | receptive_field: {37, 67, 129}, the receptive field of the model. According 38 | to the MeshNet manuscript, the receptive field should be similar to your 39 | input shape. The actual receptive field is the cube of the value provided. 40 | filters: int, number of filters per volumetric convolution. The original 41 | MeshNet manuscript uses 21 filters for a binary segmentation task 42 | (i.e., brain extraction) and 71 filters for a multi-class segmentation task. 43 | activation: str or optimizer object, the non-linearity to use. 44 | scale_factor: A tf float 32 variable to scale up the KLD loss. 45 | is_monte_carlo: bool, only Related to dropout version! 46 | dropout: string, type of dropout layer. 47 | batch_size: int, number of samples in each batch. This must be set when 48 | training on TPUs. 49 | name: str, name to give to the resulting model object. 50 | Set priors, divergence and posteriors for training with this model. 51 | 52 | Returns 53 | ------- 54 | Model object. 55 | 56 | Raises 57 | ------ 58 | ValueError if receptive field is not an allowable value. 59 | """ 60 | 61 | if receptive_field not in {37, 67, 129}: 62 | raise ValueError("unknown receptive field. Legal values are 37, 67, and 129.") 63 | 64 | def one_layer(x, layer_num, no_examples=3000, dilation_rate=(1, 1, 1)): 65 | x = tfpl.Convolution3DFlipout( 66 | filters, 67 | kernel_size=3, 68 | padding="same", 69 | dilation_rate=dilation_rate, 70 | kernel_prior_fn=prior_fn_for_bayesian(), 71 | kernel_divergence_fn=divergence_fn_bayesian( 72 | prior_std=1.0, examples_per_epoch=no_examples 73 | ), 74 | name="layer{}/vwnconv3d".format(layer_num), 75 | )(x) 76 | if dropout is None: 77 | pass 78 | elif dropout == "bernoulli": 79 | x = BernoulliDropout( 80 | rate=0.5, 81 | is_monte_carlo=is_monte_carlo, 82 | scale_during_training=False, 83 | name="layer{}/bernoulli_dropout".format(layer_num), 84 | )(x) 85 | elif dropout == "concrete": 86 | x = ConcreteDropout( 87 | is_monte_carlo=is_monte_carlo, 88 | temperature=0.02, 89 | use_expectation=is_monte_carlo, 90 | name="layer{}/concrete_dropout".format(layer_num), 91 | )(x) 92 | else: 93 | raise ValueError("unknown dropout layer, {}".format(dropout)) 94 | x = tfkl.Activation(activation, name="layer{}/activation".format(layer_num))(x) 95 | return x 96 | 97 | inputs = tfkl.Input(shape=input_shape, batch_size=batch_size, name="inputs") 98 | 99 | if receptive_field == 37: 100 | x = one_layer(inputs, 1) 101 | x = one_layer(x, 2) 102 | x = one_layer(x, 3) 103 | x = one_layer(x, 4, dilation_rate=(2, 2, 2)) 104 | x = one_layer(x, 5, dilation_rate=(4, 4, 4)) 105 | x = one_layer(x, 6, dilation_rate=(8, 8, 8)) 106 | x = one_layer(x, 7) 107 | elif receptive_field == 67: 108 | x = one_layer(inputs, 1) 109 | x = one_layer(x, 2) 110 | x = one_layer(x, 3, dilation_rate=(2, 2, 2)) 111 | x = one_layer(x, 4, dilation_rate=(4, 4, 4)) 112 | x = one_layer(x, 5, dilation_rate=(8, 8, 8)) 113 | x = one_layer(x, 6, dilation_rate=(16, 16, 16)) 114 | x = one_layer(x, 7) 115 | elif receptive_field == 129: 116 | x = one_layer(inputs, 1) 117 | x = one_layer(x, 2, dilation_rate=(2, 2, 2)) 118 | x = one_layer(x, 3, dilation_rate=(4, 4, 4)) 119 | x = one_layer(x, 4, dilation_rate=(8, 8, 8)) 120 | x = one_layer(x, 5, dilation_rate=(16, 16, 16)) 121 | x = one_layer(x, 6, dilation_rate=(32, 32, 32)) 122 | x = one_layer(x, 7) 123 | 124 | x = tfpl.Convolution3DFlipout( 125 | filters=n_classes, 126 | kernel_size=1, 127 | padding="same", 128 | name="classification/vwnconv3d", 129 | )(x) 130 | 131 | final_activation = "sigmoid" if n_classes == 1 else "softmax" 132 | x = tfkl.Activation(final_activation, name="segmentation/activation")(x) 133 | 134 | return tf.keras.Model(inputs=inputs, outputs=x, name=name) 135 | -------------------------------------------------------------------------------- /nobrainer/models/bayesian_vnet_semi.py: -------------------------------------------------------------------------------- 1 | # Model definition for a Semi-Bayesian VNet with deterministic 2 | # encoder and Bayesian decoder 3 | from tensorflow.keras.layers import ( 4 | Conv3D, 5 | Input, 6 | MaxPooling3D, 7 | UpSampling3D, 8 | concatenate, 9 | ) 10 | from tensorflow.keras.models import Model 11 | import tensorflow_probability as tfp 12 | 13 | from ..bayesian_utils import prior_fn_for_bayesian 14 | from ..layers.groupnorm import GroupNormalization 15 | 16 | tfd = tfp.distributions 17 | 18 | 19 | def down_stage(inputs, filters, kernel_size=3, activation="relu", padding="SAME"): 20 | """encoding blocks of the Semi-Bayesian VNet model. 21 | 22 | Parameters 23 | ---------- 24 | inputs: tf.layer for encoding stage. 25 | filters: list or tuple of four ints, the shape of the input data. Omit 26 | the batch dimension, and include the number of channels. 27 | kernal_size: int, size of the kernel of conv layers. Default kernel size 28 | is set to be 3. 29 | activation: str or optimizer object, the non-linearity to use. All 30 | tf.activations are allowed to use 31 | 32 | Returns 33 | ---------- 34 | encoding module 35 | """ 36 | conv = Conv3D(filters, kernel_size, activation=activation, padding=padding)(inputs) 37 | conv = GroupNormalization()(conv) 38 | conv = Conv3D(filters, kernel_size, activation=activation, padding=padding)(conv) 39 | conv = GroupNormalization()(conv) 40 | pool = MaxPooling3D()(conv) 41 | return conv, pool 42 | 43 | 44 | def up_stage( 45 | inputs, 46 | skip, 47 | filters, 48 | prior_fn, 49 | kernel_posterior_fn, 50 | kld, 51 | kernel_size=3, 52 | activation="relu", 53 | padding="SAME", 54 | ): 55 | """decoding blocks of the Semi-Bayesian VNet model. 56 | 57 | Parameters 58 | ---------- 59 | inputs: tf.layer for encoding stage. 60 | skip: setting skip connections 61 | kld: a func to compute KL Divergence loss, default is set None. 62 | KLD can be set as (lambda q, p, ignore: kl_lib.kl_divergence(q, p)) 63 | prior_fn: a func to initialize priors distributions 64 | kernel_posterior_fn:a func to initlaize kernel posteriors 65 | (loc, scale and weightnorms) 66 | filters: list or tuple of four ints, the shape of the input data. Omit 67 | the batch dimension, and include the number of channels. 68 | kernal_size: int, size of the kernel of conv layers. Default kernel size 69 | is set to be 3. 70 | activation: str or optimizer object, the non-linearity to use. All 71 | tf.activations are allowed to use 72 | 73 | Returns 74 | ---------- 75 | decoding module. 76 | """ 77 | up = UpSampling3D()(inputs) 78 | up = tfp.layers.Convolution3DFlipout( 79 | filters, 80 | 2, 81 | activation=activation, 82 | padding=padding, 83 | kernel_divergence_fn=kld, 84 | kernel_posterior_fn=kernel_posterior_fn, 85 | kernel_prior_fn=prior_fn, 86 | )(up) 87 | up = GroupNormalization()(up) 88 | 89 | merge = concatenate([skip, up]) 90 | merge = GroupNormalization()(merge) 91 | 92 | conv = tfp.layers.Convolution3DFlipout( 93 | filters, 94 | kernel_size, 95 | activation=activation, 96 | padding=padding, 97 | kernel_divergence_fn=kld, 98 | kernel_posterior_fn=kernel_posterior_fn, 99 | kernel_prior_fn=prior_fn, 100 | )(merge) 101 | conv = GroupNormalization()(conv) 102 | conv = tfp.layers.Convolution3DFlipout( 103 | filters, 104 | kernel_size, 105 | activation=activation, 106 | padding=padding, 107 | kernel_divergence_fn=kld, 108 | kernel_posterior_fn=kernel_posterior_fn, 109 | kernel_prior_fn=prior_fn, 110 | )(conv) 111 | conv = GroupNormalization()(conv) 112 | 113 | return conv 114 | 115 | 116 | def end_stage( 117 | inputs, 118 | prior_fn, 119 | kernel_posterior_fn, 120 | kld, 121 | n_classes=1, 122 | kernel_size=3, 123 | activation="relu", 124 | padding="SAME", 125 | ): 126 | """last logit layer of Semi-Bayesian VNet. 127 | 128 | Parameters 129 | ---------- 130 | inputs: tf.model layer. 131 | kld: a func to compute KL Divergence loss, default is set None. 132 | KLD can be set as (lambda q, p, ignore: kl_lib.kl_divergence(q, p)) 133 | prior_fn: a func to initialize priors distributions 134 | kernel_posterior_fn:a func to initlaize kernel posteriors 135 | (loc, scale and weightnorms) 136 | n_classes: int, for binary class use the value 1. 137 | kernal_size: int, size of the kernel of conv layers. Default kernel size 138 | is set to be 3. 139 | activation: str or optimizer object, the non-linearity to use. All 140 | tf.activations are allowed to use 141 | 142 | Result 143 | ---------- 144 | prediction probabilities. 145 | """ 146 | conv = tfp.layers.Convolution3DFlipout( 147 | n_classes, 148 | kernel_size, 149 | activation=activation, 150 | padding="SAME", 151 | kernel_divergence_fn=kld, 152 | kernel_posterior_fn=kernel_posterior_fn, 153 | kernel_prior_fn=prior_fn, 154 | )(inputs) 155 | if n_classes == 1: 156 | conv = tfp.layers.Convolution3DFlipout( 157 | n_classes, 158 | 1, 159 | activation="sigmoid", 160 | kernel_divergence_fn=kld, 161 | kernel_posterior_fn=kernel_posterior_fn, 162 | kernel_prior_fn=prior_fn, 163 | )(conv) 164 | else: 165 | conv = tfp.layers.Convolution3DFlipout( 166 | n_classes, 167 | 1, 168 | activation="softmax", 169 | kernel_divergence_fn=kld, 170 | kernel_posterior_fn=kernel_posterior_fn, 171 | kernel_prior_fn=prior_fn, 172 | )(conv) 173 | return conv 174 | 175 | 176 | def bayesian_vnet_semi( 177 | n_classes=1, 178 | input_shape=(256, 256, 256, 1), 179 | kernel_size=3, 180 | prior_fn=prior_fn_for_bayesian(), 181 | kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(), 182 | kld=None, 183 | activation="relu", 184 | padding="SAME", 185 | ): 186 | """Instantiate a 3D Semi-Bayesian VNet Architecture. 187 | 188 | Adapted from Deterministic VNet: https://arxiv.org/pdf/1606.04797.pdf 189 | Encoder has 3D Convolutional layers and Decoder has 3D 190 | Flipout(variational layers). 191 | 192 | Parameters 193 | ---------- 194 | n_classes: int, number of classes to classify. For binary applications, use 195 | a value of 1. 196 | input_shape: list or tuple of four ints, the shape of the input data. Omit 197 | the batch dimension, and include the number of channels. 198 | kernal_size(int): size of the kernel of conv layers 199 | activation(str): all tf.keras.activations are allowed 200 | kld: a func to compute KL Divergence loss, default is set None. 201 | KLD can be set as (lambda q, p, ignore: kl_lib.kl_divergence(q, p)) 202 | prior_fn: a func to initialize priors distributions 203 | kernel_posterior_fn:a func to initlaize kernel posteriors 204 | (loc, scale and weightnorms) 205 | See Bayesian Utils for more options for kld, prior_fn and kernal_posterior_fn 206 | activation: str or optimizer object, the non-linearity to use. All 207 | tf.activations are allowed to use. 208 | 209 | Returns 210 | ---------- 211 | Bayesian model object. 212 | """ 213 | inputs = Input(input_shape) 214 | 215 | conv1, pool1 = down_stage( 216 | inputs, 16, kernel_size=kernel_size, activation=activation, padding=padding 217 | ) 218 | conv2, pool2 = down_stage( 219 | pool1, 32, kernel_size=kernel_size, activation=activation, padding=padding 220 | ) 221 | conv3, pool3 = down_stage( 222 | pool2, 64, kernel_size=kernel_size, activation=activation, padding=padding 223 | ) 224 | conv4, _ = down_stage( 225 | pool3, 128, kernel_size=kernel_size, activation=activation, padding=padding 226 | ) 227 | 228 | conv5 = up_stage( 229 | conv4, 230 | conv3, 231 | 64, 232 | prior_fn, 233 | kernel_posterior_fn, 234 | kld, 235 | kernel_size=kernel_size, 236 | activation=activation, 237 | padding=padding, 238 | ) 239 | conv6 = up_stage( 240 | conv5, 241 | conv2, 242 | 32, 243 | prior_fn, 244 | kernel_posterior_fn, 245 | kld, 246 | kernel_size=kernel_size, 247 | activation=activation, 248 | padding=padding, 249 | ) 250 | conv7 = up_stage( 251 | conv6, 252 | conv1, 253 | 16, 254 | prior_fn, 255 | kernel_posterior_fn, 256 | kld, 257 | kernel_size=kernel_size, 258 | activation=activation, 259 | padding=padding, 260 | ) 261 | 262 | conv8 = end_stage( 263 | conv7, 264 | prior_fn, 265 | kernel_posterior_fn, 266 | kld, 267 | n_classes=n_classes, 268 | kernel_size=kernel_size, 269 | activation=activation, 270 | padding=padding, 271 | ) 272 | 273 | return Model(inputs=inputs, outputs=conv8) 274 | -------------------------------------------------------------------------------- /nobrainer/models/brainsiam.py: -------------------------------------------------------------------------------- 1 | """ 2 | SimSiam network architecture for 3D brain volumes 3 | Ref: [Simple Siamese Representation Learning by Chen et al.](https://arxiv.org/abs/2011.10566). 4 | author: Dhritiman Das 5 | """ 6 | 7 | import tensorflow as tf 8 | from tensorflow.keras import layers, regularizers 9 | 10 | import nobrainer 11 | 12 | 13 | def brainsiam( 14 | n_classes, 15 | input_shape, 16 | weight_decay=0.0005, 17 | projection_dim=2048, 18 | latent_dim=512, 19 | name="brainsiam", 20 | **kwargs 21 | ): 22 | """Instantiate Brain Siamese Network model.""" 23 | 24 | """ Parameters 25 | ------------------ 26 | input_shape: list or tuple of four ints, the shape of the input data. Should be 27 | scaled to [0,1]. Omit the batch dimension, and include the number of channels. 28 | Currently, only squares and cubes supported. 29 | projection_dim: int, the dimensions of the encoding of the input data. 30 | latent_dim: int, the dimensions of the latent space of the input data. 31 | n_classes: int, the number of classes in the input data. 32 | weight_decay: float, rate of decay for weights in the the l2 regularizer 33 | 34 | Returns 35 | ------- 36 | 2 Model objects: encoder, predictor. 37 | """ 38 | 39 | print("projection dimension is: ", projection_dim) 40 | print("latent dimension: ", latent_dim) 41 | 42 | def encoder(): 43 | resnet = nobrainer.models.highresnet( 44 | n_classes=n_classes, input_shape=input_shape 45 | ) 46 | 47 | input = tf.keras.layers.Input(shape=input_shape) 48 | 49 | resnet_out = resnet(input) 50 | 51 | x = layers.GlobalAveragePooling3D(name="backbone_pool")(resnet_out) 52 | 53 | x = layers.Dense( 54 | projection_dim, 55 | use_bias=False, 56 | kernel_regularizer=regularizers.l2(weight_decay), 57 | )(x) 58 | x = layers.BatchNormalization()(x) 59 | x = layers.LeakyReLU()(x) 60 | x = layers.Dense( 61 | projection_dim, 62 | use_bias=False, 63 | kernel_regularizer=regularizers.l2(weight_decay), 64 | )(x) 65 | output = layers.BatchNormalization()(x) 66 | 67 | encoder_model = tf.keras.Model(input, output, name="encoder") 68 | return encoder_model 69 | 70 | def predictor(): 71 | predictor_model = tf.keras.Sequential( 72 | [ 73 | # Note the AutoEncoder-like structure. 74 | tf.keras.layers.InputLayer((projection_dim,)), 75 | tf.keras.layers.Dense( 76 | latent_dim, 77 | use_bias=False, 78 | kernel_regularizer=regularizers.l2(weight_decay), 79 | ), 80 | tf.keras.layers.LeakyReLU(), 81 | tf.keras.layers.BatchNormalization(), 82 | tf.keras.layers.Dense(projection_dim), 83 | ], 84 | name="predictor", 85 | ) 86 | 87 | return predictor_model 88 | 89 | encoder_model = encoder() 90 | predictor_model = predictor() 91 | 92 | encoder_model.summary() 93 | predictor_model.summary() 94 | 95 | return encoder_model, predictor_model 96 | -------------------------------------------------------------------------------- /nobrainer/models/dcgan.py: -------------------------------------------------------------------------------- 1 | """Model definition for DCGAN. 2 | """ 3 | 4 | import math 5 | 6 | from tensorflow.keras import layers, models 7 | 8 | 9 | def dcgan( 10 | output_shape, 11 | z_dim=256, 12 | n_base_filters=16, 13 | batchnorm=True, 14 | batch_size=None, 15 | name="dcgan", 16 | ): 17 | """Instantiate DCGAN Architecture. 18 | 19 | Parameters 20 | ---------- 21 | output_shape: list or tuple of four ints, the shape of the output images. Should be 22 | scaled to [0,1]. Omit the batch dimension, and include the number of channels. 23 | Currently, only squares and cubes supported. 24 | z_dim: int, the dimensions of the encoding of the latent code. This would translate 25 | to a latent code of dimensions encoding_dimx1. 26 | n_base_filters: int, number of base filters the models first convolutional layer. 27 | The subsequent layers have n_filters which are multiples of n_base_filters. 28 | batchnorm: bool, whether to use batch normalization in the network. 29 | batch_size: int, number of samples in each batch. This must be set when 30 | training on TPUs. 31 | name: str, name to give to the resulting model object. 32 | 33 | Returns 34 | ------- 35 | Generator Model object. 36 | Discriminator Model object. 37 | """ 38 | 39 | conv_kwds = {"kernel_size": 4, "activation": None, "padding": "same", "strides": 2} 40 | 41 | conv_transpose_kwds = { 42 | "kernel_size": 4, 43 | "strides": 2, 44 | "activation": None, 45 | "padding": "same", 46 | } 47 | 48 | dimensions = output_shape[:-1] 49 | n_dims = len(dimensions) 50 | 51 | if not (n_dims in [2, 3] and dimensions[1:] == dimensions[:-1]): 52 | raise ValueError("Dimensions should be of square or cube!") 53 | 54 | Conv = getattr(layers, "Conv{}D".format(n_dims)) 55 | ConvTranspose = getattr(layers, "Conv{}DTranspose".format(n_dims)) 56 | n_layers = int(math.log(dimensions[0], 2)) 57 | 58 | # Generator 59 | z_input = layers.Input(shape=(z_dim,), batch_size=batch_size) 60 | 61 | project = layers.Dense(pow(4, n_dims) * z_dim)(z_input) 62 | project = layers.ReLU()(project) 63 | project = layers.Reshape((4,) * n_dims + (z_dim,))(project) 64 | x = project 65 | 66 | for i in range(n_layers - 2)[::-1]: 67 | n_filters = min(n_base_filters * (2 ** (i)), z_dim) 68 | 69 | x = ConvTranspose(n_filters, **conv_transpose_kwds)(x) 70 | if batchnorm: 71 | x = layers.BatchNormalization()(x) 72 | x = layers.LeakyReLU()(x) 73 | 74 | outputs = Conv(1, 3, activation="sigmoid", padding="same")(x) 75 | 76 | generator = models.Model( 77 | inputs=[z_input], outputs=[outputs], name=name + "_generator" 78 | ) 79 | 80 | # PatchGAN Discriminator with output of 8x8(x8) 81 | inputs = layers.Input(shape=(output_shape), batch_size=batch_size) 82 | x = inputs 83 | for i in range(n_layers - 3): 84 | n_filters = min(n_base_filters * (2 ** (i)), z_dim) 85 | 86 | x = Conv(n_filters, **conv_kwds)(x) 87 | if batchnorm: 88 | x = layers.BatchNormalization()(x) 89 | x = layers.ReLU()(x) 90 | 91 | pred = Conv(1, 3, padding="same", activation="sigmoid")(x) 92 | 93 | discriminator = models.Model( 94 | inputs=[inputs], outputs=[pred], name=name + "_discriminator" 95 | ) 96 | 97 | return generator, discriminator 98 | -------------------------------------------------------------------------------- /nobrainer/models/highresnet.py: -------------------------------------------------------------------------------- 1 | """Model definition for HighResNet. 2 | """ 3 | 4 | import tensorflow as tf 5 | from tensorflow.keras import layers 6 | 7 | from ..layers.padding import ZeroPadding3DChannels 8 | 9 | 10 | def highresnet( 11 | n_classes, input_shape, activation="relu", dropout_rate=0, name="highresnet" 12 | ): 13 | """ 14 | Instantiate a 3D HighResnet Architecture. 15 | Implementation is according to the 16 | https://arxiv.org/abs/1707.01992 17 | Args: 18 | n_classes(int): number of classes 19 | input_shape(tuple):four ints representing the shape of 3D input 20 | activation(str): all tf.keras.activations are allowed 21 | dropout_rate(int): [0,1]. 22 | """ 23 | 24 | conv_kwds = {"kernel_size": (3, 3, 3), "padding": "same"} 25 | 26 | n_base_filters = 16 27 | 28 | inputs = layers.Input(shape=input_shape) 29 | x = layers.Conv3D(n_base_filters, **conv_kwds)(inputs) 30 | 31 | for ii in range(3): 32 | skip = x 33 | x = layers.BatchNormalization()(x) 34 | x = layers.Activation(activation)(x) 35 | x = layers.Conv3D(n_base_filters, **conv_kwds)(x) 36 | x = layers.BatchNormalization()(x) 37 | x = layers.Activation(activation)(x) 38 | x = layers.Conv3D(n_base_filters, **conv_kwds)(x) 39 | x = layers.Add()([x, skip]) 40 | 41 | x = ZeroPadding3DChannels(8)(x) 42 | for ii in range(3): 43 | skip = x 44 | x = layers.BatchNormalization()(x) 45 | x = layers.Activation(activation)(x) 46 | x = layers.Conv3D(n_base_filters * 2, dilation_rate=2, **conv_kwds)(x) 47 | x = layers.BatchNormalization()(x) 48 | x = layers.Activation(activation)(x) 49 | x = layers.Conv3D(n_base_filters * 2, dilation_rate=2, **conv_kwds)(x) 50 | x = layers.Add()([x, skip]) 51 | 52 | x = ZeroPadding3DChannels(16)(x) 53 | for ii in range(3): 54 | skip = x 55 | x = layers.BatchNormalization()(x) 56 | x = layers.Activation(activation)(x) 57 | x = layers.Conv3D(n_base_filters * 4, dilation_rate=4, **conv_kwds)(x) 58 | x = layers.BatchNormalization()(x) 59 | x = layers.Activation(activation)(x) 60 | x = layers.Conv3D(n_base_filters * 4, dilation_rate=4, **conv_kwds)(x) 61 | x = layers.Add()([x, skip]) 62 | 63 | x = layers.Conv3D(filters=n_classes, kernel_size=(1, 1, 1), padding="same")(x) 64 | 65 | final_activation = "sigmoid" if n_classes == 1 else "softmax" 66 | x = layers.Activation(final_activation)(x) 67 | 68 | # QUESTION: where should dropout go? 69 | 70 | return tf.keras.Model(inputs=inputs, outputs=x, name=name) 71 | -------------------------------------------------------------------------------- /nobrainer/models/meshnet.py: -------------------------------------------------------------------------------- 1 | """Model definition for MeshNet. 2 | 3 | Implemented according to the [MeshNet manuscript](https://arxiv.org/abs/1612.00940) 4 | """ 5 | 6 | import tensorflow as tf 7 | from tensorflow.keras import layers 8 | 9 | 10 | def meshnet( 11 | n_classes, 12 | input_shape, 13 | receptive_field=67, 14 | filters=71, 15 | activation="relu", 16 | dropout_rate=0.25, 17 | batch_size=None, 18 | name="meshnet", 19 | ): 20 | """Instantiate MeshNet model. 21 | 22 | Parameters 23 | ---------- 24 | n_classes: int, number of classes to classify. For binary applications, use 25 | a value of 1. 26 | input_shape: list or tuple of four ints, the shape of the input data. Omit 27 | the batch dimension, and include the number of channels. 28 | receptive_field: {37, 67, 129}, the receptive field of the model. According 29 | to the MeshNet manuscript, the receptive field should be similar to your 30 | input shape. The actual receptive field is the cube of the value provided. 31 | filters: int, number of filters per volumetric convolution. The original 32 | MeshNet manuscript uses 21 filters for a binary segmentation task 33 | (i.e., brain extraction) and 71 filters for a multi-class segmentation task. 34 | activation: str or optimizer object, the non-linearity to use. 35 | dropout_rate: float between 0 and 1, the fraction of input units to drop. 36 | batch_size: int, number of samples in each batch. This must be set when 37 | training on TPUs. 38 | name: str, name to give to the resulting model object. 39 | 40 | Returns 41 | ------- 42 | Model object. 43 | 44 | Raises 45 | ------ 46 | ValueError if receptive field is not an allowable value. 47 | """ 48 | 49 | if receptive_field not in {37, 67, 129}: 50 | raise ValueError("unknown receptive field. Legal values are 37, 67, and 129.") 51 | 52 | def one_layer(x, layer_num, dilation_rate=(1, 1, 1)): 53 | x = layers.Conv3D( 54 | filters, 55 | kernel_size=(3, 3, 3), 56 | padding="same", 57 | dilation_rate=dilation_rate, 58 | name="layer{}/conv3d".format(layer_num), 59 | )(x) 60 | x = layers.BatchNormalization(name="layer{}/batchnorm".format(layer_num))(x) 61 | x = layers.Activation(activation, name="layer{}/activation".format(layer_num))( 62 | x 63 | ) 64 | x = layers.Dropout(dropout_rate, name="layer{}/dropout".format(layer_num))(x) 65 | return x 66 | 67 | inputs = layers.Input(shape=input_shape, batch_size=batch_size, name="inputs") 68 | 69 | if receptive_field == 37: 70 | x = one_layer(inputs, 1) 71 | x = one_layer(x, 2) 72 | x = one_layer(x, 3) 73 | x = one_layer(x, 4, dilation_rate=(2, 2, 2)) 74 | x = one_layer(x, 5, dilation_rate=(4, 4, 4)) 75 | x = one_layer(x, 6, dilation_rate=(8, 8, 8)) 76 | x = one_layer(x, 7) 77 | elif receptive_field == 67: 78 | x = one_layer(inputs, 1) 79 | x = one_layer(x, 2) 80 | x = one_layer(x, 3, dilation_rate=(2, 2, 2)) 81 | x = one_layer(x, 4, dilation_rate=(4, 4, 4)) 82 | x = one_layer(x, 5, dilation_rate=(8, 8, 8)) 83 | x = one_layer(x, 6, dilation_rate=(16, 16, 16)) 84 | x = one_layer(x, 7) 85 | elif receptive_field == 129: 86 | x = one_layer(inputs, 1) 87 | x = one_layer(x, 2, dilation_rate=(2, 2, 2)) 88 | x = one_layer(x, 3, dilation_rate=(4, 4, 4)) 89 | x = one_layer(x, 4, dilation_rate=(8, 8, 8)) 90 | x = one_layer(x, 5, dilation_rate=(16, 16, 16)) 91 | x = one_layer(x, 6, dilation_rate=(32, 32, 32)) 92 | x = one_layer(x, 7) 93 | 94 | x = layers.Conv3D( 95 | filters=n_classes, 96 | kernel_size=(1, 1, 1), 97 | padding="same", 98 | name="classification/conv3d", 99 | )(x) 100 | 101 | final_activation = "sigmoid" if n_classes == 1 else "softmax" 102 | x = layers.Activation(final_activation, name="classification/activation")(x) 103 | 104 | return tf.keras.Model(inputs=inputs, outputs=x, name=name) 105 | -------------------------------------------------------------------------------- /nobrainer/models/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuronets/nobrainer/976691d685824fd4bba836498abea4184cffd798/nobrainer/models/tests/__init__.py -------------------------------------------------------------------------------- /nobrainer/models/tests/models_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | import tensorflow as tf 6 | 7 | from nobrainer.bayesian_utils import default_mean_field_normal_fn 8 | 9 | from ..attention_unet import attention_unet 10 | from ..attention_unet_with_inception import attention_unet_with_inception 11 | from ..autoencoder import autoencoder 12 | from ..bayesian_meshnet import variational_meshnet 13 | from ..bayesian_vnet import bayesian_vnet 14 | from ..bayesian_vnet_semi import bayesian_vnet_semi 15 | from ..brainsiam import brainsiam 16 | from ..dcgan import dcgan 17 | from ..highresnet import highresnet 18 | from ..meshnet import meshnet 19 | from ..progressivegan import progressivegan 20 | from ..unet import unet 21 | from ..unet_lstm import unet_lstm 22 | from ..unetr import unetr 23 | from ..vnet import vnet 24 | from ..vox2vox import Vox_ensembler, vox_gan 25 | 26 | IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" 27 | 28 | 29 | def model_test(model_cls, n_classes, input_shape, kwds={}): 30 | """Tests for models.""" 31 | x = 10 * np.random.random(input_shape) 32 | y = np.random.choice([True, False], input_shape) 33 | 34 | # Assume every model class has n_classes and input_shape arguments. 35 | model = model_cls(n_classes=n_classes, input_shape=input_shape[1:], **kwds) 36 | model.compile(tf.optimizers.Adam(), "binary_crossentropy") 37 | model.fit(x, y) 38 | 39 | actual_output = model.predict(x) 40 | assert actual_output.shape == x.shape[:-1] + (n_classes,) 41 | 42 | 43 | def test_highresnet(): 44 | model_test(highresnet, n_classes=1, input_shape=(1, 32, 32, 32, 1)) 45 | 46 | 47 | def test_meshnet(): 48 | model_test( 49 | meshnet, 50 | n_classes=1, 51 | input_shape=(1, 32, 32, 32, 1), 52 | kwds={"receptive_field": 37}, 53 | ) 54 | model_test( 55 | meshnet, 56 | n_classes=1, 57 | input_shape=(1, 32, 32, 32, 1), 58 | kwds={"receptive_field": 67}, 59 | ) 60 | model_test( 61 | meshnet, 62 | n_classes=1, 63 | input_shape=(1, 32, 32, 32, 1), 64 | kwds={"receptive_field": 129}, 65 | ) 66 | with pytest.raises(ValueError): 67 | model_test( 68 | meshnet, 69 | n_classes=1, 70 | input_shape=(1, 32, 32, 32, 1), 71 | kwds={"receptive_field": 50}, 72 | ) 73 | 74 | 75 | def test_unet(): 76 | model_test(unet, n_classes=1, input_shape=(1, 32, 32, 32, 1)) 77 | 78 | 79 | def test_autoencoder(): 80 | """Special test for autoencoder.""" 81 | 82 | input_shape = (1, 32, 32, 32, 1) 83 | x = 10 * np.random.random(input_shape) 84 | 85 | model = autoencoder(input_shape[1:], encoding_dim=128, n_base_filters=32) 86 | model.compile(tf.optimizers.Adam(), "mse") 87 | model.fit(x, x) 88 | 89 | actual_output = model.predict(x) 90 | assert actual_output.shape == x.shape 91 | 92 | 93 | def test_brainsiam(): 94 | """Testing the encoder-projector and predictor structures of the brainsiam architecture""" 95 | input_shape = (1, 32, 32, 32, 1) 96 | x = 10 * np.random.random(input_shape) 97 | 98 | n_classes = 1 99 | weight_decay = 0.0005 100 | projection_dim = 2048 101 | latent_dim = 512 102 | 103 | encoder, predictor = brainsiam( 104 | n_classes, 105 | input_shape=input_shape[1:], 106 | weight_decay=weight_decay, 107 | projection_dim=projection_dim, 108 | latent_dim=latent_dim, 109 | ) 110 | 111 | encoder_output = encoder(x[1:]) 112 | enc_output_shape = encoder_output.get_shape().as_list() 113 | 114 | predictor_out = predictor(encoder_output) 115 | pred_output_shape = predictor_out.get_shape().as_list() 116 | 117 | assert ( 118 | enc_output_shape[1] == projection_dim 119 | ), "encoder output shape not the same as projection dim" 120 | assert ( 121 | pred_output_shape[1] == projection_dim 122 | ), "predictor output shape not the same as projection dim" 123 | 124 | 125 | def test_progressivegan(): 126 | """Test for both discriminator and generator of progressive gan""" 127 | 128 | latent_size = 256 129 | label_size = 2 130 | g_fmap_base = 1024 131 | d_fmap_base = 1024 132 | alpha = 1.0 133 | 134 | generator, discriminator = progressivegan( 135 | latent_size, 136 | label_size=label_size, 137 | g_fmap_base=g_fmap_base, 138 | d_fmap_base=d_fmap_base, 139 | ) 140 | 141 | resolutions = [8, 16] 142 | 143 | for res in resolutions: 144 | generator.add_resolution() 145 | discriminator.add_resolution() 146 | 147 | latent_input = np.random.random((10, latent_size)) 148 | real_image_input = np.random.random((10, res, res, res, 1)) 149 | 150 | fake_images = generator([latent_input, alpha]) 151 | real_pred, real_labels_pred = discriminator([real_image_input, alpha]) 152 | fake_pred, fake_labels_pred = discriminator([fake_images, alpha]) 153 | 154 | assert fake_images.shape == real_image_input.shape 155 | assert real_pred.shape == (real_image_input.shape[0],) 156 | assert fake_pred.shape == (real_image_input.shape[0],) 157 | assert real_labels_pred.shape == (real_image_input.shape[0], label_size) 158 | assert fake_labels_pred.shape == (real_image_input.shape[0], label_size) 159 | 160 | 161 | def test_dcgan(): 162 | """Special test for dcgan.""" 163 | 164 | output_shape = (1, 32, 32, 32, 1) 165 | z_dim = 32 166 | z = np.random.random((1, z_dim)) 167 | 168 | pred_shape = (1, 8, 8, 8, 1) 169 | 170 | generator, discriminator = dcgan(output_shape[1:], z_dim=z_dim) 171 | generator.compile(tf.optimizers.Adam(), "mse") 172 | discriminator.compile(tf.optimizers.Adam(), "mse") 173 | 174 | fake_images = generator.predict(z) 175 | fake_pred = discriminator.predict(fake_images) 176 | 177 | assert fake_images.shape == output_shape and fake_pred.shape == pred_shape 178 | 179 | 180 | def test_vnet(): 181 | model_test(vnet, n_classes=1, input_shape=(1, 32, 32, 32, 1)) 182 | 183 | 184 | def model_test_bayesian(model_cls, n_classes, input_shape, kernel_posterior_fn): 185 | """Tests for models.""" 186 | x = 10 * np.random.random(input_shape) 187 | y = np.random.choice([True, False], input_shape) 188 | 189 | # Assume every model class has n_classes and input_shape arguments. 190 | model = model_cls( 191 | n_classes=n_classes, 192 | input_shape=input_shape[1:], 193 | kernel_posterior_fn=kernel_posterior_fn, 194 | ) 195 | model.compile(tf.optimizers.Adam(), "binary_crossentropy") 196 | model.fit(x, y) 197 | 198 | actual_output = model.predict(x) 199 | assert actual_output.shape == x.shape[:-1] + (n_classes,) 200 | 201 | 202 | def test_bayesian_vnet_semi(): 203 | model_test_bayesian( 204 | bayesian_vnet_semi, 205 | n_classes=1, 206 | input_shape=(1, 32, 32, 32, 1), 207 | kernel_posterior_fn=default_mean_field_normal_fn(weightnorm=True), 208 | ) 209 | 210 | 211 | def test_bayesian_vnet(): 212 | model_test_bayesian( 213 | bayesian_vnet, 214 | n_classes=1, 215 | input_shape=(1, 32, 32, 32, 1), 216 | kernel_posterior_fn=default_mean_field_normal_fn(weightnorm=True), 217 | ) 218 | 219 | 220 | def test_unet_lstm(): 221 | input_shape = (1, 32, 32, 32, 32) 222 | n_classes = 1 223 | x = 10 * np.random.random(input_shape) 224 | y = 10 * np.random.random(input_shape) 225 | model = unet_lstm(input_shape=(32, 32, 32, 32, 1), n_classes=1) 226 | actual_output = model.predict(x) 227 | assert actual_output.shape == y.shape[:-1] + (n_classes,) 228 | 229 | 230 | def test_vox2vox(): 231 | input_shape = (1, 32, 32, 32, 1) 232 | n_classes = 1 233 | x = 10 * np.random.random(input_shape) 234 | y = np.random.choice([True, False], input_shape) 235 | 236 | # testing ensembler 237 | model_test(Vox_ensembler, n_classes, input_shape) 238 | 239 | # testing Vox2VoxGan 240 | vox_generator, vox_discriminator = vox_gan(n_classes, input_shape[1:]) 241 | 242 | # testing generator 243 | vox_generator.compile(tf.optimizers.Adam(), "binary_crossentropy") 244 | vox_generator.fit(x, y) 245 | actual_output = vox_generator.predict(x) 246 | assert actual_output.shape == x.shape[:-1] + (n_classes,) 247 | 248 | # testing descriminator 249 | pred_shape = (1, 2, 2, 2, 1) 250 | out = vox_discriminator(inputs=[y, x]) 251 | assert out.shape == pred_shape 252 | 253 | 254 | def test_attention_unet(): 255 | model_test(attention_unet, n_classes=1, input_shape=(1, 64, 64, 64, 1)) 256 | 257 | 258 | def test_attention_unet_with_inception(): 259 | model_test( 260 | attention_unet_with_inception, n_classes=1, input_shape=(1, 64, 64, 64, 1) 261 | ) 262 | 263 | 264 | @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Cannot test in GitHub Actions") 265 | def test_unetr(): 266 | model_test(unetr, n_classes=1, input_shape=(1, 96, 96, 96, 1)) 267 | 268 | 269 | def test_variational_meshnet(): 270 | model_test( 271 | variational_meshnet, 272 | n_classes=1, 273 | input_shape=(1, 128, 128, 128, 1), 274 | kwds={"filters": 4}, 275 | ) 276 | -------------------------------------------------------------------------------- /nobrainer/models/unet.py: -------------------------------------------------------------------------------- 1 | """Model definition for UNet. 2 | """ 3 | 4 | import tensorflow as tf 5 | from tensorflow.keras import layers 6 | 7 | 8 | def unet( 9 | n_classes, 10 | input_shape, 11 | activation="relu", 12 | batchnorm=False, 13 | batch_size=None, 14 | name="unet", 15 | ): 16 | """ 17 | Instantiate a 3D UNet Architecture 18 | UNet model: a 3D deep neural network model from 19 | https://arxiv.org/abs/1606.06650 20 | Args: 21 | n_classes(int): number of classes 22 | input_shape(tuple):four ints representing the shape of 3D input 23 | activation(str): all tf.keras.activations are allowed 24 | batch_size(int): batch size. 25 | """ 26 | 27 | conv_kwds = { 28 | "kernel_size": (3, 3, 3), 29 | "activation": None, 30 | "padding": "same", 31 | # 'kernel_regularizer': tf.keras.regularizers.l2(0.001), 32 | } 33 | 34 | conv_transpose_kwds = { 35 | "kernel_size": (2, 2, 2), 36 | "strides": 2, 37 | "padding": "same", 38 | # 'kernel_regularizer': tf.keras.regularizers.l2(0.001), 39 | } 40 | 41 | n_base_filters = 16 42 | inputs = layers.Input(shape=input_shape, batch_size=batch_size) 43 | 44 | # Begin analysis path (encoder). 45 | 46 | x = layers.Conv3D(n_base_filters, **conv_kwds)(inputs) 47 | if batchnorm: 48 | x = layers.BatchNormalization()(x) 49 | x = layers.Activation(activation)(x) 50 | x = layers.Conv3D(n_base_filters * 2, **conv_kwds)(x) 51 | if batchnorm: 52 | x = layers.BatchNormalization()(x) 53 | skip_1 = x = layers.Activation(activation)(x) 54 | x = layers.MaxPool3D(2)(x) 55 | 56 | x = layers.Conv3D(n_base_filters * 2, **conv_kwds)(x) 57 | if batchnorm: 58 | x = layers.BatchNormalization()(x) 59 | x = layers.Activation(activation)(x) 60 | x = layers.Conv3D(n_base_filters * 4, **conv_kwds)(x) 61 | if batchnorm: 62 | x = layers.BatchNormalization()(x) 63 | skip_2 = x = layers.Activation(activation)(x) 64 | x = layers.MaxPool3D(2)(x) 65 | 66 | x = layers.Conv3D(n_base_filters * 4, **conv_kwds)(x) 67 | if batchnorm: 68 | x = layers.BatchNormalization()(x) 69 | x = layers.Activation(activation)(x) 70 | x = layers.Conv3D(n_base_filters * 8, **conv_kwds)(x) 71 | if batchnorm: 72 | x = layers.BatchNormalization()(x) 73 | skip_3 = x = layers.Activation(activation)(x) 74 | x = layers.MaxPool3D(2)(x) 75 | 76 | x = layers.Conv3D(n_base_filters * 8, **conv_kwds)(x) 77 | if batchnorm: 78 | x = layers.BatchNormalization()(x) 79 | x = layers.Activation(activation)(x) 80 | x = layers.Conv3D(n_base_filters * 16, **conv_kwds)(x) 81 | if batchnorm: 82 | x = layers.BatchNormalization()(x) 83 | x = layers.Activation(activation)(x) 84 | 85 | # End analysis path (encoder). 86 | # Begin synthesis path (decoder). 87 | 88 | x = layers.Conv3DTranspose(n_base_filters * 16, **conv_transpose_kwds)(x) 89 | 90 | x = layers.Concatenate(axis=-1)([skip_3, x]) 91 | x = layers.Conv3D(n_base_filters * 8, **conv_kwds)(x) 92 | if batchnorm: 93 | x = layers.BatchNormalization()(x) 94 | x = layers.Activation(activation)(x) 95 | x = layers.Conv3D(n_base_filters * 8, **conv_kwds)(x) 96 | if batchnorm: 97 | x = layers.BatchNormalization()(x) 98 | x = layers.Activation(activation)(x) 99 | 100 | x = layers.Conv3DTranspose(n_base_filters * 8, **conv_transpose_kwds)(x) 101 | 102 | x = layers.Concatenate(axis=-1)([skip_2, x]) 103 | x = layers.Conv3D(n_base_filters * 4, **conv_kwds)(x) 104 | if batchnorm: 105 | x = layers.BatchNormalization()(x) 106 | x = layers.Activation(activation)(x) 107 | x = layers.Conv3D(n_base_filters * 4, **conv_kwds)(x) 108 | if batchnorm: 109 | x = layers.BatchNormalization()(x) 110 | x = layers.Activation(activation)(x) 111 | 112 | x = layers.Conv3DTranspose(n_base_filters * 4, **conv_transpose_kwds)(x) 113 | 114 | x = layers.Concatenate(axis=-1)([skip_1, x]) 115 | x = layers.Conv3D(n_base_filters * 2, **conv_kwds)(x) 116 | if batchnorm: 117 | x = layers.BatchNormalization()(x) 118 | x = layers.Activation(activation)(x) 119 | x = layers.Conv3D(n_base_filters * 2, **conv_kwds)(x) 120 | if batchnorm: 121 | x = layers.BatchNormalization()(x) 122 | x = layers.Activation(activation)(x) 123 | 124 | x = layers.Conv3D(filters=n_classes, kernel_size=1)(x) 125 | 126 | final_activation = "sigmoid" if n_classes == 1 else "softmax" 127 | x = layers.Activation(final_activation)(x) 128 | 129 | return tf.keras.Model(inputs=inputs, outputs=x, name=name) 130 | -------------------------------------------------------------------------------- /nobrainer/models/vnet.py: -------------------------------------------------------------------------------- 1 | # Adaptation of the VNet model from https://arxiv.org/pdf/1606.04797.pdf 2 | # This 3D deep neural network model is regularized with 3D spatial dropout 3 | # and Group normalization. 4 | 5 | from tensorflow.keras.layers import ( 6 | Conv3D, 7 | Input, 8 | MaxPooling3D, 9 | SpatialDropout3D, 10 | UpSampling3D, 11 | concatenate, 12 | ) 13 | from tensorflow.keras.models import Model 14 | 15 | from ..layers.groupnorm import GroupNormalization 16 | 17 | 18 | def down_stage(inputs, filters, kernel_size=3, activation="relu", padding="SAME"): 19 | """encoding block of the VNet model. 20 | 21 | Parameters 22 | ---------- 23 | inputs: tf.layer for encoding stage. 24 | filters: list or tuple of four ints, the shape of the input data. Omit 25 | the batch dimension, and include the number of channels. 26 | kernal_size: int, size of the kernel of conv layers. Default kernel size 27 | is set to be 3. 28 | activation: str or optimizer object, the non-linearity to use. All 29 | tf.activations are allowed to use 30 | 31 | Returns 32 | ---------- 33 | encoding module. 34 | """ 35 | convd = Conv3D(filters, kernel_size, activation=activation, padding=padding)(inputs) 36 | convd = GroupNormalization()(convd) 37 | convd = Conv3D(filters, kernel_size, activation=activation, padding=padding)(convd) 38 | convd = GroupNormalization()(convd) 39 | pool = MaxPooling3D()(convd) 40 | return convd, pool 41 | 42 | 43 | def up_stage(inputs, skip, filters, kernel_size=3, activation="relu", padding="SAME"): 44 | """decoding block of the VNet model. 45 | 46 | Parameters 47 | ---------- 48 | inputs: tf.layer for encoding stage. 49 | filters: list or tuple of four ints, the shape of the input data. Omit 50 | the batch dimension, and include the number of channels. 51 | kernal_size: int, size of the kernel of conv layers. Default kernel size 52 | is set to be 3. 53 | activation: str or optimizer object, the non-linearity to use. All 54 | tf.activations are allowed to use 55 | 56 | Returns 57 | ---------- 58 | decoded module. 59 | """ 60 | up = UpSampling3D()(inputs) 61 | up = Conv3D(filters, 2, activation=activation, padding=padding)(up) 62 | up = GroupNormalization()(up) 63 | 64 | merge = concatenate([skip, up]) 65 | merge = GroupNormalization()(merge) 66 | 67 | convu = Conv3D(filters, kernel_size, activation=activation, padding=padding)(merge) 68 | convu = GroupNormalization()(convu) 69 | convu = Conv3D(filters, kernel_size, activation=activation, padding=padding)(convu) 70 | convu = GroupNormalization()(convu) 71 | convu = SpatialDropout3D(0.5)(convu, training=True) 72 | 73 | return convu 74 | 75 | 76 | def end_stage(inputs, n_classes=1, kernel_size=3, activation="relu", padding="SAME"): 77 | """last logit layer. 78 | 79 | Parameters 80 | ---------- 81 | inputs: tf.model layer. 82 | n_classes: int, for binary class use the value 1. 83 | kernal_size: int, size of the kernel of conv layers. Default kernel size 84 | is set to be 3. 85 | activation: str or optimizer object, the non-linearity to use. All 86 | tf.activations are allowed to use 87 | 88 | Result 89 | ---------- 90 | prediction probabilities 91 | """ 92 | conv = Conv3D( 93 | filters=n_classes, 94 | kernel_size=kernel_size, 95 | activation=activation, 96 | padding="SAME", 97 | )(inputs) 98 | if n_classes == 1: 99 | conv = Conv3D(n_classes, 1, activation="sigmoid")(conv) 100 | else: 101 | conv = Conv3D(n_classes, 1, activation="softmax")(conv) 102 | 103 | return conv 104 | 105 | 106 | def vnet( 107 | n_classes=1, 108 | input_shape=(128, 128, 128, 1), 109 | kernel_size=3, 110 | activation="relu", 111 | padding="SAME", 112 | **kwargs 113 | ): 114 | """Instantiate a 3D VNet Architecture. 115 | 116 | VNet model: a 3D deep neural network model adapted from 117 | https://arxiv.org/pdf/1606.04797.pdf adatptations include groupnorm 118 | and spatial dropout. 119 | 120 | Parameters 121 | ---------- 122 | n_classes: int, number of classes to classify. For binary applications, use 123 | a value of 1. 124 | input_shape: list or tuple of four ints, the shape of the input data. Omit 125 | the batch dimension, and include the number of channels. 126 | kernal_size: int, size of the kernel of conv layers. Default kernel size 127 | is set to be 3. 128 | activation: str or optimizer object, the non-linearity to use. All 129 | tf.activations are allowed to use 130 | 131 | Returns 132 | ---------- 133 | Model object. 134 | 135 | """ 136 | inputs = Input(input_shape) 137 | 138 | conv1, pool1 = down_stage( 139 | inputs, 16, kernel_size=kernel_size, activation=activation, padding=padding 140 | ) 141 | conv2, pool2 = down_stage( 142 | pool1, 32, kernel_size=kernel_size, activation=activation, padding=padding 143 | ) 144 | conv3, pool3 = down_stage( 145 | pool2, 64, kernel_size=kernel_size, activation=activation, padding=padding 146 | ) 147 | conv4, _ = down_stage( 148 | pool3, 128, kernel_size=kernel_size, activation=activation, padding=padding 149 | ) 150 | conv4 = SpatialDropout3D(0.5)(conv4, training=True) 151 | 152 | conv5 = up_stage( 153 | conv4, 154 | conv3, 155 | 64, 156 | kernel_size=kernel_size, 157 | activation=activation, 158 | padding=padding, 159 | ) 160 | conv6 = up_stage( 161 | conv5, 162 | conv2, 163 | 32, 164 | kernel_size=kernel_size, 165 | activation=activation, 166 | padding=padding, 167 | ) 168 | conv7 = up_stage( 169 | conv6, 170 | conv1, 171 | 16, 172 | kernel_size=kernel_size, 173 | activation=activation, 174 | padding=padding, 175 | ) 176 | 177 | conv8 = end_stage( 178 | conv7, 179 | n_classes=n_classes, 180 | kernel_size=kernel_size, 181 | activation=activation, 182 | padding=padding, 183 | ) 184 | 185 | return Model(inputs=inputs, outputs=conv8) 186 | -------------------------------------------------------------------------------- /nobrainer/processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuronets/nobrainer/976691d685824fd4bba836498abea4184cffd798/nobrainer/processing/__init__.py -------------------------------------------------------------------------------- /nobrainer/processing/base.py: -------------------------------------------------------------------------------- 1 | """Base classes for all estimators.""" 2 | 3 | import inspect 4 | import os 5 | from pathlib import Path 6 | import pickle as pk 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def get_strategy(multi_gpu): 12 | if multi_gpu: 13 | return tf.distribute.MirroredStrategy() 14 | return tf.distribute.get_strategy() 15 | 16 | 17 | class BaseEstimator: 18 | """Base class for all high-level models in Nobrainer.""" 19 | 20 | state_variables = [] 21 | model_ = None 22 | 23 | def __init__(self, checkpoint_filepath=None, multi_gpu=True): 24 | self.checkpoint_tracker = None 25 | if checkpoint_filepath: 26 | from .checkpoint import CheckpointTracker 27 | 28 | self.checkpoint_tracker = CheckpointTracker(self, checkpoint_filepath) 29 | 30 | self.strategy = get_strategy(multi_gpu) 31 | 32 | @property 33 | def model(self): 34 | return self.model_ 35 | 36 | def save(self, save_dir): 37 | """Saves a trained model""" 38 | if self.model_ is None: 39 | raise ValueError("Model is undefined. Please train or load a model") 40 | self.model_.save(save_dir) 41 | model_info = {"classname": self.__class__.__name__, "__init__": {}} 42 | for key in inspect.signature(self.__init__).parameters: 43 | # TODO this assumes that all parameters passed to __init__ 44 | # are stored as members, which doesn't leave room for 45 | # parameters that are specific to the runtime context. 46 | # (e.g. multi_gpu). 47 | if key == "multi_gpu" or key == "checkpoint_filepath": 48 | continue 49 | model_info["__init__"][key] = getattr(self, key) 50 | for val in self.state_variables: 51 | model_info[val] = getattr(self, val) 52 | model_file = Path(save_dir) / "model_params.pkl" 53 | with open(model_file, "wb") as fp: 54 | pk.dump(model_info, fp) 55 | 56 | @classmethod 57 | def load( 58 | cls, 59 | model_dir, 60 | multi_gpu=True, 61 | custom_objects=None, 62 | compile=False, 63 | ): 64 | """Loads a trained model from a save directory""" 65 | model_dir = Path(str(model_dir).rstrip(os.pathsep)) 66 | assert model_dir.exists() and model_dir.is_dir() 67 | model_file = model_dir / "model_params.pkl" 68 | with open(model_file, "rb") as fp: 69 | model_info = pk.load(fp) 70 | if model_info["classname"] != cls.__name__: 71 | raise ValueError(f"Model class does not match {cls.__name__}") 72 | del model_info["classname"] 73 | 74 | klass = cls(**model_info["__init__"]) 75 | del model_info["__init__"] 76 | for key, value in model_info.items(): 77 | setattr(klass, key, value) 78 | 79 | klass.strategy = get_strategy(multi_gpu) 80 | with klass.strategy.scope(): 81 | klass.model_ = tf.keras.models.load_model( 82 | model_dir, 83 | custom_objects=custom_objects, 84 | compile=compile, 85 | ) 86 | return klass 87 | 88 | @classmethod 89 | def init_with_checkpoints( 90 | cls, 91 | model_name, 92 | checkpoint_filepath, 93 | multi_gpu=True, 94 | custom_objects=None, 95 | compile=False, 96 | model_args=None, 97 | ): 98 | """Initialize a model for training, either from the latest 99 | checkpoint found, or from scratch if no checkpoints are 100 | found. This is useful for long-running model fits that may be 101 | interrupted or preepmted during training and need to pick up 102 | where they left off. 103 | 104 | model_name: str or Module in nobrainer.models, the base model 105 | for this estimator. 106 | 107 | checkpoint_filepath: str, path to which checkpoints will be 108 | saved and loaded. Supports the epoch and block flormating 109 | parameters supported by tensorflows ModelCheckpoint, 110 | e.g. /{epoch:03d} 111 | 112 | """ 113 | from .checkpoint import CheckpointTracker 114 | 115 | checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath) 116 | estimator = checkpoint_tracker.load( 117 | multi_gpu=multi_gpu, 118 | custom_objects=custom_objects, 119 | compile=compile, 120 | ) 121 | if not estimator: 122 | estimator = cls(model_name, model_args=model_args) 123 | estimator.checkpoint_tracker = checkpoint_tracker 124 | checkpoint_tracker.estimator = estimator 125 | return estimator 126 | 127 | 128 | class TransformerMixin: 129 | """Mixin class for all transformers in scikit-learn.""" 130 | 131 | def fit_transform(self, X, y=None, **fit_params): 132 | """ 133 | Fit to data, then transform it. 134 | Fits transformer to `X` and `y` with optional parameters `fit_params` 135 | and returns a transformed version of `X`. 136 | Parameters 137 | ---------- 138 | X : array-like of shape (n_samples, n_features) 139 | Input samples. 140 | y : array-like of shape (n_samples,) or (n_samples, n_outputs), \ 141 | default=None 142 | Target values (None for unsupervised transformations). 143 | **fit_params : dict 144 | Additional fit parameters. 145 | Returns 146 | ------- 147 | X_new : ndarray array of shape (n_samples, n_features_new) 148 | Transformed array. 149 | """ 150 | # non-optimized default implementation; override when a better 151 | # method is possible for a given clustering algorithm 152 | if y is None: 153 | # fit method of arity 1 (unsupervised transformation) 154 | return self.fit(X, **fit_params).transform(X) 155 | else: 156 | # fit method of arity 2 (supervised transformation) 157 | return self.fit(X, y, **fit_params).transform(X) 158 | -------------------------------------------------------------------------------- /nobrainer/processing/checkpoint.py: -------------------------------------------------------------------------------- 1 | """Checkpointing utils""" 2 | 3 | from glob import glob 4 | import logging 5 | import os 6 | 7 | import tensorflow as tf 8 | 9 | 10 | class CheckpointTracker(tf.keras.callbacks.ModelCheckpoint): 11 | """Class for saving/loading estimators at/from checkpoints.""" 12 | 13 | def __init__(self, estimator, file_path, **kwargs): 14 | """ 15 | estimator: BaseEstimator, instance of an estimator (e.g., Segmentation). 16 | file_path: str, directory to/from which to save or load. 17 | """ 18 | self.estimator = estimator 19 | super().__init__(file_path, **kwargs) 20 | 21 | def _save_model(self, epoch, batch, logs): 22 | """Save the current state of the estimator. This overrides the 23 | base class implementation to save `nobrainer` specific info. 24 | 25 | epoch: int, the index of the epoch that just finished. 26 | batch: int, the index of the batch that just finished. 27 | logs: dict, logging info passed into on_epoch_end or on_batch_end. 28 | """ 29 | self.save(self._get_file_path(epoch, batch, logs)) 30 | 31 | def save(self, directory): 32 | """Save the current state of the estimator. 33 | directory: str, path in which to save the model. 34 | """ 35 | logging.info(f"Saving to dir {directory}") 36 | self.estimator.save(directory) 37 | 38 | def load( 39 | self, 40 | multi_gpu=True, 41 | custom_objects=None, 42 | compile=False, 43 | model_args=None, 44 | ): 45 | """Loads the most-recently created checkpoint from the 46 | checkpoint directory. 47 | """ 48 | checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/")) 49 | if not checkpoints: 50 | self.last_epoch = 0 51 | return None 52 | 53 | # TODO, we should probably exclude non-checkpoint files here, 54 | # and maybe parse the filename for the epoch number 55 | self.last_epoch = len(checkpoints) 56 | 57 | latest = max(checkpoints, key=os.path.getctime) 58 | self.estimator = self.estimator.load( 59 | latest, 60 | multi_gpu=multi_gpu, 61 | custom_objects=custom_objects, 62 | compile=compile, 63 | ) 64 | logging.info(f"Loaded estimator from {latest}.") 65 | return self.estimator 66 | -------------------------------------------------------------------------------- /nobrainer/processing/generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | 6 | from .base import BaseEstimator 7 | from .. import losses 8 | from ..dataset import Dataset 9 | 10 | 11 | class ProgressiveGeneration(BaseEstimator): 12 | """Perform generation type operations""" 13 | 14 | state_variables = ["current_resolution_"] 15 | 16 | def __init__( 17 | self, 18 | latent_size=256, 19 | label_size=0, 20 | num_channels=1, 21 | dimensionality=3, 22 | g_fmap_base=1024, 23 | d_fmap_base=1024, 24 | multi_gpu=True, 25 | ): 26 | super().__init__(multi_gpu=multi_gpu) 27 | self.model_ = None 28 | self.latent_size = latent_size 29 | self.label_size = label_size 30 | self.g_fmap_base = g_fmap_base 31 | self.d_fmap_base = d_fmap_base 32 | self.num_channels = num_channels 33 | self.dimensionality = dimensionality 34 | self.current_resolution_ = 0 35 | 36 | def fit( 37 | self, 38 | dataset_train, 39 | epochs=2, 40 | checkpoint_dir=Path(os.getcwd()) / "temp", 41 | normalizer=None, 42 | # TODO: figure out whether optimizer args should be flattened 43 | g_optimizer=None, 44 | g_opt_args=None, 45 | d_optimizer=None, 46 | d_opt_args=None, 47 | g_loss=losses.Wasserstein, 48 | d_loss=losses.Wasserstein, 49 | warm_start=False, 50 | num_parallel_calls=None, 51 | save_freq=500, 52 | ): 53 | """Train a progressive gan model""" 54 | # TODO: check validity of datasets 55 | 56 | # create checkpoint sub-dirs 57 | checkpoint_dir = Path(checkpoint_dir) 58 | checkpoint_dir.mkdir(exist_ok=True) 59 | generated_dir = checkpoint_dir / "generated" 60 | model_dir = checkpoint_dir / "saved_models" 61 | log_dir = checkpoint_dir / "logs" 62 | 63 | generated_dir.mkdir(exist_ok=True) 64 | model_dir.mkdir(exist_ok=True) 65 | log_dir.mkdir(exist_ok=True) 66 | 67 | # set optimizers 68 | g_opt_args = g_opt_args or {} 69 | if g_optimizer is None: 70 | g_optimizer = tf.keras.optimizers.Adam 71 | g_opt_args_tmp = dict( 72 | learning_rate=1e-04, beta_1=0.0, beta_2=0.99, epsilon=1e-8 73 | ) 74 | g_opt_args_tmp.update(**g_opt_args) 75 | g_opt_args = g_opt_args_tmp 76 | 77 | d_opt_args = d_opt_args or {} 78 | if d_optimizer is None: 79 | d_optimizer = tf.keras.optimizers.Adam 80 | d_opt_args_tmp = dict( 81 | learning_rate=1e-04, beta_1=0.0, beta_2=0.99, epsilon=1e-8 82 | ) 83 | d_opt_args_tmp.update(**d_opt_args) 84 | d_opt_args = d_opt_args_tmp 85 | 86 | if warm_start: 87 | if self.model_ is None: 88 | raise ValueError("warm_start requested, but model is undefined") 89 | else: 90 | from ..models.progressivegan import progressivegan 91 | from ..training import ProgressiveGANTrainer 92 | 93 | # Instantiate the generator and discriminator 94 | with self.strategy.scope(): 95 | generator, discriminator = progressivegan( 96 | latent_size=self.latent_size, 97 | g_fmap_base=self.g_fmap_base, 98 | d_fmap_base=self.d_fmap_base, 99 | num_channels=self.num_channels, 100 | dimensionality=self.dimensionality, 101 | ) 102 | self.model_ = ProgressiveGANTrainer( 103 | generator=generator, 104 | discriminator=discriminator, 105 | gradient_penalty=True, 106 | ) 107 | self.current_resolution_ = 0 108 | 109 | # wrap the losses to work on multiple GPUs 110 | with self.strategy.scope(): 111 | d_loss_object = d_loss(reduction=tf.keras.losses.Reduction.NONE) 112 | 113 | def compute_d_loss(labels, predictions): 114 | per_example_loss = d_loss_object(labels, predictions) 115 | return tf.nn.compute_average_loss( 116 | per_example_loss, global_batch_size=batch_size 117 | ) 118 | 119 | g_loss_object = g_loss(reduction=tf.keras.losses.Reduction.NONE) 120 | 121 | def compute_g_loss(labels, predictions): 122 | per_example_loss = g_loss_object(labels, predictions) 123 | return tf.nn.compute_average_loss( 124 | per_example_loss, global_batch_size=batch_size 125 | ) 126 | 127 | d_loss = compute_d_loss 128 | g_loss = compute_g_loss 129 | 130 | # instantiate a progressive training helper and compile with loss and optimizer 131 | def _compile(): 132 | self.model_.compile( 133 | g_optimizer=g_optimizer(**g_opt_args), 134 | d_optimizer=d_optimizer(**d_opt_args), 135 | g_loss_fn=g_loss, 136 | d_loss_fn=d_loss, 137 | ) 138 | 139 | print(self.model_.generator.summary()) 140 | print(self.model_.discriminator.summary()) 141 | 142 | for resolution, info in dataset_train.items(): 143 | if resolution < self.current_resolution_: 144 | continue 145 | # create a train dataset with features for resolution 146 | batch_size = info.get("batch_size") 147 | if batch_size % self.strategy.num_replicas_in_sync: 148 | raise ValueError("batch size must be a multiple of the number of GPUs") 149 | 150 | dataset = Dataset.from_tfrecords( 151 | file_pattern=info.get("file_pattern"), 152 | num_parallel_calls=num_parallel_calls, 153 | volume_shape=(resolution, resolution, resolution), 154 | n_classes=1, 155 | scalar_labels=True, 156 | ) 157 | n_epochs = info.get("epochs") or epochs 158 | dataset.batch(batch_size).normalize( 159 | info.get("normalizer") or normalizer 160 | ).repeat(n_epochs) 161 | 162 | with self.strategy.scope(): 163 | # grow the networks by one (2^x) resolution 164 | if resolution > self.current_resolution_: 165 | self.model_.generator.add_resolution() 166 | self.model_.discriminator.add_resolution() 167 | _compile() 168 | 169 | steps_per_epoch = n_epochs // info.get("batch_size") 170 | 171 | # save_best_only is set to False as it is an adversarial loss 172 | model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 173 | str(model_dir), 174 | save_weights_only=True, 175 | save_best_only=False, 176 | save_freq=save_freq, 177 | verbose=False, 178 | ) 179 | 180 | # Train at resolution 181 | print("Resolution : {}".format(resolution)) 182 | 183 | print("Transition phase") 184 | self.model_.fit( 185 | dataset.dataset, 186 | phase="transition", 187 | resolution=resolution, 188 | steps_per_epoch=steps_per_epoch, # necessary for repeat dataset 189 | callbacks=[model_checkpoint_callback], 190 | ) 191 | 192 | print("Resolution phase") 193 | self.model_.fit( 194 | dataset.dataset, 195 | phase="resolution", 196 | resolution=resolution, 197 | steps_per_epoch=steps_per_epoch, 198 | callbacks=[model_checkpoint_callback], 199 | ) 200 | self.current_resolution_ = resolution 201 | # save the final weights 202 | self.model_.save_weights(model_dir) 203 | return self 204 | 205 | def generate(self, n_images=1, return_latents=False, data_type=None): 206 | """generate a synthetic image using the trained model""" 207 | if self.model_ is None: 208 | raise ValueError("Model is undefined. Please train or load a model") 209 | import nibabel as nib 210 | import numpy as np 211 | 212 | latents_all = [] 213 | img_all = [] 214 | for i in range(n_images): 215 | latents = tf.random.normal((1, self.latent_size)) 216 | img = self.model_.generator.generate(latents)["generated"] 217 | img = np.squeeze(img) 218 | if data_type is not None: 219 | img = np.round( 220 | np.iinfo(data_type).max 221 | * (img - img.min()) 222 | / (img.max() - img.min()) 223 | ).astype(data_type) 224 | img = nib.Nifti1Image(img, np.eye(4)) 225 | latents_all.append(latents) 226 | img_all.append(img) 227 | if return_latents: 228 | return img_all, latents_all 229 | else: 230 | return img_all 231 | -------------------------------------------------------------------------------- /nobrainer/processing/segmentation.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | 4 | import tensorflow as tf 5 | 6 | from .base import BaseEstimator 7 | from .. import losses, metrics 8 | from ..models import available_models, list_available_models 9 | 10 | logging.getLogger().setLevel(logging.INFO) 11 | 12 | 13 | class Segmentation(BaseEstimator): 14 | """Perform segmentation type operations""" 15 | 16 | state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"] 17 | 18 | def __init__( 19 | self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=True 20 | ): 21 | super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu) 22 | 23 | if not isinstance(base_model, str): 24 | self.base_model = base_model.__name__ 25 | else: 26 | self.base_model = base_model 27 | 28 | if self.base_model and self.base_model not in available_models(): 29 | raise ValueError( 30 | "Unknown model: '{}'. Available models are {}.".format( 31 | self.base_model, available_models() 32 | ) 33 | ) 34 | 35 | self.model_ = None 36 | self.model_args = model_args or {} 37 | self.block_shape_ = None 38 | self.volume_shape_ = None 39 | self.scalar_labels_ = None 40 | 41 | def add_model(self, base_model, model_args=None): 42 | """Add a segmentation model""" 43 | self.base_model = base_model 44 | self.model_args = model_args or {} 45 | 46 | def fit( 47 | self, 48 | dataset_train, 49 | dataset_validate=None, 50 | epochs=1, 51 | # TODO: figure out whether optimizer args should be flattened 52 | optimizer=None, 53 | opt_args=None, 54 | loss=losses.dice, 55 | metrics=metrics.dice, 56 | callbacks=None, 57 | verbose=1, 58 | ): 59 | """Train a segmentation model""" 60 | # TODO: check validity of datasets 61 | 62 | batch_size = dataset_train.batch_size 63 | self.block_shape_ = dataset_train.block_shape 64 | self.volume_shape_ = dataset_train.volume_shape 65 | self.scalar_labels_ = dataset_train.scalar_labels 66 | n_classes = dataset_train.n_classes 67 | opt_args = opt_args or {} 68 | if optimizer is None: 69 | optimizer = tf.keras.optimizers.Adam 70 | opt_args_tmp = dict(learning_rate=1e-04) 71 | opt_args_tmp.update(**opt_args) 72 | opt_args = opt_args_tmp 73 | 74 | def _create(base_model): 75 | # Instantiate and compile the model 76 | self.model_ = base_model( 77 | n_classes=n_classes, 78 | input_shape=(*self.block_shape_, 1), 79 | **self.model_args 80 | ) 81 | 82 | def _compile(): 83 | self.model_.compile( 84 | optimizer(**opt_args), 85 | loss=loss, 86 | metrics=metrics, 87 | ) 88 | 89 | if self.model is None: 90 | mod = importlib.import_module("..models", "nobrainer.processing") 91 | base_model = getattr(mod, self.base_model) 92 | if batch_size % self.strategy.num_replicas_in_sync: 93 | raise ValueError("batch size must be a multiple of the number of GPUs") 94 | 95 | with self.strategy.scope(): 96 | _create(base_model) 97 | with self.strategy.scope(): 98 | _compile() 99 | self.model_.summary() 100 | 101 | if callbacks is not None and not isinstance(callbacks, list): 102 | raise AttributeError("Callbacks must be either of type list or None") 103 | 104 | if callbacks is None: 105 | callbacks = [] 106 | 107 | if self.checkpoint_tracker: 108 | callbacks.append(self.checkpoint_tracker) 109 | self.model_.fit( 110 | dataset_train.dataset, 111 | epochs=epochs, 112 | steps_per_epoch=dataset_train.get_steps_per_epoch(), 113 | validation_data=dataset_validate.dataset if dataset_validate else None, 114 | validation_steps=( 115 | dataset_validate.get_steps_per_epoch() if dataset_validate else None 116 | ), 117 | callbacks=callbacks, 118 | verbose=verbose, 119 | ) 120 | 121 | return self 122 | 123 | def predict(self, x, batch_size=1, normalizer=None): 124 | """Makes a prediction using the trained model""" 125 | if self.model_ is None: 126 | raise ValueError("Model is undefined. Please train or load a model") 127 | from ..prediction import predict 128 | 129 | return predict( 130 | x, 131 | self.model_, 132 | block_shape=self.block_shape_, 133 | batch_size=batch_size, 134 | normalizer=normalizer, 135 | ) 136 | 137 | @classmethod 138 | def list_available_models(cls): 139 | list_available_models() 140 | -------------------------------------------------------------------------------- /nobrainer/spatial_transforms.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def centercrop(x, y=None, finesize=64, trans_xy=False): 5 | """Apply center crop to input and label. 6 | 7 | Usage: 8 | ```python 9 | >>> x = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]) 10 | >>> finesize = 1 11 | >>> x_out = spatial_transforms.centercrop(x, finesize=finesize) 12 | >>> x_out 13 | 15 | ``` 16 | 17 | Parameters 18 | ---------- 19 | x: input is a tensor or numpy to have rank 3, 20 | y: label is a tensor or numpy to have rank 3, 21 | finesize: int, desired output size of the crop. Default = 64; 22 | trans_xy: Boolean, transforms both x and y (Default: False). 23 | If set True, function will require both x,y. 24 | 25 | Returns 26 | ---------- 27 | CenterCroped input and/or label tensor. 28 | """ 29 | if ~tf.is_tensor(x): 30 | x = tf.convert_to_tensor(x) 31 | x = tf.cast(x, tf.float32) 32 | if len(x.shape) != 3: 33 | raise ValueError("`volume` must be rank 3") 34 | w, h = x.shape[1], x.shape[0] 35 | th, tw = finesize, finesize 36 | x1 = int(round((w - tw) / 2.0)) 37 | y1 = int(round((h - th) / 2.0)) 38 | x = x[y1 : y1 + th, x1 : x1 + tw, :] 39 | 40 | if trans_xy: 41 | if y is None: 42 | raise ValueError("`LabelMap' should be assigned") 43 | if len(y.shape) != 3: 44 | raise ValueError("`LabelMap` must be equal or higher than rank 2") 45 | if ~tf.is_tensor(y): 46 | y = tf.convert_to_tensor(y) 47 | y = tf.cast(y, tf.float32) 48 | y = y[y1 : y1 + th, x1 : x1 + tw, :] 49 | if y is None: 50 | return x 51 | return x, y 52 | 53 | 54 | def spatialConstantPadding(x, y=None, trans_xy=False, padding_zyx=[1, 1, 1]): 55 | """Add constant padding to input and label. 56 | 57 | Usage: 58 | ```python 59 | >>> x = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]) 60 | >>> x_out = spatial_transforms.spatialConstantPadding( 61 | x,padding_zyx=[0, 1, 1]) 62 | >>> x_out 63 | 69 | ``` 70 | 71 | Parameters 72 | ---------- 73 | x: input is a tensor or numpy to have rank 3, 74 | y: label is a tensor or numpy to have rank 3, 75 | padding_zyx: int or a list of desired padding in three dimensions. 76 | Default = 1; 77 | trans_xy: Boolean, transforms both x and y (Default: False). 78 | If set True, function will require both x,y. 79 | 80 | Returns 81 | ---------- 82 | Input and/or label tensor with spatial padding. 83 | """ 84 | if ~tf.is_tensor(x): 85 | x = tf.convert_to_tensor(x) 86 | x = tf.cast(x, tf.float32) 87 | padz = padding_zyx[0] 88 | pady = padding_zyx[1] 89 | padx = padding_zyx[2] 90 | padding = tf.constant([[padz, padz], [pady, pady], [padx, padx]]) 91 | x = tf.pad(x, padding, "CONSTANT") 92 | if trans_xy: 93 | if y is None: 94 | raise ValueError("`LabelMap' should be assigned") 95 | if len(y.shape) != 3: 96 | raise ValueError("`LabelMap` must be equal or higher than rank 2") 97 | if ~tf.is_tensor(y): 98 | y = tf.convert_to_tensor(y) 99 | y = tf.cast(y, tf.float32) 100 | y = tf.pad(y, padding, "CONSTANT") 101 | if y is None: 102 | return x 103 | return x, y 104 | 105 | 106 | def randomCrop(x, y=None, trans_xy=False, cropsize=16): 107 | """Apply random crops to input and label. 108 | 109 | Usage: 110 | ```python 111 | >>> x = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]) 112 | >>> x_out = spatial_transforms.randomCrop(x, cropsize=1) 113 | >>> x_out 114 | 116 | ``` 117 | 118 | Parameters 119 | ---------- 120 | x: input is a tensor or numpy to have rank 3, 121 | y: label is a tensor or numpy to have rank 3, 122 | cropsize: int, the size of the cropped output, 123 | finesize: int, desired output size of the crop. Default = 64; 124 | trans_xy: Boolean, transforms both x and y (Default: False). 125 | If set True, function will require both x,y. 126 | 127 | Returns 128 | ---------- 129 | Randomly croped input and/or label tensor. 130 | """ 131 | if ~tf.is_tensor(x): 132 | x = tf.convert_to_tensor(x) 133 | x = tf.cast(x, tf.float32) 134 | if trans_xy: 135 | if y is None: 136 | raise ValueError("`LabelMap' should be assigned") 137 | if len(y.shape) != 3: 138 | raise ValueError("`LabelMap` must be equal or higher than rank 2") 139 | if ~tf.is_tensor(y): 140 | y = tf.convert_to_tensor(y) 141 | y = tf.cast(y, tf.float32) 142 | stacked = tf.stack([x, y], axis=0) 143 | cropped = tf.image.random_crop(stacked, [2, cropsize, cropsize, x.shape[2]]) 144 | return cropped[0], cropped[1] 145 | if y is None: 146 | return tf.image.random_crop(x, [cropsize, cropsize, x.shape[2]]) 147 | return tf.image.random_crop(x, [cropsize, cropsize, x.shape[2]]), y 148 | 149 | 150 | def resize(x, y=None, trans_xy=False, size=[32, 32], mode="bicubic"): 151 | """Resize the input and label. 152 | 153 | Usage: 154 | ```python 155 | >>> x = np.array([[[1, 2, 3], [6, 2, 5], [3, 4, 9]]]) 156 | >>> x_out = spatial_transforms.resize(x,size=[2, 2]) 157 | >>> x_out 158 | 164 | ``` 165 | 166 | Parameters 167 | ---------- 168 | x: input is a tensor or numpy to have rank 3, 169 | y: label is a tensor or numpy to have rank 3, 170 | size: int or a list, the resize output, 171 | trans_xy: Boolean, transforms both x and y (Default: False). 172 | If set True, function will require both x,y. 173 | mode [options]: "bilinear", "lanczos3", 174 | "lanczos5", "bicubic", "gaussian" , "nearest". 175 | 176 | Returns 177 | ---------- 178 | Resized input and/or label tensor. 179 | """ 180 | if ~tf.is_tensor(x): 181 | x = tf.convert_to_tensor(x) 182 | x = tf.cast(x, tf.float32) 183 | x = tf.image.resize(x, size, method=mode) 184 | if trans_xy: 185 | if y is None: 186 | raise ValueError("`LabelMap' should be assigned") 187 | if len(y.shape) != 3: 188 | raise ValueError("`LabelMap` must be equal or higher than rank 2") 189 | if ~tf.is_tensor(y): 190 | y = tf.convert_to_tensor(y) 191 | y = tf.cast(y, tf.float32) 192 | y = tf.image.resize(y, size, method=mode) 193 | if y is None: 194 | return x 195 | return x, y 196 | 197 | 198 | def randomflip_leftright(x, y=None, trans_xy=False): 199 | """Randomly flips the input and label. 200 | 201 | Usage: 202 | ```python 203 | >>> x = np.array([[[1, 2, 3], [6, 2, 5], [3, 4, 9]]]) 204 | >>> x_out = spatial_transforms.randomflip_leftright(x) 205 | >>> x_out 206 | 210 | ``` 211 | 212 | Parameters 213 | ---------- 214 | x: input is a tensor or numpy to have rank 3, 215 | y: label is a tensor or numpy to have rank 3, 216 | trans_xy: Boolean, transforms both x and y (Default: False). 217 | If set True, function will require both x,y. 218 | 219 | Returns 220 | ---------- 221 | Randomly flipped input and/or label tensor. 222 | """ 223 | if ~tf.is_tensor(x): 224 | x = tf.convert_to_tensor(x) 225 | x = tf.cast(x, tf.float32) 226 | if trans_xy: 227 | if y is None: 228 | raise ValueError("`LabelMap' should be assigned") 229 | if len(y.shape) != 3: 230 | raise ValueError("`LabelMap` must be equal or higher than rank 2") 231 | if ~tf.is_tensor(y): 232 | y = tf.convert_to_tensor(y) 233 | y = tf.cast(y, tf.float32) 234 | c = tf.concat([x, y], axis=0) 235 | c = tf.image.random_flip_left_right(c, seed=None) 236 | split_channel = int(c.shape[0] / 2) 237 | return c[0:split_channel, :, :], c[split_channel : c.shape[0], :, :] 238 | if y is None: 239 | return tf.image.random_flip_left_right(x, seed=None) 240 | return tf.image.random_flip_left_right(x, seed=None), y 241 | -------------------------------------------------------------------------------- /nobrainer/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuronets/nobrainer/976691d685824fd4bba836498abea4184cffd798/nobrainer/tests/__init__.py -------------------------------------------------------------------------------- /nobrainer/tests/checkpoint_test.py: -------------------------------------------------------------------------------- 1 | """Tests for `nobrainer.processing.checkpoint`.""" 2 | 3 | import os 4 | 5 | import numpy as np 6 | from numpy.testing import assert_allclose 7 | import tensorflow as tf 8 | 9 | from nobrainer.dataset import Dataset 10 | from nobrainer.models import meshnet 11 | from nobrainer.processing.segmentation import Segmentation 12 | 13 | 14 | def _get_toy_dataset(): 15 | data_shape = (8, 8, 8, 8, 1) 16 | train = tf.data.Dataset.from_tensors( 17 | (np.random.rand(*data_shape), np.random.randint(0, 1, data_shape)) 18 | ) 19 | return Dataset(train, data_shape[0], data_shape[1:4], 1) 20 | 21 | 22 | def _assert_model_weights_allclose(model1, model2): 23 | for layer1, layer2 in zip(model1.model.layers, model2.model.layers): 24 | weights1 = layer1.get_weights() 25 | weights2 = layer2.get_weights() 26 | assert len(weights1) == len(weights2) 27 | for index in range(len(weights1)): 28 | assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08) 29 | 30 | 31 | def test_checkpoint(tmp_path): 32 | train = _get_toy_dataset() 33 | 34 | checkpoint_filepath = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}") 35 | model1 = Segmentation.init_with_checkpoints( 36 | meshnet, 37 | checkpoint_filepath=checkpoint_filepath, 38 | ) 39 | model1.fit( 40 | dataset_train=train, 41 | epochs=2, 42 | ) 43 | 44 | model2 = Segmentation.init_with_checkpoints( 45 | meshnet, 46 | checkpoint_filepath=checkpoint_filepath, 47 | ) 48 | _assert_model_weights_allclose(model1, model2) 49 | model2.fit( 50 | dataset_train=train, 51 | epochs=3, 52 | ) 53 | 54 | model3 = Segmentation.init_with_checkpoints( 55 | meshnet, 56 | checkpoint_filepath=checkpoint_filepath, 57 | ) 58 | _assert_model_weights_allclose(model2, model3) 59 | 60 | 61 | def test_warm_start_workflow(tmp_path): 62 | train = _get_toy_dataset() 63 | 64 | checkpoint_dir = os.path.join(tmp_path, "checkpoints") 65 | checkpoint_filepath = os.path.join(checkpoint_dir, "{epoch:03d}") 66 | if not os.path.exists(checkpoint_dir): 67 | os.mkdir(checkpoint_dir) 68 | 69 | for iteration in range(2): 70 | bem = Segmentation.init_with_checkpoints( 71 | meshnet, 72 | checkpoint_filepath=checkpoint_filepath, 73 | ) 74 | if iteration == 0: 75 | assert bem.model is None 76 | else: 77 | assert bem.model is not None 78 | for layer in bem.model.layers: 79 | for weight_array in layer.get_weights(): 80 | assert np.count_nonzero(weight_array) 81 | bem.fit( 82 | dataset_train=train, 83 | epochs=2, 84 | ) 85 | -------------------------------------------------------------------------------- /nobrainer/tests/dataset_test.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as op 3 | import shutil 4 | import tempfile 5 | 6 | import nibabel as nib 7 | import numpy as np 8 | from numpy.testing import assert_array_equal 9 | import pytest 10 | 11 | from .. import dataset, intensity_transforms, io, spatial_transforms, tfrecord, utils 12 | 13 | 14 | @pytest.fixture(scope="session") 15 | def tmp_data_filepaths(): 16 | temp_dir = tempfile.mkdtemp() 17 | csv_of_filepaths = utils.get_data(cache_dir=temp_dir) 18 | filepaths = io.read_csv(csv_of_filepaths) 19 | yield filepaths 20 | shutil.rmtree(temp_dir) 21 | 22 | 23 | def write_tfrecs(filepaths, outdir, examples_per_shard): 24 | tfrecord.write( 25 | features_labels=filepaths, 26 | filename_template=op.join(outdir, "data_shard-{shard:03d}.tfrec"), 27 | examples_per_shard=examples_per_shard, 28 | ) 29 | return op.join(outdir, "data_shard-*.tfrec") 30 | 31 | 32 | @pytest.mark.parametrize("examples_per_shard", [1, 3]) 33 | @pytest.mark.parametrize("batch_size", [1, 2]) 34 | @pytest.mark.parametrize("num_parallel_calls", [1, 2]) 35 | def test_get_dataset_maintains_order( 36 | tmp_data_filepaths, examples_per_shard, batch_size, num_parallel_calls 37 | ): 38 | filepaths = [(x, i) for i, (x, _) in enumerate(tmp_data_filepaths)] 39 | temp_dir = tempfile.mkdtemp() 40 | file_pattern = write_tfrecs( 41 | filepaths, temp_dir, examples_per_shard=examples_per_shard 42 | ) 43 | volume_shape = (256, 256, 256) 44 | dset = dataset.Dataset.from_tfrecords( 45 | file_pattern=file_pattern, 46 | n_volumes=10, 47 | volume_shape=volume_shape, 48 | scalar_labels=True, 49 | n_classes=1, 50 | num_parallel_calls=num_parallel_calls, 51 | ).batch(batch_size) 52 | 53 | y_orig = np.array([y for _, y in filepaths]) 54 | y_from_dset = ( 55 | np.concatenate([y for _, y in dset.dataset.as_numpy_iterator()]) 56 | .flatten() 57 | .astype(int) 58 | ) 59 | assert_array_equal(y_orig, y_from_dset) 60 | shutil.rmtree(temp_dir) 61 | 62 | 63 | def create_dummy_niftis(shape, n_volumes, outdir): 64 | array_data = np.zeros(shape, dtype=np.int16) 65 | affine = np.diag([1, 2, 3, 1]) 66 | array_img = nib.Nifti1Image(array_data, affine) 67 | for volume in range(n_volumes): 68 | nib.save(array_img, op.join(outdir, f"image_{volume}.nii.gz")) 69 | 70 | return glob.glob(op.join(outdir, "image_*.nii.gz")) 71 | 72 | 73 | def test_get_dataset_errors(): 74 | temp_dir = tempfile.mkdtemp() 75 | file_pattern = op.join(temp_dir, "does_not_exist-*.tfrec") 76 | with pytest.raises(ValueError): 77 | dataset.Dataset.from_tfrecords( 78 | file_pattern, 79 | None, 80 | (256, 256, 256), 81 | n_classes=1, 82 | ) 83 | 84 | 85 | @pytest.mark.parametrize("batch_size", [1, 2]) 86 | @pytest.mark.parametrize("examples_per_shard", [1, 3]) 87 | @pytest.mark.parametrize("volume_shape", [(64, 64, 64), (64, 64, 64, 3)]) 88 | @pytest.mark.parametrize("num_parallel_calls", [1, 2]) 89 | def test_get_dataset_shapes( 90 | volume_shape, examples_per_shard, batch_size, num_parallel_calls 91 | ): 92 | temp_dir = tempfile.mkdtemp() 93 | nifti_paths = create_dummy_niftis(volume_shape, 10, temp_dir) 94 | filepaths = [(x, i) for i, x in enumerate(nifti_paths)] 95 | file_pattern = write_tfrecs( 96 | filepaths, temp_dir, examples_per_shard=examples_per_shard 97 | ) 98 | dset = dataset.Dataset.from_tfrecords( 99 | file_pattern=file_pattern, 100 | n_volumes=len(filepaths), 101 | volume_shape=volume_shape, 102 | scalar_labels=True, 103 | n_classes=1, 104 | num_parallel_calls=num_parallel_calls, 105 | ).batch(batch_size) 106 | 107 | output_volume_shape = volume_shape if len(volume_shape) > 3 else volume_shape + (1,) 108 | output_volume_shape = (batch_size,) + output_volume_shape 109 | shapes = [x.shape for x, _ in dset.dataset.as_numpy_iterator()] 110 | assert all([_shape == output_volume_shape for _shape in shapes]) 111 | shutil.rmtree(temp_dir) 112 | 113 | 114 | def test_get_dataset_errors_augmentation(): 115 | temp_dir = tempfile.mkdtemp() 116 | file_pattern = op.join(temp_dir, "does_not_exist-*.tfrec") 117 | with pytest.raises(ValueError): 118 | dataset.Dataset.from_tfrecords( 119 | file_pattern=file_pattern, 120 | n_volumes=10, 121 | volume_shape=(256, 256, 256), 122 | n_classes=1, 123 | ).augment = [ 124 | ( 125 | intensity_transforms.addGaussianNoise, 126 | {"noise_mean": 0.1, "noise_std": 0.5}, 127 | ), 128 | (spatial_transforms.randomflip_leftright), 129 | ] 130 | shutil.rmtree(temp_dir) 131 | 132 | 133 | # TODO: need to implement this soon. 134 | @pytest.mark.xfail 135 | def test_get_dataset(): 136 | assert False 137 | 138 | 139 | def test_get_steps_per_epoch(): 140 | volume_shape = (256, 256, 256) 141 | temp_dir = tempfile.mkdtemp() 142 | nifti_paths = create_dummy_niftis(volume_shape, 10, temp_dir) 143 | filepaths = [(x, i) for i, x in enumerate(nifti_paths)] 144 | file_pattern = write_tfrecs(filepaths, temp_dir, examples_per_shard=1) 145 | dset = dataset.Dataset.from_tfrecords( 146 | file_pattern=file_pattern.replace("*", "000"), 147 | n_volumes=1, 148 | volume_shape=volume_shape, 149 | block_shape=(64, 64, 64), 150 | scalar_labels=True, 151 | n_classes=1, 152 | ) 153 | assert dset.get_steps_per_epoch() == 64 154 | 155 | dset = dataset.Dataset.from_tfrecords( 156 | file_pattern=file_pattern.replace("*", "000"), 157 | n_volumes=1, 158 | volume_shape=volume_shape, 159 | block_shape=(64, 64, 64), 160 | scalar_labels=True, 161 | n_classes=1, 162 | ).batch(64) 163 | assert dset.get_steps_per_epoch() == 1 164 | 165 | dset = dataset.Dataset.from_tfrecords( 166 | file_pattern=file_pattern.replace("*", "000"), 167 | n_volumes=1, 168 | volume_shape=volume_shape, 169 | block_shape=(64, 64, 64), 170 | scalar_labels=True, 171 | n_classes=1, 172 | ).batch(63) 173 | assert dset.get_steps_per_epoch() == 2 174 | 175 | dset = dataset.Dataset.from_tfrecords( 176 | file_pattern=file_pattern, 177 | n_volumes=10, 178 | volume_shape=volume_shape, 179 | block_shape=(128, 128, 128), 180 | scalar_labels=True, 181 | n_classes=1, 182 | ).batch(4) 183 | assert dset.get_steps_per_epoch() == 20 184 | 185 | shutil.rmtree(temp_dir) 186 | -------------------------------------------------------------------------------- /nobrainer/tests/io_test.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import tempfile 3 | 4 | from fsspec.implementations.local import LocalFileSystem 5 | import nibabel as nib 6 | import numpy as np 7 | import pytest 8 | 9 | from .utils import csv_of_volumes # noqa: F401 10 | from .. import io 11 | 12 | 13 | def test_read_csv(): 14 | with tempfile.NamedTemporaryFile() as f: 15 | f.write("foo,bar\nbaz,boo".encode()) 16 | f.seek(0) 17 | assert [("foo", "bar"), ("baz", "boo")] == io.read_csv( 18 | f.name, skip_header=False 19 | ) 20 | 21 | with tempfile.NamedTemporaryFile() as f: 22 | f.write("foo,bar\nbaz,boo".encode()) 23 | f.seek(0) 24 | assert [("baz", "boo")] == io.read_csv(f.name, skip_header=True) 25 | 26 | with tempfile.NamedTemporaryFile() as f: 27 | f.write("foo,bar\nbaz,boo".encode()) 28 | f.seek(0) 29 | assert [("baz", "boo")] == io.read_csv(f.name) 30 | 31 | with tempfile.NamedTemporaryFile() as f: 32 | f.write("foo|bar\nbaz|boo".encode()) 33 | f.seek(0) 34 | assert [("baz", "boo")] == io.read_csv(f.name, delimiter="|") 35 | 36 | 37 | def test_read_mapping(): 38 | with tempfile.NamedTemporaryFile() as f: 39 | f.write("orig,new\n0,1\n20,10\n40,15".encode()) 40 | f.seek(0) 41 | assert {0: 1, 20: 10, 40: 15} == io.read_mapping(f.name, skip_header=True) 42 | # Header is non-integer. 43 | with pytest.raises(ValueError): 44 | io.read_mapping(f.name, skip_header=False) 45 | 46 | with tempfile.NamedTemporaryFile() as f: 47 | f.write("orig,new\n0,1\n20,10\n40".encode()) 48 | f.seek(0) 49 | # Last row only has one value. 50 | with pytest.raises(ValueError): 51 | io.read_mapping(f.name, skip_header=False) 52 | 53 | with tempfile.NamedTemporaryFile() as f: 54 | f.write("origFnew\n0F1\n20F10\n40F15".encode()) 55 | f.seek(0) 56 | assert {0: 1, 20: 10, 40: 15} == io.read_mapping( 57 | f.name, skip_header=True, delimiter="F" 58 | ) 59 | 60 | 61 | def test_read_volume(tmp_path): 62 | data = np.random.rand(8, 8, 8).astype(np.float32) 63 | affine = np.eye(4) 64 | 65 | filename = str(tmp_path / "foo.nii.gz") 66 | nib.save(nib.Nifti1Image(data, affine), filename) 67 | data_loaded = io.read_volume(filename) 68 | assert np.array_equal(data, data_loaded) 69 | 70 | data_loaded = io.read_volume(filename, dtype=data.dtype) 71 | assert data.dtype == data_loaded.dtype 72 | 73 | data_loaded, affine_loaded = io.read_volume(filename, return_affine=True) 74 | assert np.array_equal(data, data_loaded) 75 | assert np.array_equal(affine, affine_loaded) 76 | 77 | data = np.random.rand(8, 8, 8).astype(np.float32) 78 | affine = np.array([[1.5, 0, 1.2, 0], [0.8, 0.8, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) 79 | filename = str(tmp_path / "foo_asr.nii.gz") 80 | nib.save(nib.Nifti1Image(data, affine), filename) 81 | data_loaded = io.read_volume(filename, to_ras=True) 82 | assert not np.array_equal(data, data_loaded) 83 | data_loaded = io.read_volume(filename, to_ras=False) 84 | assert np.array_equal(data, data_loaded) 85 | 86 | 87 | def test_verify_features_nonscalar_labels(csv_of_volumes): # noqa: F811 88 | files = io.read_csv(csv_of_volumes, skip_header=False) 89 | invalid = io.verify_features_labels( 90 | files, volume_shape=(8, 8, 8), num_parallel_calls=1 91 | ) 92 | assert not invalid 93 | # TODO: add more cases. 94 | 95 | 96 | def test_verify_features_scalar_labels(csv_of_volumes): # noqa: F811 97 | files = io.read_csv(csv_of_volumes, skip_header=False) 98 | # Int labels. 99 | files = [(x, 0) for (x, _) in files] 100 | invalid = io.verify_features_labels( 101 | files, volume_shape=(8, 8, 8), num_parallel_calls=1 102 | ) 103 | assert not invalid 104 | invalid = io.verify_features_labels( 105 | files, volume_shape=(12, 12, 8), num_parallel_calls=1 106 | ) 107 | assert all(invalid) 108 | # Float labels. 109 | files = [(x, 1.0) for (x, _) in files] 110 | invalid = io.verify_features_labels( 111 | files, volume_shape=(8, 8, 8), num_parallel_calls=1 112 | ) 113 | assert not invalid 114 | invalid = io.verify_features_labels( 115 | files, volume_shape=(12, 12, 8), num_parallel_calls=1 116 | ) 117 | assert all(invalid) 118 | 119 | 120 | @pytest.mark.parametrize("filesys", [None, LocalFileSystem()]) 121 | def test_is_gzipped(tmp_path, filesys): 122 | filename = str(tmp_path / "test.gz") 123 | with gzip.GzipFile(filename, "w") as f: 124 | f.write("i'm more than a test!".encode()) 125 | assert io._is_gzipped(filename, filesys=filesys) 126 | 127 | with open(filename, "w") as f: 128 | f.write("i'm just a test...") 129 | assert not io._is_gzipped(filename, filesys=filesys) 130 | -------------------------------------------------------------------------------- /nobrainer/tests/losses_test.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | 3 | import numpy as np 4 | from numpy.testing import assert_allclose, assert_array_equal 5 | import pytest 6 | import scipy.spatial.distance 7 | import tensorflow as tf 8 | 9 | from .. import losses 10 | 11 | 12 | def test_dice(): 13 | x = np.zeros(4) 14 | y = np.zeros(4) 15 | out = losses.dice(x, y, axis=None).numpy() 16 | assert_allclose(out, 0) 17 | 18 | x = np.ones(4) 19 | y = np.ones(4) 20 | out = losses.dice(x, y, axis=None).numpy() 21 | assert_allclose(out, 0) 22 | 23 | x = [0.0, 0.0, 1.0, 1.0] 24 | y = [1.0, 1.0, 1.0, 1.0] 25 | out = losses.dice(x, y, axis=None).numpy() 26 | ref = scipy.spatial.distance.dice(x, y) 27 | assert_allclose(out, ref) 28 | 29 | x = [0.0, 0.0, 1.0, 1.0] 30 | y = [1.0, 1.0, 0.0, 0.0] 31 | out = losses.dice(x, y, axis=None).numpy() 32 | ref = scipy.spatial.distance.dice(x, y) 33 | assert_allclose(out, ref) 34 | assert_allclose(out, 1) 35 | 36 | x = np.ones((4, 32, 32, 32, 1), dtype=np.float32) 37 | y = x.copy() 38 | x[:2, :10, 10:] = 0 39 | y[:2, :3, 20:] = 0 40 | y[3:, 10:] = 0 41 | dices = np.empty(x.shape[0]) 42 | for i in range(x.shape[0]): 43 | dices[i] = scipy.spatial.distance.dice(x[i].flatten(), y[i].flatten()) 44 | assert_allclose(losses.dice(x, y, axis=(1, 2, 3, 4)), dices, rtol=1e-05) 45 | assert_allclose(losses.Dice(axis=(1, 2, 3, 4))(x, y), dices.mean(), rtol=1e-05) 46 | assert_allclose(losses.Dice(axis=(1, 2, 3, 4))(y, x), dices.mean(), rtol=1e-05) 47 | 48 | 49 | def test_generalized_dice(): 50 | shape = (8, 32, 32, 32, 16) 51 | x = np.zeros(shape) 52 | y = np.zeros(shape) 53 | assert_array_equal(losses.generalized_dice(x, y), np.zeros(shape[0])) 54 | 55 | shape = (8, 32, 32, 32, 16) 56 | x = np.ones(shape) 57 | y = np.ones(shape) 58 | assert_array_equal(losses.generalized_dice(x, y), np.zeros(shape[0])) 59 | 60 | shape = (8, 32, 32, 32, 16) 61 | x = np.ones(shape) 62 | y = np.zeros(shape) 63 | # Why aren't the losses exactly one? Could it be the propagation of floating 64 | # point inaccuracies when summing? 65 | assert_allclose(losses.generalized_dice(x, y), np.ones(shape[0]), atol=1e-03) 66 | assert_allclose( 67 | losses.GeneralizedDice(axis=(1, 2, 3))(x, y), losses.generalized_dice(x, y) 68 | ) 69 | 70 | x = np.ones((4, 32, 32, 32, 1), dtype=np.float64) 71 | y = x.copy() 72 | x[:2, :10, 10:] = 0 73 | y[:2, :3, 20:] = 0 74 | y[3:, 10:] = 0 75 | # Dice is similar to generalized Dice for one class. The weight factor 76 | # makes the generalized form slightly different from Dice. 77 | gd = losses.generalized_dice(x, y, axis=(1, 2, 3)).numpy() 78 | dd = losses.dice(x, y, axis=(1, 2, 3, 4)).numpy() 79 | assert_allclose(gd, dd, rtol=1e-02) # is this close enough? 80 | 81 | 82 | def test_jaccard(): 83 | x = np.zeros(4) 84 | y = np.zeros(4) 85 | out = losses.jaccard(x, y, axis=None).numpy() 86 | assert_allclose(out, 0) 87 | 88 | x = np.ones(4) 89 | y = np.ones(4) 90 | out = losses.jaccard(x, y, axis=None).numpy() 91 | assert_allclose(out, 0) 92 | 93 | x = [0.0, 0.0, 1.0, 1.0] 94 | y = [1.0, 1.0, 1.0, 1.0] 95 | out = losses.jaccard(x, y, axis=None).numpy() 96 | ref = scipy.spatial.distance.jaccard(x, y) 97 | assert_allclose(out, ref) 98 | 99 | x = [0.0, 0.0, 1.0, 1.0] 100 | y = [1.0, 1.0, 0.0, 0.0] 101 | out = losses.jaccard(x, y, axis=None).numpy() 102 | ref = scipy.spatial.distance.jaccard(x, y) 103 | assert_allclose(out, ref) 104 | assert_allclose(out, 1) 105 | 106 | x = np.ones((4, 32, 32, 32, 1), dtype=np.float32) 107 | y = x.copy() 108 | x[:2, :10, 10:] = 0 109 | y[:2, :3, 20:] = 0 110 | y[3:, 10:] = 0 111 | jaccards = np.empty(x.shape[0]) 112 | for i in range(x.shape[0]): 113 | jaccards[i] = scipy.spatial.distance.jaccard(x[i].flatten(), y[i].flatten()) 114 | assert_allclose(losses.jaccard(x, y, axis=(1, 2, 3, 4)), jaccards) 115 | assert_allclose(losses.Jaccard(axis=(1, 2, 3, 4))(x, y), jaccards.mean()) 116 | assert_allclose(losses.Jaccard(axis=(1, 2, 3, 4))(y, x), jaccards.mean()) 117 | 118 | 119 | @pytest.mark.xfail 120 | def test_tversky(): 121 | # TODO: write the test 122 | assert False 123 | 124 | 125 | @pytest.mark.xfail 126 | def test_elbo(): 127 | # TODO: write the test 128 | assert False 129 | 130 | 131 | def test_wasserstein(): 132 | x = np.zeros(4) 133 | y = np.zeros(4) 134 | out = losses.wasserstein(x, y) 135 | assert_allclose(out, 0) 136 | 137 | x = np.ones(4) 138 | y = np.ones(4) 139 | out = losses.wasserstein(x, y) 140 | assert_allclose(out, 1) 141 | 142 | x = np.array([0.0, -1.0, 1.0, -1.0]) 143 | y = np.array([1.0, -1.0, 1.0, 1.0]) 144 | out = losses.wasserstein(x, y) 145 | ref = [0.0, 1.0, 1.0, -1.0] 146 | assert_allclose(out, ref) 147 | 148 | x = np.array([0.0, 0.0, 1.0, 1.0]) 149 | y = np.array([1.0, 1.0, 0.0, 0.0]) 150 | out = losses.wasserstein(x, y) 151 | assert_allclose(out, 0) 152 | 153 | 154 | def test_gradient_penalty(): 155 | x = np.zeros(4) 156 | y = np.zeros(4) 157 | out = losses.gradient_penalty(x, y) 158 | assert_allclose(out, 10) 159 | 160 | x = np.ones(4) 161 | y = np.ones(4) 162 | out = losses.gradient_penalty(x, y) 163 | assert_allclose(out, 0.001) 164 | 165 | x = np.array([0.0, -1.0, 1.0, -1.0]) 166 | y = np.array([1.0, -1.0, 1.0, 1.0]) 167 | out = losses.gradient_penalty(x, y) 168 | ref = [1.0001e01, 1.0000e-03, 1.0000e-03, 1.0000e-03] 169 | assert_allclose(out, ref) 170 | 171 | x = np.array([0.0, 0.0, 1.0, 1.0]) 172 | y = np.array([1.0, 1.0, 0.0, 0.0]) 173 | out = losses.gradient_penalty(x, y) 174 | ref = [10.001, 10.001, 0.0, 0.0] 175 | assert_allclose(out, ref) 176 | 177 | 178 | def test_get(): 179 | if LooseVersion(tf.__version__) < LooseVersion("1.14.1-dev20190408"): 180 | assert losses.get("dice") is losses.dice 181 | assert losses.get("Dice") is losses.Dice 182 | assert losses.get("jaccard") is losses.jaccard 183 | assert losses.get("Jaccard") is losses.Jaccard 184 | assert losses.get("tversky") is losses.tversky 185 | assert losses.get("Tversky") is losses.Tversky 186 | assert losses.get("binary_crossentropy") 187 | else: 188 | assert losses.get("dice") is losses.dice 189 | assert isinstance(losses.get("Dice"), losses.Dice) 190 | assert losses.get("jaccard") is losses.jaccard 191 | assert isinstance(losses.get("Jaccard"), losses.Jaccard) 192 | assert losses.get("tversky") is losses.tversky 193 | assert isinstance(losses.get("Tversky"), losses.Tversky) 194 | assert losses.get("binary_crossentropy") 195 | assert losses.get("gradient_penalty") is losses.gradient_penalty 196 | assert losses.get("wasserstein") is losses.wasserstein 197 | assert isinstance(losses.get("Wasserstein"), losses.Wasserstein) 198 | -------------------------------------------------------------------------------- /nobrainer/tests/metrics_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_allclose, assert_array_equal 3 | import pytest 4 | import scipy.spatial.distance 5 | 6 | from .. import metrics 7 | 8 | 9 | def test_dice(): 10 | x = np.zeros(4) 11 | y = np.zeros(4) 12 | out = metrics.dice(x, y, axis=None).numpy() 13 | assert_allclose(out, 1) 14 | 15 | x = np.ones(4) 16 | y = np.ones(4) 17 | out = metrics.dice(x, y, axis=None).numpy() 18 | assert_allclose(out, 1) 19 | 20 | x = [0.0, 0.0, 1.0, 1.0] 21 | y = [1.0, 1.0, 1.0, 1.0] 22 | out = metrics.dice(x, y, axis=None).numpy() 23 | ref = 1.0 - scipy.spatial.distance.dice(x, y) 24 | assert_allclose(out, ref) 25 | jac_out = metrics.jaccard(x, y, axis=None).numpy() 26 | assert_allclose(out, 2.0 * jac_out / (1.0 + jac_out)) 27 | 28 | x = [0.0, 0.0, 1.0, 1.0] 29 | y = [1.0, 1.0, 0.0, 0.0] 30 | out = metrics.dice(x, y, axis=None).numpy() 31 | ref = 1.0 - scipy.spatial.distance.dice(x, y) 32 | assert_allclose(out, ref, atol=1e-07) 33 | assert_allclose(out, 0, atol=1e-07) 34 | 35 | x = np.ones((4, 32, 32, 32, 1), dtype=np.float32) 36 | y = x.copy() 37 | x[:2, :10, 10:] = 0 38 | y[:2, :3, 20:] = 0 39 | y[3:, 10:] = 0 40 | dices = np.empty(x.shape[0]) 41 | for i in range(x.shape[0]): 42 | dices[i] = 1.0 - scipy.spatial.distance.dice(x[i].flatten(), y[i].flatten()) 43 | assert_allclose(metrics.dice(x, y, axis=(1, 2, 3, 4)), dices) 44 | 45 | 46 | def test_generalized_dice(): 47 | shape = (8, 32, 32, 32, 16) 48 | x = np.zeros(shape) 49 | y = np.zeros(shape) 50 | assert_array_equal(metrics.generalized_dice(x, y), np.ones(shape[0])) 51 | 52 | shape = (8, 32, 32, 32, 16) 53 | x = np.ones(shape) 54 | y = np.ones(shape) 55 | assert_array_equal(metrics.generalized_dice(x, y), np.ones(shape[0])) 56 | 57 | shape = (8, 32, 32, 32, 16) 58 | x = np.ones(shape) 59 | y = np.zeros(shape) 60 | # Why aren't the scores exactly zero? Could it be the propagation of floating 61 | # point inaccuracies when summing? 62 | assert_allclose(metrics.generalized_dice(x, y), np.zeros(shape[0]), atol=1e-03) 63 | 64 | x = np.ones((4, 32, 32, 32, 1), dtype=np.float64) 65 | y = x.copy() 66 | x[:2, :10, 10:] = 0 67 | y[:2, :3, 20:] = 0 68 | y[3:, 10:] = 0 69 | # Dice is similar to generalized Dice for one class. The weight factor 70 | # makes the generalized form slightly different from Dice. 71 | gd = metrics.generalized_dice(x, y, axis=(1, 2, 3)).numpy() 72 | dd = metrics.dice(x, y, axis=(1, 2, 3, 4)).numpy() 73 | assert_allclose(gd, dd, rtol=1e-02) # is this close enough? 74 | 75 | 76 | def test_jaccard(): 77 | x = np.zeros(4) 78 | y = np.zeros(4) 79 | out = metrics.jaccard(x, y, axis=None).numpy() 80 | assert_allclose(out, 1) 81 | 82 | x = np.ones(4) 83 | y = np.ones(4) 84 | out = metrics.jaccard(x, y, axis=None).numpy() 85 | assert_allclose(out, 1) 86 | 87 | x = [0.0, 0.0, 1.0, 1.0] 88 | y = [1.0, 1.0, 1.0, 1.0] 89 | out = metrics.jaccard(x, y, axis=None).numpy() 90 | ref = 1.0 - scipy.spatial.distance.jaccard(x, y) 91 | assert_allclose(out, ref) 92 | dice_out = metrics.dice(x, y, axis=None).numpy() 93 | assert_allclose(out, dice_out / (2.0 - dice_out)) 94 | 95 | x = [0.0, 0.0, 1.0, 1.0] 96 | y = [1.0, 1.0, 0.0, 0.0] 97 | out = metrics.jaccard(x, y, axis=None).numpy() 98 | ref = 1.0 - scipy.spatial.distance.jaccard(x, y) 99 | assert_allclose(out, ref, atol=1e-07) 100 | assert_allclose(out, 0, atol=1e-07) 101 | 102 | x = np.ones((4, 32, 32, 32, 1), dtype=np.float32) 103 | y = x.copy() 104 | x[:2, :10, 10:] = 0 105 | y[:2, :3, 20:] = 0 106 | y[3:, 10:] = 0 107 | jaccards = np.empty(x.shape[0]) 108 | for i in range(x.shape[0]): 109 | jaccards[i] = 1.0 - scipy.spatial.distance.jaccard( 110 | x[i].flatten(), y[i].flatten() 111 | ) 112 | assert_allclose(metrics.jaccard(x, y, axis=(1, 2, 3, 4)), jaccards) 113 | 114 | 115 | def test_tversky(): 116 | shape = (4, 32, 32, 32, 1) 117 | y_pred = np.random.rand(*shape).astype(np.float64) 118 | y_true = np.random.randint(2, size=shape).astype(np.float64) 119 | 120 | # Test that tversky and dice are same when alpha = beta = 0.5 121 | dice = metrics.dice(y_true, y_pred).numpy() 122 | tversky = metrics.tversky( 123 | y_true, y_pred, axis=(1, 2, 3), alpha=0.5, beta=0.5 124 | ).numpy() 125 | assert_allclose(dice, tversky) 126 | 127 | # Test that tversky and jaccard are same when alpha = beta = 1.0 128 | jaccard = metrics.jaccard(y_true, y_pred).numpy() 129 | tversky = metrics.tversky( 130 | y_true, y_pred, axis=(1, 2, 3), alpha=1.0, beta=1.0 131 | ).numpy() 132 | assert_allclose(jaccard, tversky) 133 | 134 | with pytest.raises(ValueError): 135 | metrics.tversky([0.0, 0.0, 1.0], [1.0, 0.0, 1.0], axis=0) 136 | -------------------------------------------------------------------------------- /nobrainer/tests/prediction_test.py: -------------------------------------------------------------------------------- 1 | """Tests for `nobrainer.prediction`.""" 2 | 3 | import nibabel as nib 4 | import numpy as np 5 | from numpy.testing import assert_array_equal 6 | import pytest 7 | import tensorflow as tf 8 | 9 | from .. import prediction 10 | from ..models.bayesian_meshnet import variational_meshnet 11 | from ..models.meshnet import meshnet 12 | 13 | 14 | def test_predict(tmp_path): 15 | x = np.ones((4, 4, 4)) 16 | img = nib.Nifti1Image(x, affine=np.eye(4)) 17 | path = str(tmp_path / "features.nii.gz") 18 | img.to_filename(path) 19 | 20 | x2 = x * -50 21 | img2 = nib.Nifti1Image(x2, affine=np.eye(4)) 22 | path2 = str(tmp_path / "features2.nii.gz") 23 | img2.to_filename(path2) 24 | 25 | model = meshnet(1, (*x.shape, 1), receptive_field=37) 26 | 27 | # From array. 28 | y_ = prediction.predict_from_array(x, model=model, block_shape=None) 29 | y_blocks = prediction.predict_from_array(x, model=model, block_shape=x.shape) 30 | y_other = prediction.predict(x, model=model, block_shape=None) 31 | assert isinstance(y_, np.ndarray) 32 | assert_array_equal(y_.shape, *x.shape) 33 | assert_array_equal(y_, y_other) 34 | assert_array_equal(y_, y_blocks) 35 | 36 | # From image. 37 | y_img = prediction.predict_from_img(img, model=model, block_shape=None) 38 | y_img_other = prediction.predict(img, model=model, block_shape=None) 39 | assert isinstance(y_img, nib.spatialimages.SpatialImage) 40 | assert_array_equal(y_img.shape, x.shape) 41 | assert_array_equal(y_img.get_fdata(caching="unchanged"), y_) 42 | assert_array_equal( 43 | y_img.get_fdata(caching="unchanged"), y_img_other.get_fdata(caching="unchanged") 44 | ) 45 | 46 | # From filepath 47 | y_img2 = prediction.predict_from_filepath(path, model=model, block_shape=None) 48 | y_img2_other = prediction.predict(path, model=model, block_shape=None) 49 | assert isinstance(y_img, nib.spatialimages.SpatialImage) 50 | assert_array_equal(y_img.shape, x.shape) 51 | assert_array_equal(y_img.get_fdata(caching="unchanged"), y_) 52 | assert_array_equal( 53 | y_img2.get_fdata(caching="unchanged"), 54 | y_img2_other.get_fdata(caching="unchanged"), 55 | ) 56 | 57 | # From filepaths 58 | gen = prediction.predict_from_filepaths( 59 | [path, path2], model=model, block_shape=None 60 | ) 61 | y_img3 = next(gen) 62 | y_img4 = next(gen) 63 | gen_other = prediction.predict([path, path2], model=model, block_shape=None) 64 | y_img3_other = next(gen_other) 65 | y_img4_other = next(gen_other) 66 | 67 | assert_array_equal( 68 | y_img2.get_fdata(caching="unchanged"), y_img3.get_fdata(caching="unchanged") 69 | ) 70 | assert_array_equal( 71 | y_img3.get_fdata(caching="unchanged"), 72 | y_img3_other.get_fdata(caching="unchanged"), 73 | ) 74 | assert_array_equal( 75 | y_img4.get_fdata(caching="unchanged"), 76 | y_img4_other.get_fdata(caching="unchanged"), 77 | ) 78 | assert_array_equal(y_img3.shape, x.shape) 79 | 80 | 81 | def test_variational_predict(tmp_path): 82 | x = np.ones((4, 4, 4)) 83 | img = nib.Nifti1Image(x, affine=np.eye(4)) 84 | path = str(tmp_path / "features.nii.gz") 85 | img.to_filename(path) 86 | 87 | x2 = x * -50 88 | img2 = nib.Nifti1Image(x2, affine=np.eye(4)) 89 | path2 = str(tmp_path / "features2.nii.gz") 90 | img2.to_filename(path2) 91 | 92 | model = variational_meshnet(1, (*x.shape, 1), receptive_field=37) 93 | 94 | # From array. 95 | mean, var, entropy = prediction.predict_from_array( 96 | x, 97 | model=model, 98 | block_shape=None, 99 | n_samples=2, 100 | return_variance=True, 101 | return_entropy=True, 102 | ) 103 | # y_blocks = prediction.predict_from_array(x, model=model, block_shape=x.shape) 104 | # y_other = prediction.predict(x, model=model, block_shape=None) 105 | assert isinstance(mean, np.ndarray) 106 | assert isinstance(var, np.ndarray) 107 | assert isinstance(entropy, np.ndarray) 108 | assert_array_equal(mean.shape, *x.shape) 109 | assert_array_equal(var.shape, *x.shape) 110 | assert_array_equal(entropy.shape, *x.shape) 111 | # assert_array_equal(y_, y_other) 112 | # assert_array_equal(y_, y_blocks) 113 | 114 | 115 | def test_get_model(tmp_path): 116 | model = meshnet(3, (10, 10, 10, 1), receptive_field=37) 117 | path = str(tmp_path / "model.h5") 118 | model.save(path) 119 | assert isinstance(prediction._get_model(path), tf.keras.Model) 120 | assert model is prediction._get_model(model) 121 | with pytest.raises(ValueError): 122 | prediction._get_model("not a model") 123 | 124 | 125 | @pytest.mark.xfail 126 | def test_transform_and_predict(tmp_path): 127 | assert False 128 | -------------------------------------------------------------------------------- /nobrainer/tests/test_intensity_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from nobrainer import intensity_transforms 5 | 6 | 7 | def test_addGaussianNoise(): 8 | shape = (10, 10, 10) 9 | x = np.ones(shape).astype(np.float32) 10 | y = np.random.randint(0, 2, size=shape).astype(np.float32) 11 | x_out = intensity_transforms.addGaussianNoise(x, noise_mean=0.0, noise_std=1) 12 | x_out = x_out.numpy() 13 | assert x_out.shape == x.shape 14 | assert np.sum(x_out - x) != 0 15 | 16 | # test if x and y undergoes same noiseshift 17 | x_out, y_out = intensity_transforms.addGaussianNoise( 18 | x, y, trans_xy=True, noise_mean=0.0, noise_std=1 19 | ) 20 | x_out = x_out.numpy() 21 | y_out = y_out.numpy() 22 | noise_y = y_out - y 23 | noise_x = x_out - x 24 | assert x_out.shape == x.shape 25 | assert y_out.shape == y.shape 26 | assert np.sum(noise_x - noise_y) < 1e-5 27 | 28 | # test sending y, but not transforming it 29 | x_out, y_out = intensity_transforms.addGaussianNoise( 30 | x, y, trans_xy=False, noise_mean=0.0, noise_std=1 31 | ) 32 | x_out = x_out.numpy() 33 | assert x_out.shape == x.shape 34 | assert np.sum(x_out - x) != 0 35 | assert y_out.shape == y.shape 36 | np.testing.assert_array_equal(y_out, y) 37 | 38 | 39 | def test_minmaxIntensityScaling(): 40 | x = np.random.rand(10, 10, 10).astype(np.float32) 41 | y = np.random.randint(0, 2, size=(10, 10, 10)).astype(np.float32) 42 | x_out, y_out = intensity_transforms.minmaxIntensityScaling(x, y, trans_xy=True) 43 | x_out = x_out.numpy() 44 | y_out = y_out.numpy() 45 | assert x_out.min() - 0.0 < 1e-5 46 | assert y_out.min() - 0.0 < 1e-5 47 | assert 1 - x_out.max() < 1e-5 48 | assert 1 - y_out.max() < 1e-5 49 | 50 | x_out, y_out = intensity_transforms.minmaxIntensityScaling(x, y, trans_xy=False) 51 | x_out = x_out.numpy() 52 | assert x_out.min() - 0.0 < 1e-5 53 | assert 1 - x_out.max() < 1e-5 54 | np.testing.assert_array_equal(y_out, y) 55 | 56 | 57 | def test_customIntensityScaling(): 58 | x = np.random.rand(10, 10, 10).astype(np.float32) 59 | y = np.random.randint(0, 2, size=(10, 10, 10)).astype(np.float32) 60 | x_out, y_out = intensity_transforms.customIntensityScaling( 61 | x, y, trans_xy=True, scale_x=[0, 100], scale_y=[0, 3] 62 | ) 63 | x_out = x_out.numpy() 64 | y_out = y_out.numpy() 65 | assert x_out.min() - 0.0 < 1e-5 66 | assert y_out.min() - 0.0 < 1e-5 67 | assert 100 - x_out.max() < 1e-5 68 | assert 3 - y_out.max() < 1e-5 69 | 70 | x_out, y_out = intensity_transforms.customIntensityScaling( 71 | x, y, trans_xy=False, scale_x=[0, 100], scale_y=[0, 3] 72 | ) 73 | x_out = x_out.numpy() 74 | assert x_out.min() - 0.0 < 1e-5 75 | assert 100 - x_out.max() < 1e-5 76 | np.testing.assert_array_equal(y_out, y) 77 | 78 | 79 | def test_intensityMasking(): 80 | mask_x = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]) 81 | x = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]) 82 | expected = np.array( 83 | [[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]] 84 | ) 85 | results = intensity_transforms.intensityMasking(x, mask_x=mask_x) 86 | results = tf.squeeze(results) 87 | np.testing.assert_allclose(results.numpy(), expected) 88 | 89 | y = np.random.rand(*x.shape) 90 | results, y_out = intensity_transforms.intensityMasking( 91 | x, mask_x=mask_x, y=y, trans_xy=False 92 | ) 93 | results = tf.squeeze(results) 94 | np.testing.assert_allclose(results.numpy(), expected) 95 | np.testing.assert_array_equal(y_out, y) 96 | 97 | 98 | def test_contrastAdjust(): 99 | gamma = 1.5 100 | epsilon = 1e-7 101 | x = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]) 102 | x_range = x.max() - x.min() 103 | expected = ( 104 | np.power(((x - x.min()) / float(x_range + epsilon)), gamma) * x_range + x.min() 105 | ) 106 | results = intensity_transforms.contrastAdjust(x, gamma=1.5) 107 | np.testing.assert_allclose(expected, results.numpy(), rtol=1e-05) 108 | 109 | y = np.random.rand(*x.shape) 110 | results, y_out = intensity_transforms.contrastAdjust( 111 | x, y, trans_xy=False, gamma=1.5 112 | ) 113 | np.testing.assert_allclose(expected, results.numpy(), rtol=1e-05) 114 | np.testing.assert_array_equal(y_out, y) 115 | -------------------------------------------------------------------------------- /nobrainer/tests/test_spatial_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from nobrainer import spatial_transforms as transformations 5 | 6 | 7 | @pytest.fixture(scope="session") 8 | def test_centercrop(): 9 | # Test for inputs 10 | shape = (10, 10, 10) 11 | x = np.ones(shape).astype(np.float32) 12 | y = np.random.randint(0, 2, size=shape).astype(np.float32) 13 | fine = int(x.shape[1]) 14 | x = transformations.centercrop(x, finesize=fine) 15 | x = x.numpy() 16 | # Test for output shapes 17 | assert x.shape[1] == fine & x.shape[0] == fine & x.shape[2] == shape[2] 18 | assert y.shape[1] == fine & y.shape[0] == fine & y.shape[2] == shape[2] 19 | 20 | # Test for passing y but not transoforming it 21 | shape = (10, 10, 10) 22 | x = np.ones(shape).astype(np.float32) 23 | y = np.random.randint(0, 2, size=shape).astype(np.float32) 24 | fine = int(x.shape[1]) 25 | x, y_out = transformations.centercrop(x, y, trans_xy=False, finesize=fine) 26 | x = x.numpy() 27 | # Test for output shapes 28 | assert x.shape[1] == fine & x.shape[0] == fine & x.shape[2] == shape[2] 29 | np.testing.assert_array_equal(y_out, y) 30 | 31 | # test for both x and y 32 | shape = (10, 10, 10) 33 | x = np.ones(shape).astype(np.float32) 34 | y = np.random.randint(0, 2, size=shape).astype(np.float32) 35 | fine = int(x.shape[1]) 36 | x, y = transformations.centercrop(x, y, fine, trans_xy=True) 37 | x = x.numpy() 38 | y = y.numpy() 39 | # Test for output shapes 40 | assert x.shape[1] == fine & x.shape[0] == fine & x.shape[2] == shape[2] 41 | assert y.shape[1] == fine & y.shape[0] == fine & y.shape[2] == shape[2] 42 | 43 | # Test for varying finesize 44 | shape = (10, 10, 10) 45 | x = np.ones(shape).astype(np.float32) 46 | y = np.random.randint(0, 2, size=shape).astype(np.float32) 47 | finesize = [128, 1] 48 | x1, y1 = transformations.centercrop(x, y, finesize[0], trans_xy=True) 49 | x2, y2 = transformations.centercrop(x, y, finesize[1], trans_xy=True) 50 | x1 = x1.numpy() 51 | x2 = x2.numpy() 52 | y1 = y1.numpy() 53 | y2 = y2.numpy() 54 | assert ( 55 | x1.shape[1] 56 | == min(shape[1], finesize[0]) & x1.shape[0] 57 | == min(shape[0], finesize[0]) & x1.shape[2] 58 | == shape[2] 59 | ) 60 | assert ( 61 | y1.shape[1] 62 | == min(shape[1], finesize[0]) & y1.shape[0] 63 | == min(shape[0], finesize[0]) & y1.shape[2] 64 | == shape[2] 65 | ) 66 | 67 | assert ( 68 | x2.shape[1] 69 | == min(shape[1], finesize[1]) & x2.shape[0] 70 | == min(shape[0], finesize[1]) 71 | ) 72 | assert ( 73 | y2.shape[1] 74 | == min(shape[1], finesize[1]) & y2.shape[0] 75 | == min(shape[0], finesize[1]) 76 | ) 77 | assert y2.shape[2] == shape[2] & x2.shape[2] == shape[2] 78 | 79 | 80 | def test_spatialConstantPadding(): 81 | x = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]) 82 | y = np.array([[[1, 0, 1], [0, 2, 2], [3, 3, 0]], [[4, 1, 4], [5, 0, 0], [0, 0, 0]]]) 83 | x_expected = np.array( 84 | [ 85 | [ 86 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 87 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 88 | [0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0], 89 | [0.0, 0.0, 2.0, 2.0, 2.0, 0.0, 0.0], 90 | [0.0, 0.0, 3.0, 3.0, 3.0, 0.0, 0.0], 91 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 92 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 93 | ], 94 | [ 95 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 96 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 97 | [0.0, 0.0, 4.0, 4.0, 4.0, 0.0, 0.0], 98 | [0.0, 0.0, 5.0, 5.0, 5.0, 0.0, 0.0], 99 | [0.0, 0.0, 6.0, 6.0, 6.0, 0.0, 0.0], 100 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 101 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 102 | ], 103 | ] 104 | ) 105 | y_expected = np.array( 106 | [ 107 | [ 108 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 109 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 110 | [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], 111 | [0.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0], 112 | [0.0, 0.0, 3.0, 3.0, 0.0, 0.0, 0.0], 113 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 114 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 115 | ], 116 | [ 117 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 118 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 119 | [0.0, 0.0, 4.0, 1.0, 4.0, 0.0, 0.0], 120 | [0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0], 121 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 122 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 123 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 124 | ], 125 | ] 126 | ) 127 | resultx, resulty = transformations.spatialConstantPadding( 128 | x, y, trans_xy=True, padding_zyx=[0, 2, 2] 129 | ) 130 | np.testing.assert_allclose(x_expected, resultx.numpy()) 131 | np.testing.assert_allclose(y_expected, resulty.numpy()) 132 | 133 | resultx, resulty = transformations.spatialConstantPadding( 134 | x, y, trans_xy=False, padding_zyx=[0, 2, 2] 135 | ) 136 | np.testing.assert_allclose(x_expected, resultx.numpy()) 137 | np.testing.assert_array_equal(resulty, y) 138 | 139 | 140 | def test_randomCrop(): 141 | x = np.random.rand(10, 10, 10).astype(np.float32) 142 | y = np.random.randint(0, 2, size=(10, 10, 10)).astype(np.float32) 143 | expected_shape = (3, 3, 10) 144 | res_x, res_y = transformations.randomCrop(x, y, trans_xy=True, cropsize=3) 145 | assert np.shape(res_x.numpy()) == expected_shape 146 | assert np.shape(res_y.numpy()) == expected_shape 147 | assert np.all(np.in1d(np.ravel(res_x), np.ravel(x))) 148 | assert np.all(np.in1d(np.ravel(res_y), np.ravel(y))) 149 | 150 | res_x, res_y = transformations.randomCrop(x, y, trans_xy=False, cropsize=3) 151 | assert np.shape(res_x.numpy()) == expected_shape 152 | assert np.all(np.in1d(np.ravel(res_x), np.ravel(x))) 153 | np.testing.assert_array_equal(res_y, y) 154 | 155 | 156 | def test_resize(): 157 | x = np.random.rand(10, 10, 10).astype(np.float32) 158 | y = np.random.randint(0, 2, size=(10, 10, 10)).astype(np.float32) 159 | expected_shape = (5, 5, 10) 160 | results_x, results_y = transformations.resize( 161 | x, y, trans_xy=True, size=[5, 5], mode="bicubic" 162 | ) 163 | assert np.shape(results_x.numpy()) == expected_shape 164 | assert np.shape(results_y.numpy()) == expected_shape 165 | 166 | results_x, results_y = transformations.resize( 167 | x, y, trans_xy=False, size=[5, 5], mode="bicubic" 168 | ) 169 | assert np.shape(results_x.numpy()) == expected_shape 170 | np.testing.assert_array_equal(results_y, y) 171 | 172 | 173 | def test_randomflip_leftright(): 174 | x = np.random.rand(3, 3, 3).astype(np.float32) 175 | y = np.random.randint(0, 2, size=(3, 3, 3)).astype(np.float32) 176 | res_x, res_y = transformations.randomflip_leftright(x, y, trans_xy=True) 177 | expected_shape = (3, 3, 3) 178 | assert np.shape(res_x.numpy()) == expected_shape 179 | assert np.shape(res_y.numpy()) == expected_shape 180 | 181 | res_x, res_y = transformations.randomflip_leftright(x, y, trans_xy=False) 182 | expected_shape = (3, 3, 3) 183 | assert np.shape(res_x.numpy()) == expected_shape 184 | np.testing.assert_array_equal(res_y, y) 185 | -------------------------------------------------------------------------------- /nobrainer/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from numpy.testing import assert_allclose 5 | 6 | from .. import utils as nbutils 7 | from ..io import read_csv 8 | 9 | 10 | def test_get_data(): 11 | csv_path = nbutils.get_data() 12 | assert Path(csv_path).is_file() 13 | 14 | files = read_csv(csv_path) 15 | assert len(files) == 10 16 | assert all(len(r) == 2 for r in files) 17 | for x, y in files: 18 | assert Path(x).is_file() 19 | assert Path(y).is_file() 20 | 21 | 22 | def test_streaming_stats(): 23 | # TODO: add entropy 24 | ss = nbutils.StreamingStats() 25 | xs = np.random.random_sample((100)) 26 | for x in xs: 27 | ss.update(x) 28 | assert_allclose(xs.mean(), ss.mean()) 29 | assert_allclose(xs.std(), ss.std()) 30 | assert_allclose(xs.var(), ss.var()) 31 | 32 | ss = nbutils.StreamingStats() 33 | xs = np.random.random_sample((10, 5, 5, 5)) 34 | for x in xs: 35 | ss.update(x) 36 | assert_allclose(xs.mean(0), ss.mean()) 37 | assert_allclose(xs.std(0), ss.std()) 38 | assert_allclose(xs.var(0), ss.var()) 39 | -------------------------------------------------------------------------------- /nobrainer/tests/tfrecord_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from numpy.testing import assert_array_equal 5 | import pytest 6 | import tensorflow as tf 7 | 8 | from .utils import csv_of_volumes # noqa: F401 9 | from .. import io, tfrecord 10 | 11 | 12 | def test_write_read_volume_labels(csv_of_volumes, tmp_path): # noqa: F811 13 | files = io.read_csv(csv_of_volumes, skip_header=False) 14 | filename_template = str(tmp_path / "data-{shard:03d}.tfrecords") 15 | examples_per_shard = 12 16 | tfrecord.write( 17 | files, 18 | filename_template=filename_template, 19 | examples_per_shard=examples_per_shard, 20 | processes=1, 21 | ) 22 | 23 | paths = list(tmp_path.glob("data-*.tfrecords")) 24 | paths = sorted(paths) 25 | assert len(paths) == 9 26 | assert (tmp_path / "data-008.tfrecords").is_file() 27 | 28 | dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") 29 | dset = dset.map( 30 | tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_labels=False) 31 | ) 32 | 33 | for ref, test in zip(files, dset): 34 | x, y = ref 35 | x, y = io.read_volume(x), io.read_volume(y) 36 | assert_array_equal(x, test[0]) 37 | assert_array_equal(y, test[1]) 38 | 39 | with pytest.raises(ValueError): 40 | tfrecord.write( 41 | files, filename_template="data/foobar-{}.tfrecords", examples_per_shard=4 42 | ) 43 | 44 | 45 | def test_write_read_volume_labels_all_processes(csv_of_volumes, tmp_path): # noqa: F811 46 | files = io.read_csv(csv_of_volumes, skip_header=False) 47 | filename_template = str(tmp_path / "data-{shard:03d}.tfrecords") 48 | examples_per_shard = 12 49 | tfrecord.write( 50 | files, 51 | filename_template=filename_template, 52 | examples_per_shard=examples_per_shard, 53 | processes=None, 54 | ) 55 | 56 | paths = list(tmp_path.glob("data-*.tfrecords")) 57 | paths = sorted(paths) 58 | assert len(paths) == 9 59 | assert (tmp_path / "data-008.tfrecords").is_file() 60 | 61 | dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") 62 | dset = dset.map( 63 | tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_labels=False) 64 | ) 65 | 66 | for ref, test in zip(files, dset): 67 | x, y = ref 68 | x, y = io.read_volume(x), io.read_volume(y) 69 | assert_array_equal(x, test[0]) 70 | assert_array_equal(y, test[1]) 71 | 72 | with pytest.raises(ValueError): 73 | tfrecord.write( 74 | files, filename_template="data/foobar-{}.tfrecords", examples_per_shard=4 75 | ) 76 | 77 | 78 | def test_write_read_float_labels(csv_of_volumes, tmp_path): # noqa: F811 79 | files = io.read_csv(csv_of_volumes, skip_header=False) 80 | files = [(x, random.random()) for x, _ in files] 81 | filename_template = str(tmp_path / "data-{shard:03d}.tfrecords") 82 | examples_per_shard = 12 83 | tfrecord.write( 84 | files, 85 | filename_template=filename_template, 86 | examples_per_shard=examples_per_shard, 87 | processes=1, 88 | ) 89 | 90 | paths = list(tmp_path.glob("data-*.tfrecords")) 91 | paths = sorted(paths) 92 | assert len(paths) == 9 93 | assert (tmp_path / "data-008.tfrecords").is_file() 94 | 95 | dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") 96 | dset = dset.map( 97 | tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_labels=True) 98 | ) 99 | 100 | for ref, test in zip(files, dset): 101 | x, y = ref 102 | x = io.read_volume(x) 103 | assert_array_equal(x, test[0]) 104 | assert_array_equal(y, test[1]) 105 | 106 | 107 | def test_write_read_int_labels(csv_of_volumes, tmp_path): # noqa: F811 108 | files = io.read_csv(csv_of_volumes, skip_header=False) 109 | files = [(x, random.randint(0, 9)) for x, _ in files] 110 | filename_template = str(tmp_path / "data-{shard:03d}.tfrecords") 111 | examples_per_shard = 12 112 | tfrecord.write( 113 | files, 114 | filename_template=filename_template, 115 | examples_per_shard=examples_per_shard, 116 | processes=1, 117 | ) 118 | 119 | paths = list(tmp_path.glob("data-*.tfrecords")) 120 | paths = sorted(paths) 121 | assert len(paths) == 9 122 | assert (tmp_path / "data-008.tfrecords").is_file() 123 | 124 | dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") 125 | dset = dset.map( 126 | tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_labels=True) 127 | ) 128 | 129 | for ref, test in zip(files, dset): 130 | x, y = ref 131 | x = io.read_volume(x) 132 | assert_array_equal(x, test[0]) 133 | assert_array_equal(y, test[1]) 134 | 135 | 136 | def test__is_int_or_float(): 137 | assert tfrecord._is_int_or_float(10) 138 | assert tfrecord._is_int_or_float(10.0) 139 | assert tfrecord._is_int_or_float("10") 140 | assert tfrecord._is_int_or_float("10.00") 141 | assert not tfrecord._is_int_or_float("foobar") 142 | assert tfrecord._is_int_or_float(np.ones(1)) 143 | assert not tfrecord._is_int_or_float(np.ones(10)) 144 | 145 | 146 | def test__dtype_to_bytes(): 147 | np_tf_dt = [ 148 | (np.uint8, tf.uint8, b"uint8"), 149 | (np.uint16, tf.uint16, b"uint16"), 150 | (np.uint32, tf.uint32, b"uint32"), 151 | (np.uint64, tf.uint64, b"uint64"), 152 | (np.int8, tf.int8, b"int8"), 153 | (np.int16, tf.int16, b"int16"), 154 | (np.int32, tf.int32, b"int32"), 155 | (np.int64, tf.int64, b"int64"), 156 | (np.float16, tf.float16, b"float16"), 157 | (np.float32, tf.float32, b"float32"), 158 | (np.float64, tf.float64, b"float64"), 159 | ] 160 | 161 | for npd, tfd, dt in np_tf_dt: 162 | npd = np.dtype(npd) 163 | assert tfrecord._dtype_to_bytes(npd) == dt 164 | assert tfrecord._dtype_to_bytes(tfd) == dt 165 | 166 | assert tfrecord._dtype_to_bytes("float32") == b"float32" 167 | assert tfrecord._dtype_to_bytes("foobar") == b"foobar" 168 | -------------------------------------------------------------------------------- /nobrainer/tests/transform_test.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from numpy.testing import assert_array_equal 5 | import pytest 6 | import tensorflow as tf 7 | 8 | from .. import transform 9 | 10 | 11 | @pytest.mark.parametrize("volume_shape", [(64, 64, 64), (64, 64, 64, 3)]) 12 | def test_get_affine_smoke(volume_shape): 13 | affine = transform.get_affine(volume_shape) 14 | 15 | assert_array_equal(affine, np.eye(4)) 16 | 17 | 18 | def test_get_affine_errors(): 19 | with pytest.raises(ValueError): 20 | transform.get_affine(volume_shape=(64, 64)) 21 | 22 | with pytest.raises(ValueError): 23 | transform.get_affine(volume_shape=(64, 64, 64), rotation=[0, 0]) 24 | 25 | with pytest.raises(ValueError): 26 | transform.get_affine(volume_shape=(64, 64, 64), translation=[0, 0]) 27 | 28 | 29 | @pytest.mark.parametrize("volume_shape", [(2, 2, 2), (2, 2, 2, 3)]) 30 | def test_get_coordinates(volume_shape): 31 | coords = transform._get_coordinates(volume_shape=volume_shape) 32 | coords_ref = [ 33 | list(element) for element in list(itertools.product([0, 1], repeat=3)) 34 | ] 35 | assert_array_equal(coords, coords_ref) 36 | 37 | 38 | def test_get_coordinates_errors(): 39 | with pytest.raises(ValueError): 40 | transform._get_coordinates(volume_shape=(64, 64)) 41 | 42 | 43 | @pytest.mark.parametrize("volume_shape", [(8, 8, 8), (8, 8, 8, 3)]) 44 | def test_trilinear_interpolation_smoke(volume_shape): 45 | volume = np.arange(np.prod(volume_shape)).reshape(volume_shape) 46 | coords = transform._get_coordinates(volume_shape=volume_shape) 47 | x = transform._trilinear_interpolation(volume=volume, coords=coords) 48 | assert_array_equal(x, volume) 49 | 50 | 51 | @pytest.mark.parametrize("volume_shape", [(8, 8, 8), (8, 8, 8, 3)]) 52 | def test_get_voxels(volume_shape): 53 | volume = np.arange(np.prod(volume_shape)).reshape(volume_shape) 54 | coords = transform._get_coordinates(volume_shape=volume_shape) 55 | voxels = transform._get_voxels(volume=volume, coords=coords) 56 | 57 | if len(volume_shape) == 3: 58 | assert_array_equal(voxels, np.arange(np.prod(volume_shape))) 59 | else: 60 | assert_array_equal( 61 | voxels, 62 | np.arange(np.prod(volume_shape)).reshape((np.prod(volume_shape[:3]), -1)), 63 | ) 64 | 65 | 66 | def test_get_voxels_errors(): 67 | volume = np.zeros((8, 8)) 68 | coords = transform._get_coordinates(volume_shape=(8, 8, 8)) 69 | with pytest.raises(ValueError): 70 | transform._get_voxels(volume=volume, coords=coords) 71 | 72 | volume = np.zeros((8, 8, 8)) 73 | coords = np.zeros((8, 8, 8)) 74 | with pytest.raises(ValueError): 75 | transform._get_voxels(volume=volume, coords=coords) 76 | 77 | coords = np.zeros((8, 2)) 78 | with pytest.raises(ValueError): 79 | transform._get_voxels(volume=volume, coords=coords) 80 | 81 | 82 | @pytest.mark.parametrize("shape", [(10, 10, 10), (10, 10, 10, 3)]) 83 | @pytest.mark.parametrize("scalar_labels", [True, False]) 84 | def test_apply_random_transform(shape, scalar_labels): 85 | x = np.ones(shape).astype(np.float32) 86 | transform_func = transform.apply_random_transform 87 | if scalar_labels: 88 | y_shape = (1,) 89 | kwargs = {"trans_xy": False} 90 | else: 91 | y_shape = shape 92 | kwargs = {} 93 | 94 | y_in = np.random.randint(0, 2, size=y_shape).astype(np.float32) 95 | x, y = transform_func(x, y_in, **kwargs) 96 | x = x.numpy() 97 | y = y.numpy() 98 | 99 | # Test that values were not changed in the labels. 100 | if scalar_labels: 101 | assert_array_equal(y, y_in) 102 | else: 103 | assert_array_equal(np.unique(y), [0, 1]) 104 | assert x.shape == shape 105 | assert y.shape == y_shape 106 | 107 | with pytest.raises(ValueError): 108 | inconsistent_shape = tuple([sh + 1 for sh in shape]) 109 | x, y = transform_func(np.ones(shape), np.ones(inconsistent_shape), **kwargs) 110 | 111 | with pytest.raises(ValueError): 112 | y_shape = (1,) if scalar_labels else (10, 10) 113 | x, y = transform_func(np.ones((10, 10)), np.ones(y_shape), **kwargs) 114 | 115 | x = np.random.randn(*shape).astype(np.float32) 116 | y_shape = (1,) if scalar_labels else shape 117 | y = np.random.randint(0, 2, size=y_shape).astype(np.float32) 118 | x0, y0 = transform_func(x, y, **kwargs) 119 | x1, y1 = transform_func(x, y, **kwargs) 120 | assert not np.array_equal(x, x0) 121 | assert not np.array_equal(x, x1) 122 | assert not np.array_equal(x0, x1) 123 | 124 | if scalar_labels: 125 | assert np.array_equal(y, y0) 126 | assert np.array_equal(y, y1) 127 | assert np.array_equal(y0, y1) 128 | else: 129 | assert not np.array_equal(y, y0) 130 | assert not np.array_equal(y, y1) 131 | assert not np.array_equal(y0, y1) 132 | 133 | # Test that new iterations yield different augmentations. 134 | x = np.arange(64).reshape(1, 4, 4, 4).astype(np.float32) 135 | y_shape = (1, 1) if scalar_labels else x.shape 136 | y = np.random.randint(0, 2, size=y_shape).astype(np.float32) 137 | dataset = tf.data.Dataset.from_tensor_slices((x, y)) 138 | # sanity check 139 | x0, y0 = next(iter(dataset)) 140 | x1, y1 = next(iter(dataset)) 141 | assert_array_equal(x[0], x0) 142 | assert_array_equal(x0, x1) 143 | assert_array_equal(y[0], y0) 144 | assert_array_equal(y0, y1) 145 | # Need to reset the seed, because it is set in other tests. 146 | tf.random.set_seed(None) 147 | dataset = dataset.map(lambda x_l, y_l: transform_func(x_l, y_l, **kwargs)) 148 | x0, y0 = next(iter(dataset)) 149 | x1, y1 = next(iter(dataset)) 150 | assert not np.array_equal(x0, x1) 151 | if scalar_labels: 152 | assert_array_equal(y0, y1) 153 | else: 154 | assert not np.array_equal(y0, y1) 155 | assert_array_equal(np.unique(y0), [0, 1]) 156 | assert_array_equal(np.unique(y1), [0, 1]) 157 | # Naive test that features were interpolated without nearest neighbor. 158 | assert np.any(x0 % 1) 159 | assert np.any(x1 % 1) 160 | -------------------------------------------------------------------------------- /nobrainer/tests/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | import nibabel as nib 4 | import numpy as np 5 | from numpy.testing import assert_array_equal 6 | import pytest 7 | from scipy.stats import entropy 8 | 9 | from ..utils import StreamingStats 10 | 11 | 12 | @pytest.fixture(scope="session") 13 | def csv_of_volumes(tmpdir_factory): 14 | """Create random Nifti volumes for use in testing, and return filepath to 15 | CSV, which contains rows of filepaths to `(features, labels)`. 16 | """ 17 | savedir = tmpdir_factory.mktemp("data") 18 | volume_shape = (8, 8, 8) 19 | n_volumes = 100 20 | 21 | features = np.random.rand(n_volumes, *volume_shape).astype(np.float32) * 10 22 | labels = np.random.randint(0, 1, size=(n_volumes, *volume_shape)) 23 | labels = labels.astype(np.int32) 24 | affine = np.eye(4) 25 | list_of_filepaths = [] 26 | 27 | for idx in range(n_volumes): 28 | fpf = str(savedir.join("{}f.nii.gz")).format(idx) 29 | fpl = str(savedir.join("{}l.nii.gz")).format(idx) 30 | 31 | nib.save(nib.Nifti1Image(features[idx], affine), fpf) 32 | nib.save(nib.Nifti1Image(labels[idx], affine), fpl) 33 | list_of_filepaths.append((fpf, fpl)) 34 | 35 | filepath = savedir.join("features_labels.csv") 36 | with open(filepath, mode="w", newline="") as f: 37 | writer = csv.writer(f, delimiter=",") 38 | writer.writerows(list_of_filepaths) 39 | 40 | return str(filepath) 41 | 42 | 43 | def test_stream_stat(): 44 | s1 = np.array([[0.5, 0.1, 0.4]]) 45 | s2 = np.array([[0.5, 0.2, 0.3]]) 46 | s3 = np.array([[0.4, 0.2, 0.4]]) 47 | 48 | st = np.concatenate((s1, s2, s3), axis=0) 49 | 50 | s = StreamingStats() 51 | s.update(s1).update(s2).update(s3) 52 | 53 | assert_array_equal(s.mean(), np.mean(st, axis=0)) 54 | assert_array_equal(s.var(), np.var(st, axis=0)) 55 | assert_array_equal(s.std(), np.std(st, axis=0)) 56 | assert_array_equal(np.sum(s.entropy()), entropy(np.mean(st, axis=0), axis=0)) 57 | -------------------------------------------------------------------------------- /nobrainer/tests/volume_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_array_equal 3 | import pytest 4 | import tensorflow as tf 5 | 6 | from .. import volume 7 | 8 | 9 | def test_binarize(): 10 | x = [ 11 | 0.49671415, 12 | -0.1382643, 13 | 0.64768854, 14 | 1.52302986, 15 | -0.23415337, 16 | -0.23413696, 17 | 1.57921282, 18 | 0.76743473, 19 | ] 20 | x = np.asarray(x, dtype="float64") 21 | expected = np.array([True, False, True, True, False, False, True, True]) 22 | result = volume.binarize(x) 23 | assert_array_equal(expected, result) 24 | assert result.dtype == tf.float64 25 | result = volume.binarize(x.astype(np.float32)) 26 | assert_array_equal(expected, result) 27 | assert result.dtype == tf.float32 28 | 29 | x = np.asarray([-2, 0, 2, 0, 2, -2, -1, 1], dtype=np.int32) 30 | expected = np.array([False, False, True, False, True, False, False, True]) 31 | result = volume.binarize(x) 32 | assert_array_equal(expected, result) 33 | assert result.dtype == tf.int32 34 | result = volume.binarize(x.astype(np.int64)) 35 | assert_array_equal(expected, result) 36 | assert result.dtype == tf.int64 37 | 38 | 39 | @pytest.mark.parametrize("replace_func", [volume.replace, volume.replace_in_numpy]) 40 | def test_replace(replace_func): 41 | data = np.arange(5) 42 | mapping = {0: 10, 1: 20, 2: 30, 3: 40, 4: 30} 43 | output = replace_func(data, mapping) 44 | assert_array_equal(output, [10, 20, 30, 40, 30]) 45 | 46 | # Test that overlapping keys and values gives correct result. 47 | data = np.arange(5) 48 | mapping = {0: 1, 1: 2, 2: 3, 3: 4} 49 | output = replace_func(data, mapping) 50 | assert_array_equal(output, [1, 2, 3, 4, 0]) 51 | 52 | data = np.arange(8).reshape(2, 2, 2) 53 | mapping = {0: 100, 100: 10, 10: 5, 3: 5} 54 | outputs = replace_func(data, mapping, zero=False) 55 | expected = data.copy() 56 | expected[0, 0, 0] = 100 57 | expected[0, 1, 1] = 5 58 | assert_array_equal(outputs, expected) 59 | 60 | # Zero values not in mapping values. 61 | outputs = replace_func(data, mapping, zero=True) 62 | expected = np.zeros_like(data) 63 | expected[0, 0, 0] = 100 64 | expected[0, 1, 1] = 5 65 | assert_array_equal(outputs, expected) 66 | 67 | 68 | @pytest.mark.parametrize("std_func", [volume.standardize, volume.standardize_numpy]) 69 | def test_standardize(std_func): 70 | x = np.random.randn(10, 10, 10).astype(np.float32) 71 | outputs = np.array(std_func(x)) 72 | assert np.allclose(outputs.mean(), 0, atol=1e-07) 73 | assert np.allclose(outputs.std(), 1, atol=1e-07) 74 | 75 | if std_func == volume.standardize: 76 | x = np.random.randn(10, 10, 10).astype(np.float64) 77 | outputs = np.array(std_func(x)) 78 | assert outputs.dtype == np.float32 79 | 80 | 81 | @pytest.mark.parametrize("norm_func", [volume.normalize, volume.normalize_numpy]) 82 | def test_normalize(norm_func): 83 | x = np.random.randn(10, 10, 10).astype(np.float32) 84 | outputs = np.array(norm_func(x)) 85 | assert np.allclose(outputs.min(), 0, atol=1e-07) 86 | assert np.allclose(outputs.max(), 1, atol=1e-07) 87 | 88 | if norm_func == volume.normalize: 89 | x = np.random.randn(10, 10, 10).astype(np.float64) 90 | outputs = np.array(norm_func(x)) 91 | assert outputs.dtype == np.float32 92 | 93 | 94 | def _stack_channels(_in): 95 | return np.stack([_in, 2 * _in, 3 * _in], axis=-1) 96 | 97 | 98 | @pytest.mark.parametrize("multichannel", [True, False]) 99 | @pytest.mark.parametrize("to_blocks_func", [volume.to_blocks, volume.to_blocks_numpy]) 100 | def test_to_blocks(multichannel, to_blocks_func): 101 | x = np.arange(8).reshape(2, 2, 2) 102 | block_shape = (1, 1, 1) 103 | if multichannel: 104 | x = _stack_channels(x) 105 | block_shape = (1, 1, 1, 3) 106 | outputs = np.array(to_blocks_func(x, block_shape)) 107 | expected = np.array( 108 | [[[[0]]], [[[1]]], [[[2]]], [[[3]]], [[[4]]], [[[5]]], [[[6]]], [[[7]]]] 109 | ) 110 | if multichannel: 111 | expected = _stack_channels(expected) 112 | assert_array_equal(outputs, expected) 113 | 114 | block_shape = 2 115 | if multichannel: 116 | block_shape = (2, 2, 2, 3) 117 | outputs = np.array(to_blocks_func(x, block_shape)) 118 | assert_array_equal(outputs, x[None]) 119 | 120 | block_shape = (3, 3, 3) 121 | if multichannel: 122 | block_shape = (3, 3, 3, 3) 123 | with pytest.raises((tf.errors.InvalidArgumentError, ValueError)): 124 | to_blocks_func(x, block_shape) 125 | 126 | block_shape = (3, 3) 127 | with pytest.raises(ValueError): 128 | to_blocks_func(x, block_shape) 129 | 130 | 131 | @pytest.mark.parametrize("multichannel", [True, False]) 132 | @pytest.mark.parametrize( 133 | "from_blocks_func", [volume.from_blocks, volume.from_blocks_numpy] 134 | ) 135 | def test_from_blocks(multichannel, from_blocks_func): 136 | x = np.arange(64).reshape(4, 4, 4) 137 | block_shape = (2, 2, 2) 138 | if multichannel: 139 | x = _stack_channels(x) 140 | block_shape = (2, 2, 2, 3) 141 | 142 | outputs = from_blocks_func(volume.to_blocks(x, block_shape), x.shape) 143 | assert_array_equal(outputs, x) 144 | 145 | with pytest.raises(ValueError): 146 | x = np.arange(80).reshape(10, 2, 2, 2) 147 | outputs = from_blocks_func(x, (4, 4, 4)) 148 | 149 | 150 | def test_blocks_numpy_value_errors(): 151 | with pytest.raises(ValueError): 152 | x = np.random.rand(4, 4) 153 | output_shape = (4, 4, 4) 154 | volume.to_blocks_numpy(x, output_shape) 155 | 156 | with pytest.raises(ValueError): 157 | x = np.random.rand(4, 4, 4) 158 | output_shape = (4, 4, 4) 159 | volume.from_blocks_numpy(x, output_shape) 160 | 161 | with pytest.raises(ValueError): 162 | x = np.random.rand(4, 4, 4, 4) 163 | output_shape = (4, 4) 164 | volume.from_blocks_numpy(x, output_shape) 165 | -------------------------------------------------------------------------------- /nobrainer/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for Nobrainer.""" 2 | 3 | from collections import namedtuple 4 | import csv 5 | import os 6 | import tempfile 7 | 8 | import numpy as np 9 | import psutil 10 | import tensorflow as tf 11 | 12 | _cache_dir = os.path.join(tempfile.gettempdir(), "nobrainer-data") 13 | 14 | 15 | def get_data(cache_dir=_cache_dir): 16 | """Download sample features and labels. The features are T1-weighted MGZ 17 | files, and the labels are the corresponding aparc+aseg MGZ files, created 18 | with FreeSurfer. This will download 46 megabytes of data. 19 | 20 | These data can be found at 21 | https://datasets.datalad.org/workshops/nih-2017/ds000114/. 22 | 23 | Parameters 24 | ---------- 25 | cache_dir: str, directory where to save the data. By default, saves to a 26 | temporary directory. 27 | 28 | Returns 29 | ------- 30 | List of `(features, labels)`. 31 | """ 32 | 33 | os.makedirs(cache_dir, exist_ok=True) 34 | URLHashPair = namedtuple("URLHashPair", "sub x_hash y_hash") 35 | hashes = [ 36 | URLHashPair( 37 | sub="sub-01", 38 | x_hash="67d0053f021d1d137bc99715e4e3ebb763364c8ce04311b1032d4253fc149f52", 39 | y_hash="7a85b628653f24e2b71cbef6dda86ab24a1743c5f6dbd996bdde258414e780b5", 40 | ), 41 | URLHashPair( 42 | sub="sub-02", 43 | x_hash="c0fee669a34bf3b43c8e4aecc88204512ef4e83f2e414640a5abc076b435990c", 44 | y_hash="c92357c2571da72d15332b2b4838b94d442d4abd3dbddc4b54202d68f0e19380", 45 | ), 46 | URLHashPair( 47 | sub="sub-03", 48 | x_hash="e2bba954e37f5791260f0ec573456e3293bbd40dba139bb1af417eaaeabe63e6", 49 | y_hash="e9204f0d50f06a89dd1870911f7ef5e9808e222227799a5384dceeb941ee8f9d", 50 | ), 51 | URLHashPair( 52 | sub="sub-04", 53 | x_hash="deec5245a2a5948f7e1053ace8d8a31396b14a96d520c6a52305434e75abe1e8", 54 | y_hash="c50e33a3f87aca351414e729b7c25404af364dfe5dd1de5fe380a460cbe9f891", 55 | ), 56 | URLHashPair( 57 | sub="sub-05", 58 | x_hash="8a7fe84918f3f80b87903a1e8f7bd20792c0ebc7528fb98513be373258dfd6c0", 59 | y_hash="682f52633633551d6fda71ede65aa41e16c332ebf42b4df042bc312200b0337c", 60 | ), 61 | URLHashPair( 62 | sub="sub-06", 63 | x_hash="f9a0c40bcd62d7b7e88015867ab5d926009b097ac3235499a541ac9072dd90c8", 64 | y_hash="31c842969af9ac178361fa8c13f656a47d27d95357abaf3e7f3521671aa17929", 65 | ), 66 | URLHashPair( 67 | sub="sub-07", 68 | x_hash="9de3b7392f5383e7391c5fcd9266d6b7ab6b57bc7ab203cc9ad2a29a2d31a85b", 69 | y_hash="b2e48bbfc4185261785643fc8ab066be5f97215b5a9b029ade1ffb12d54d616e", 70 | ), 71 | URLHashPair( 72 | sub="sub-08", 73 | x_hash="361098fc69c280970bb0b0d7ea6aba80d383c12e3ccfe5899693bc35b68efbe4", 74 | y_hash="0c980ef851b1391f580d91fc87c10d6d30315527cc0749c1010f2b7d5819a009", 75 | ), 76 | URLHashPair( 77 | sub="sub-09", 78 | x_hash="1456b35112297df5caacb9d33cb047aa85a3a5b4db3b4b5f9a5c2e189a684e1a", 79 | y_hash="696f1e9fef512193b71580292e0edc5835f396d2c8d63909c13668ef7bed433b", 80 | ), 81 | URLHashPair( 82 | sub="sub-10", 83 | x_hash="97447f17402e0f9990cd0917f281704893b52a9b61a3241b23a112a0a143d26e", 84 | y_hash="97a7947ba1a28963714c9f5c82520d9ef803d005695a0b4109d5a73d7e8a537b", 85 | ), 86 | ] 87 | x_filename = "t1.mgz" 88 | y_filename = "aparc+aseg.mgz" 89 | url_template = ( 90 | "https://datasets.datalad.org/workshops/nih-2017/ds000114/derivatives/" 91 | "freesurfer/{sub}/mri/{fname}" 92 | ) 93 | output = [("features", "labels")] 94 | for h in hashes: 95 | x_origin = url_template.format(sub=h.sub, fname=x_filename) 96 | y_origin = url_template.format(sub=h.sub, fname=y_filename) 97 | x_fname = h.sub + "_" + x_origin.rsplit("/", 1)[-1] 98 | y_fname = h.sub + "_" + y_origin.rsplit("/", 1)[-1] 99 | x_out = tf.keras.utils.get_file( 100 | fname=x_fname, origin=x_origin, file_hash=h.x_hash, cache_dir=cache_dir 101 | ) 102 | y_out = tf.keras.utils.get_file( 103 | fname=y_fname, origin=y_origin, file_hash=h.y_hash, cache_dir=cache_dir 104 | ) 105 | output.append((x_out, y_out)) 106 | 107 | csvpath = os.path.join(cache_dir, "filepaths.csv") 108 | with open(csvpath, "w", newline="") as f: 109 | writer = csv.writer(f) 110 | writer.writerows(output) 111 | 112 | return csvpath 113 | 114 | 115 | class StreamingStats: 116 | """Object to calculate statistics on streaming data. 117 | 118 | Compatible with scalars and n-dimensional arrays. 119 | 120 | Examples 121 | -------- 122 | 123 | ```python 124 | >>> s = StreamingStats() 125 | >>> s.update(10).update(20) 126 | >>> s.mean() 127 | 15.0 128 | ``` 129 | 130 | ```python 131 | >>> import numpy as np 132 | >>> a = np.array([[0, 2], [4, 8]]) 133 | >>> b = np.array([[2, 4], [8, 16]]) 134 | >>> s = StreamingStats() 135 | >>> s.update(a).update(b) 136 | >>> s.mean() 137 | array([[ 1., 3.], 138 | [ 6., 12.]]) 139 | ``` 140 | """ 141 | 142 | def __init__(self): 143 | self._n_samples = 0 144 | self._current_mean = 0.0 145 | self._M = 0.0 146 | 147 | def update(self, value): 148 | """Update the statistics with the next value. 149 | 150 | Parameters 151 | ---------- 152 | value: scalar, array-like 153 | 154 | Returns 155 | ------- 156 | Modified instance. 157 | """ 158 | if self._n_samples == 0: 159 | self._current_mean = value 160 | else: 161 | prev_mean = self._current_mean 162 | curr_mean = prev_mean + (value - prev_mean) / (self._n_samples + 1) 163 | _M = self._M + (prev_mean - value) * (curr_mean - value) 164 | # Set the instance attributes after computation in case there are 165 | # errors during computation. 166 | self._current_mean = curr_mean 167 | self._M = _M 168 | self._n_samples += 1 169 | return self 170 | 171 | def mean(self): 172 | """Return current mean of streaming data.""" 173 | return self._current_mean 174 | 175 | def var(self): 176 | """Return current variance of streaming data.""" 177 | return self._M / self._n_samples 178 | 179 | def std(self): 180 | """Return current standard deviation of streaming data.""" 181 | return self.var() ** 0.5 182 | 183 | def entropy(self): 184 | """Return current entropy of streaming data.""" 185 | eps = 1e-07 186 | mult = np.multiply(np.log(self.mean() + eps), self.mean()) 187 | return -mult 188 | # return -np.sum(mult, axis=axis) 189 | 190 | 191 | def get_num_parallel(): 192 | # Get number of processes allocated to the current process. 193 | # Note the difference from `os.cpu_count()`. 194 | try: 195 | num_parallel_calls = len(psutil.Process().cpu_affinity()) 196 | except AttributeError: 197 | num_parallel_calls = psutil.cpu_count() 198 | return num_parallel_calls 199 | -------------------------------------------------------------------------------- /nobrainer/validation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | import nibabel as nib 6 | import numpy as np 7 | 8 | from .io import read_mapping, read_volume 9 | from .metrics import dice as dice_numpy 10 | from .prediction import predict as _predict 11 | from .volume import normalize_numpy, replace 12 | 13 | DT_X = "float32" 14 | 15 | 16 | def validate_from_filepath( 17 | filepath, 18 | predictor, 19 | block_shape, 20 | n_classes, 21 | mapping_y, 22 | return_variance=False, 23 | return_entropy=False, 24 | return_array_from_images=False, 25 | n_samples=1, 26 | normalizer=normalize_numpy, 27 | batch_size=4, 28 | ): 29 | """Computes dice for a prediction compared to a ground truth image. 30 | 31 | Args: 32 | filepath: tuple, tuple of paths to existing neuroimaging volume (index 0) 33 | and ground truth (index 1). 34 | predictor: TensorFlow Predictor object, predictor from previously 35 | trained model. 36 | n_classes: int, number of classifications the model is trained to output. 37 | mapping_y: path-like, path to csv mapping file per command line argument. 38 | block_shape: tuple of len 3, shape of blocks on which to predict. 39 | return_variance: Boolean. If set True, it returns the running population 40 | variance along with mean. Note, if the n_samples is smaller or equal to 1, 41 | the variance will not be returned; instead it will return None 42 | return_entropy: Boolean. If set True, it returns the running entropy. 43 | along with mean. 44 | return_array_from_images: Boolean. If set True and the given input is either 45 | image, filepath, or filepaths, it will return arrays of [mean, variance, 46 | entropy] instead of images of them. Also, if the input is array, it will 47 | simply return array, whether or not this flag is True or False. 48 | n_samples: The number of sampling. If set as 1, it will just return the 49 | single prediction value. 50 | normalizer: callable, function that accepts an ndarray and returns an 51 | ndarray. Called before separating volume into blocks. 52 | batch_size: int, number of sub-volumes per batch for prediction. 53 | dtype: str or dtype object, dtype of features. 54 | 55 | Returns: 56 | `nibabel.spatialimages.SpatialImage` or arrays of predictions of 57 | mean, variance(optional), and entropy (optional). 58 | """ 59 | if not Path(filepath[0]).is_file(): 60 | raise FileNotFoundError("could not find file {}".format(filepath[0])) 61 | img = nib.load(filepath[0]) 62 | y = read_volume(filepath[1], dtype=np.int32) 63 | 64 | outputs = _predict( 65 | inputs=img, 66 | predictor=predictor, 67 | block_shape=block_shape, 68 | return_variance=return_variance, 69 | return_entropy=return_entropy, 70 | return_array_from_images=return_array_from_images, 71 | n_samples=n_samples, 72 | normalizer=normalizer, 73 | batch_size=batch_size, 74 | ) 75 | prediction_image = outputs[0].get_data() 76 | y = replace(y, read_mapping(mapping_y)) 77 | dice = get_dice_for_images(prediction_image, y, n_classes) 78 | return outputs, dice 79 | 80 | 81 | def get_dice_for_images(pred, gt, n_classes): 82 | """Computes dice for a prediction compared to a ground truth image. 83 | 84 | Args: 85 | pred: nibabel.spatialimages.SpatialImage, a predicted image. 86 | gt: nibabel.spatialimages.SpatialImage, a ground-truth image. 87 | 88 | 89 | Returns: 90 | `nibabel.spatialimages.SpatialImage`. 91 | """ 92 | dice = np.zeros(n_classes) 93 | for i in range(n_classes): 94 | u = np.equal(pred, i) 95 | v = np.equal(gt, i) 96 | dice[i] = dice_numpy(u, v) 97 | 98 | return dice 99 | 100 | 101 | def validate_from_filepaths( 102 | filepaths, 103 | predictor, 104 | block_shape, 105 | n_classes, 106 | mapping_y, 107 | output_path, 108 | return_variance=False, 109 | return_entropy=False, 110 | return_array_from_images=False, 111 | n_samples=1, 112 | normalizer=normalize_numpy, 113 | batch_size=4, 114 | dtype=DT_X, 115 | ): 116 | """Yield predictions from filepaths using a SavedModel. 117 | 118 | Args: 119 | test_csv: list, neuroimaging volume filepaths on which to predict. 120 | n_classes: int, number of classifications the model is trained to output. 121 | mapping_y: path-like, path to csv mapping file per command line argument. 122 | block_shape: tuple of len 3, shape of blocks on which to predict. 123 | predictor: TensorFlow Predictor object, predictor from previously 124 | trained model. 125 | block_shape: tuple of len 3, shape of blocks on which to predict. 126 | normalizer: callable, function that accepts an ndarray and returns 127 | an ndarray. Called before separating volume into blocks. 128 | batch_size: int, number of sub-volumes per batch for prediction. 129 | dtype: str or dtype object, dtype of features. 130 | 131 | Returns: 132 | None 133 | """ 134 | for filepath in filepaths: 135 | outputs, dice = validate_from_filepath( 136 | filepath=filepath, 137 | predictor=predictor, 138 | n_classes=n_classes, 139 | mapping_y=mapping_y, 140 | block_shape=block_shape, 141 | return_variance=return_variance, 142 | return_entropy=return_entropy, 143 | return_array_from_images=return_array_from_images, 144 | n_samples=n_samples, 145 | normalizer=normalizer, 146 | batch_size=batch_size, 147 | dtype=dtype, 148 | ) 149 | 150 | outpath = Path(filepath[0]) 151 | output_path = Path(output_path) 152 | suffixes = "".join(s for s in outpath.suffixes) 153 | mean_path = output_path / (outpath.stem + "_mean" + suffixes) 154 | variance_path = output_path / (outpath.stem + "_variance" + suffixes) 155 | entropy_path = output_path / (outpath.stem + "_entropy" + suffixes) 156 | dice_path = output_path / (outpath.stem + "_dice.npy") 157 | # if mean_path.is_file() or variance_path.is_file() or entropy_path.is_file(): 158 | # raise Exception(str(mean_path) + " or " + str(variance_path) + 159 | # " or " + str(entropy_path) + " already exists.") 160 | 161 | nib.save(outputs[0], mean_path.as_posix()) # fix 162 | if not return_array_from_images: 163 | include_variance = (n_samples > 1) and (return_variance) 164 | include_entropy = (n_samples > 1) and (return_entropy) 165 | if include_variance and return_entropy: 166 | nib.save(outputs[1], str(variance_path)) 167 | nib.save(outputs[2], str(entropy_path)) 168 | elif include_variance: 169 | nib.save(outputs[1], str(variance_path)) 170 | elif include_entropy: 171 | nib.save(outputs[1], str(entropy_path)) 172 | 173 | print(filepath[0]) 174 | print("Dice: " + str(np.mean(dice))) 175 | np.save(dice_path, dice) 176 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # Setuptools version should match setup.py; wheel because pip will insert it noisily 3 | requires = ["setuptools >= 38.3.0", "wheel"] 4 | build-backend = 'setuptools.build_meta' 5 | 6 | [tool.black] 7 | exclude='\.eggs|\.git|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist|_version\.py|versioneer\.py' 8 | 9 | [tool.isort] 10 | profile = "black" 11 | force_sort_within_sections = true 12 | reverse_relative = true 13 | sort_relative_in_force_sorted_sections = true 14 | known_first_party = ["nobrainer"] 15 | 16 | [tool.codespell] 17 | ignore-words-list = "nd" 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = nobrainer 3 | url = https://neuronets.github.io 4 | author = Nobrainer Developers 5 | author_email = jakub.kaczmarzyk@gmail.com 6 | description = A framework for developing neural network models for 3D image processing. 7 | long_description = file:README.md 8 | long_description_content_type = text/markdown; charset=UTF-8 9 | license = Apache License, 2.0 10 | license_file = LICENSE 11 | classifiers = 12 | Development Status :: 3 - Alpha 13 | Environment :: Console 14 | Intended Audience :: Developers 15 | Intended Audience :: Education 16 | Intended Audience :: Healthcare Industry 17 | Intended Audience :: Science/Research 18 | License :: OSI Approved :: Apache Software License 19 | Operating System :: OS Independent 20 | Programming Language :: Python :: 3 21 | Programming Language :: Python :: 3 :: Only 22 | Programming Language :: Python :: 3.9 23 | Programming Language :: Python :: 3.10 24 | Programming Language :: Python :: 3.11 25 | Programming Language :: Python :: 3.12 26 | Topic :: Scientific/Engineering :: Artificial Intelligence 27 | Topic :: Software Development 28 | Topic :: Software Development :: Libraries :: Python Modules 29 | project_urls = 30 | Source Code = https://github.com/neuronets/nobrainer 31 | 32 | [options] 33 | python_requires = >= 3.9 34 | install_requires = 35 | click 36 | fsspec 37 | joblib 38 | nibabel 39 | numpy 40 | scikit-image 41 | tensorflow-probability < 0.24 42 | tensorflow >=2.13, < 2.16 43 | tensorflow-addons ~= 0.23.0 44 | psutil 45 | zip_safe = False 46 | packages = find: 47 | include_package_data = True 48 | 49 | [options.entry_points] 50 | console_scripts = 51 | nobrainer=nobrainer.cli.main:cli 52 | 53 | [options.extras_require] 54 | and-cuda = 55 | tensorflow[and-cuda] >=2.13, < 2.16 56 | dev = 57 | pre-commit 58 | pytest 59 | pytest-cov 60 | scipy 61 | 62 | [tool:pytest] 63 | addopts = --verbose --cov=nobrainer --cov-config=setup.cfg 64 | 65 | [coverage:run] 66 | branch = True 67 | omit = 68 | nobrainer/_version.py 69 | */tests* 70 | 71 | [coverage:report] 72 | exclude_lines = 73 | pragma: no cover 74 | raise NotImplementedError 75 | if __name__ == .__main__.: 76 | ignore_errors = True 77 | 78 | [flake8] 79 | max-line-length = 100 80 | exclude = 81 | .git/ 82 | __pycache__/ 83 | build/ 84 | dist/ 85 | versioneer.py 86 | _version.py 87 | ignore = 88 | E203 89 | W503 90 | 91 | [versioneer] 92 | VCS = git 93 | style = pep440 94 | versionfile_source = nobrainer/_version.py 95 | versionfile_build = nobrainer/_version.py 96 | tag_prefix = 97 | parentdir_prefix = 98 | 99 | [codespell] 100 | skip = nobrainer/_version.py,versioneer.py 101 | # Don't warn about "[l]ist" in the abbrev_prompt() docstring: 102 | ignore-regex = \[\w\]\w+ 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for nobrainer. 2 | 3 | To install, run `python3 setup.py install`. 4 | """ 5 | 6 | import os 7 | import sys 8 | 9 | from setuptools import setup 10 | 11 | # This is needed for versioneer to be importable when building with PEP 517. 12 | # See and links 13 | # therein for more information. 14 | sys.path.append(os.path.dirname(__file__)) 15 | 16 | try: 17 | import versioneer 18 | 19 | setup_kw = { 20 | "version": versioneer.get_version(), 21 | "cmdclass": versioneer.get_cmdclass(), 22 | } 23 | except ImportError: 24 | # see https://github.com/warner/python-versioneer/issues/192 25 | print("WARNING: failed to import versioneer, falling back to no version for now") 26 | setup_kw = {} 27 | 28 | setup(name="nobrainer", **setup_kw) 29 | --------------------------------------------------------------------------------