├── .readthedocs.yml ├── tests ├── test_data │ ├── tif │ │ └── test.tif │ └── nii │ │ └── test.nii.gz └── test_dataset.py ├── docs ├── source │ ├── dataset.rst │ ├── transforms.rst │ ├── index.rst │ └── conf.py ├── Makefile └── make.bat ├── niftidataset ├── __init__.py ├── errors.py ├── utils.py ├── dataset.py └── transforms.py ├── LICENSE ├── .travis.yml ├── setup.py ├── .gitignore └── README.md /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | build: 2 | image: latest 3 | 4 | python: 5 | version: 3.6 6 | -------------------------------------------------------------------------------- /tests/test_data/tif/test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcreinhold/niftidataset/HEAD/tests/test_data/tif/test.tif -------------------------------------------------------------------------------- /tests/test_data/nii/test.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcreinhold/niftidataset/HEAD/tests/test_data/nii/test.nii.gz -------------------------------------------------------------------------------- /docs/source/dataset.rst: -------------------------------------------------------------------------------- 1 | Dataset 2 | =================================== 3 | 4 | This module holds all of the pytorch dataset classes 5 | 6 | .. automodule:: niftidataset.dataset 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/transforms.rst: -------------------------------------------------------------------------------- 1 | Transforms 2 | =================================== 3 | 4 | This module holds all the transforms associated with NIfTI dataset 5 | 6 | .. automodule:: niftidataset.transforms 7 | :members: 8 | -------------------------------------------------------------------------------- /niftidataset/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = """Jacob C Reinhold""" 2 | __email__ = 'jcreinhold@gmail.com' 3 | __version__ = '0.2.1' 4 | 5 | from niftidataset.errors import * 6 | from niftidataset.utils import * 7 | from niftidataset.dataset import * 8 | from niftidataset.transforms import * 9 | -------------------------------------------------------------------------------- /niftidataset/errors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | niftidataset.errors 5 | 6 | This module holds project defined errors 7 | 8 | Author: Jacob Reinhold (jacob.reinhold@jhu.edu) 9 | 10 | Created on: Oct 24, 2018 11 | """ 12 | 13 | 14 | class NiftiDatasetError(Exception): 15 | pass 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, IACL Contributors 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. niftidataset documentation master file, created by 2 | sphinx-quickstart on Thu Oct 24 07:08:04 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | niftidataset documentation 7 | =================================================== 8 | 9 | This package provides a simple dataset class by which to extract and 10 | process NIfTI files for PyTorch neural networks 11 | 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :caption: Contents: 16 | 17 | dataset 18 | transforms 19 | 20 | Indices and tables 21 | ================== 22 | 23 | * :ref:`genindex` 24 | * :ref:`modindex` 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = niftidataset 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=intensity-normalization 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /niftidataset/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | niftidataset.utils 5 | 6 | assortment of input/output utilities for the project 7 | 8 | Author: Jacob Reinhold (jacob.reinhold@jhu.edu) 9 | 10 | Created on: Oct 24, 2018 11 | """ 12 | 13 | __all__ = ['split_filename', 14 | 'glob_imgs'] 15 | 16 | from typing import List, Tuple 17 | 18 | from glob import glob 19 | import os 20 | 21 | 22 | def split_filename(filepath: str) -> Tuple[str, str, str]: 23 | """ split a filepath into the directory, base, and extension """ 24 | path = os.path.dirname(filepath) 25 | filename = os.path.basename(filepath) 26 | base, ext = os.path.splitext(filename) 27 | if ext == '.gz': 28 | base, ext2 = os.path.splitext(base) 29 | ext = ext2 + ext 30 | return path, base, ext 31 | 32 | 33 | def glob_imgs(path: str, ext='*.nii*') -> List[str]: 34 | """ grab all `ext` files in a directory and sort them for consistency """ 35 | fns = sorted(glob(os.path.join(path, ext))) 36 | return fns 37 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | services: 3 | - xvfb 4 | cache: apt 5 | language: python 6 | sudo: false 7 | addons: 8 | apt: 9 | packages: 10 | - libatlas-dev 11 | - libatlas-base-dev 12 | - liblapack-dev 13 | - gfortran 14 | python: 15 | - 3.6 16 | before_install: 17 | - wget http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 18 | - bash miniconda.sh -b -p $HOME/miniconda 19 | - export PATH="$HOME/miniconda/bin:$PATH" 20 | - conda update --yes conda 21 | install: 22 | - python --version 23 | - which python 24 | - pip --version 25 | # coveralls requires coverage==4.0.3, recent versions of pytest-cov require coverage>=4.4 26 | - travis_retry conda install numpy --yes 27 | - travis_retry conda install pytorch torchvision -c pytorch --yes 28 | - travis_retry conda install nibabel -c conda-forge --yes 29 | - travis_retry pip install coverage nose pytest-pep8 pytest-cov coveralls 30 | - travis_retry python setup.py install 31 | script: 32 | - nosetests -v --with-coverage --cover-tests --cover-package=niftidataset tests 33 | after_success: 34 | - coveralls 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | setup 5 | 6 | Module installs the niftidataset package 7 | Can be run via command: python setup.py install (or develop) 8 | 9 | Author: Jacob Reinhold (jacob.reinhold@jhu.edu) 10 | 11 | Created on: Oct 24, 2018 12 | """ 13 | 14 | from setuptools import setup, find_packages 15 | 16 | 17 | with open('README.md') as f: 18 | readme = f.read() 19 | 20 | with open('LICENSE') as f: 21 | license = f.read() 22 | 23 | args = dict( 24 | name='niftidataset', 25 | version='0.2.1', 26 | description="dataset and transforms classes for NIfTI data in pytorch", 27 | long_description=readme, 28 | author='Jacob Reinhold', 29 | author_email='jacob.reinhold@jhu.edu', 30 | url='https://github.com/jcreinhold/niftidataset', 31 | license=license, 32 | packages=find_packages(exclude=('tests', 'docs', 'tutorials')), 33 | keywords="nifti dataset", 34 | ) 35 | 36 | setup(install_requires=['nibabel>=2.3.1', 37 | 'numpy>=1.15.4', 38 | 'pillow>=5.3.0', 39 | 'torch>=1.0.0', 40 | 'torchvision>=0.2.1'], **args) 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac 2 | .DS_Store 3 | 4 | # Latex documentation 5 | *.aux 6 | *.fls 7 | *.fdb_latexmk 8 | *.log 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | env/ 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *,cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | _build 74 | _static 75 | _templates 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # IPython Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv/ 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # pycharm 104 | .idea 105 | 106 | # misc 107 | *.pth 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | niftidataset 2 | ======================= 3 | 4 | [![Build Status](https://travis-ci.org/jcreinhold/niftidataset.svg?branch=master)](https://travis-ci.org/jcreinhold/niftidataset) 5 | [![Coverage Status](https://coveralls.io/repos/github/jcreinhold/niftidataset/badge.svg?branch=master)](https://coveralls.io/github/jcreinhold/niftidataset?branch=master) 6 | [![Documentation Status](https://readthedocs.org/projects/niftidataset/badge/?version=latest)](http://niftidataset.readthedocs.io/en/latest/) 7 | [![Python Versions](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/downloads/release/python-370/) 8 | 9 | **This package is deprecated in favor of [torchio](https://torchio.readthedocs.io/) or [MONAI](https://monai.io/) and will no longer be supported** 10 | 11 | This package simply provides appropriate `dataset` and `transforms` classes for NIfTI images 12 | for use with PyTorch or PyTorch wrappers. 13 | 14 | ** Note that this is an **alpha** release. If you have feedback or problems, please submit an issue (it is very appreciated) ** 15 | 16 | Requirements 17 | ------------ 18 | 19 | - nibabel >= 2.3.1 20 | - numpy >= 1.15.4 21 | - pillow >= 5.3.0 22 | - torch >= 1.0.0 23 | - torchvision >= 0.2.1 24 | 25 | Installation 26 | ------------ 27 | 28 | pip install git+git://github.com/jcreinhold/niftidataset.git 29 | 30 | Tutorial 31 | -------- 32 | 33 | [5 minute Overview](https://github.com/jcreinhold/niftidataset/blob/master/tutorials/5min-tutorial.ipynb) 34 | 35 | In addition to the above small tutorial, there is consolidated documentation [here](https://niftidataset.readthedocs.io/en/latest/). 36 | 37 | Test Package 38 | ------------ 39 | 40 | Unit tests can be run from the main directory as follows: 41 | 42 | nosetests -v tests 43 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | 16 | import mock 17 | import os 18 | import sys 19 | 20 | MOCK_MODULES = ['torch', 'torch.utils', 'torch.utils.data', 'torch.utils.data.dataset', 'numpy', 'nibabel', 21 | 'fastai','fastai.vision', 'PIL',',matplotlib','matplotlib.pyplot','torchvision', 22 | 'torchvision.transforms','torchvision.transforms.functional'] 23 | 24 | for mod_name in MOCK_MODULES: 25 | sys.modules[mod_name] = mock.Mock() 26 | 27 | # this should not be needed with the above, but meh 28 | autodoc_mock_imports = ['nibabel','numpy', 'torch','fastai','PIL','matplotlib','torchvision'] 29 | 30 | on_rtd = os.environ.get('READTHEDOCS') == 'True' 31 | 32 | # add package to local path 33 | local_path = os.path.abspath('../../') 34 | print('path defined as: {}'.format(local_path)) 35 | sys.path.insert(0, local_path) 36 | 37 | 38 | # -- Project information ----------------------------------------------------- 39 | 40 | project = 'niftidataset' 41 | copyright = '2020, Jacob Reinhold' 42 | author = 'Jacob Reinhold' 43 | 44 | # The short X.Y version 45 | version = '0.2' 46 | # The full version, including alpha/beta/rc tags 47 | release = '0.2.0' 48 | 49 | 50 | # -- General configuration --------------------------------------------------- 51 | 52 | # If your documentation needs a minimal Sphinx version, state it here. 53 | # 54 | # needs_sphinx = '1.0' 55 | 56 | # Add any Sphinx extension module names here, as strings. They can be 57 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 58 | # ones. 59 | extensions = [ 60 | 'sphinx.ext.autodoc', 61 | 'sphinx.ext.napoleon', 62 | 'sphinx.ext.mathjax' 63 | ] 64 | 65 | napoleon_google_docstring = True 66 | 67 | # Add any paths that contain templates here, relative to this directory. 68 | templates_path = ['_templates'] 69 | 70 | # The suffix(es) of source filenames. 71 | # You can specify multiple suffix as a list of string: 72 | # 73 | # source_suffix = ['.rst', '5min_tutorial.md'] 74 | source_suffix = '.rst' 75 | 76 | # The master toctree document. 77 | master_doc = 'index' 78 | 79 | # The language for content autogenerated by Sphinx. Refer to documentation 80 | # for a list of supported languages. 81 | # 82 | # This is also used if you do content translation via gettext catalogs. 83 | # Usually you set "language" from the command line for these cases. 84 | language = None 85 | 86 | # List of patterns, relative to source directory, that match files and 87 | # directories to ignore when looking for source files. 88 | # This pattern also affects html_static_path and html_extra_path . 89 | exclude_patterns = [] 90 | 91 | # The name of the Pygments (syntax highlighting) style to use. 92 | pygments_style = 'sphinx' 93 | 94 | 95 | # -- Options for HTML output ------------------------------------------------- 96 | 97 | # The theme to use for HTML and HTML Help pages. See the documentation for 98 | # a list of builtin themes. 99 | # 100 | html_theme = 'alabaster' 101 | 102 | # Theme options are theme-specific and customize the look and feel of a theme 103 | # further. For a list of options available for each theme, see the 104 | # documentation. 105 | # 106 | html_theme_options = {'github_user': 'jcreinhold', 107 | 'github_repo': 'niftidataset'} 108 | 109 | # Add any paths that contain custom static files (such as style sheets) here, 110 | # relative to this directory. They are copied after the builtin static files, 111 | # so a file named "default.css" will overwrite the builtin "default.css". 112 | html_static_path = ['_static'] 113 | 114 | # Custom sidebar templates, must be a dictionary that maps document names 115 | # to template names. 116 | # 117 | # The default sidebars (for documents that don't match any pattern) are 118 | # defined by theme itself. Builtin themes are using these templates by 119 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 120 | # 'searchbox.html']``. 121 | # 122 | # html_sidebars = {} 123 | 124 | 125 | # -- Options for HTMLHelp output --------------------------------------------- 126 | 127 | # Output file base name for HTML help builder. 128 | htmlhelp_basename = 'niftidatasetdoc' 129 | 130 | 131 | # -- Options for LaTeX output ------------------------------------------------ 132 | 133 | latex_elements = { 134 | # The paper size ('letterpaper' or 'a4paper'). 135 | # 136 | # 'papersize': 'letterpaper', 137 | 138 | # The font size ('10pt', '11pt' or '12pt'). 139 | # 140 | # 'pointsize': '10pt', 141 | 142 | # Additional stuff for the LaTeX preamble. 143 | # 144 | # 'preamble': '', 145 | 146 | # Latex figure (float) alignment 147 | # 148 | # 'figure_align': 'htbp', 149 | } 150 | 151 | # Grouping the document tree into LaTeX files. List of tuples 152 | # (source start file, target name, title, 153 | # author, documentclass [howto, manual, or own class]). 154 | latex_documents = [ 155 | (master_doc, 'niftidataset.tex', 'niftidataset Documentation', 156 | 'Jacob Reinhold', 'manual'), 157 | ] 158 | 159 | 160 | # -- Options for manual page output ------------------------------------------ 161 | 162 | # One entry per manual page. List of tuples 163 | # (source start file, name, description, authors, manual section). 164 | man_pages = [ 165 | (master_doc, 'niftidataset', 'niftidataset Documentation', 166 | [author], 1) 167 | ] 168 | 169 | 170 | # -- Options for Texinfo output ---------------------------------------------- 171 | 172 | # Grouping the document tree into Texinfo files. List of tuples 173 | # (source start file, target name, title, author, 174 | # dir menu entry, description, category) 175 | texinfo_documents = [ 176 | (master_doc, 'niftidataset', 'niftidataset Documentation', 177 | author, 'niftidataset', 'One line description of project.', 178 | 'Miscellaneous'), 179 | ] 180 | 181 | 182 | # -- Extension configuration ------------------------------------------------- 183 | -------------------------------------------------------------------------------- /niftidataset/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | niftidataset.dataset 5 | 6 | the actual dataset classes of niftidataset 7 | 8 | Author: Jacob Reinhold (jacob.reinhold@jhu.edu) 9 | 10 | Created on: Oct 24, 2018 11 | """ 12 | 13 | __all__ = ['NiftiDataset', 14 | 'MultimodalNiftiDataset', 15 | 'MultimodalNifti2p5DDataset', 16 | 'MultimodalImageDataset', 17 | 'train_val_split'] 18 | 19 | from typing import Callable, List, Optional 20 | 21 | import nibabel as nib 22 | import numpy as np 23 | from PIL import Image 24 | from torch.utils.data.dataset import Dataset 25 | 26 | from niftidataset.errors import NiftiDatasetError 27 | from niftidataset.utils import glob_imgs 28 | 29 | 30 | class NiftiDataset(Dataset): 31 | """ 32 | create a dataset class in PyTorch for reading NIfTI files 33 | 34 | Args: 35 | source_fns (List[str]): list of paths to source images 36 | target_fns (List[str]): list of paths to target images 37 | transform (Callable): transform to apply to both source and target images 38 | preload (bool): load all data when initializing the dataset 39 | """ 40 | 41 | def __init__(self, source_fns: List[str], target_fns: List[str], transform: Optional[Callable] = None, 42 | preload: bool = False): 43 | self.source_fns, self.target_fns = source_fns, target_fns 44 | self.transform = transform 45 | self.preload = preload 46 | if len(self.source_fns) != len(self.target_fns) or len(self.source_fns) == 0: 47 | raise ValueError(f'Number of source and target images must be equal and non-zero') 48 | if preload: 49 | self.imgs = [] 50 | for s, t in zip(self.source_fns, self.target_fns): 51 | sdata = nib.load(s).get_data() 52 | tdata = nib.load(t).get_data() 53 | self.imgs.append((sdata, tdata)) 54 | if sdata.size == 0 and tdata.size == 0: 55 | raise NiftiDatasetError(f'Both {s} and {t} are empty. Data needs to be non-empty.') 56 | elif sdata.size == 0: 57 | raise NiftiDatasetError(f'Source {s} is empty. Data needs to be non-empty.') 58 | elif tdata.size == 0: 59 | raise NiftiDatasetError(f'Target {s} is empty. Data needs to be non-empty.') 60 | 61 | @classmethod 62 | def setup_from_dir(cls, source_dir: str, target_dir: str, transform: Optional[Callable] = None, 63 | preload: bool = False): 64 | source_fns, target_fns = glob_imgs(source_dir), glob_imgs(target_dir) 65 | return cls(source_fns, target_fns, transform, preload) 66 | 67 | def __len__(self): 68 | return len(self.source_fns) 69 | 70 | def __getitem__(self, idx: int): 71 | if not self.preload: 72 | src_fn, tgt_fn = self.source_fns[idx], self.target_fns[idx] 73 | sample = (nib.load(src_fn).get_fdata(dtype=np.float32), nib.load(tgt_fn).get_fdata(dtype=np.float32)) 74 | else: 75 | sample = self.imgs[idx] 76 | if self.transform is not None: 77 | sample = self.transform(sample) 78 | return sample 79 | 80 | 81 | class MultimodalDataset(Dataset): 82 | """ base class for Multimodal*Dataset """ 83 | 84 | def __init__(self, source_fns: List[List[str]], target_fns: List[List[str]], 85 | transform: Optional[Callable] = None, 86 | segmentation: bool = False, preload: bool = False, **kwargs): 87 | self.source_fns, self.target_fns = source_fns, target_fns 88 | self.transform = transform 89 | self.segmentation = segmentation 90 | self.preload = preload 91 | if any([len(self.source_fns[0]) != len(sfn) for sfn in self.source_fns]) or \ 92 | any([len(self.target_fns[0]) != len(tfn) for tfn in self.target_fns]) or \ 93 | len(self.source_fns[0]) != len(self.target_fns[0]) or \ 94 | len(self.source_fns[0]) == 0: 95 | raise ValueError(f'Number of source and target images must be equal and non-zero') 96 | if preload: 97 | self.imgs = [] 98 | for idx in range(len(self.source_fns[0])): 99 | src_fns = [sfns[idx] for sfns in self.source_fns] 100 | tgt_fns = [tfns[idx] for tfns in self.target_fns] 101 | src, tgt = [], [] 102 | for s in src_fns: 103 | sdata = self.get_data(s) 104 | if sdata.size == 0: 105 | raise NiftiDatasetError(f'Source {s} is empty. Data needs to be non-empty.') 106 | src.append(sdata) 107 | for t in tgt_fns: 108 | tdata = self.get_data(t) 109 | if tdata.size == 0: 110 | raise NiftiDatasetError(f'Target {t} is empty. Data needs to be non-empty.') 111 | tgt.append(tdata) 112 | self.imgs.append((self.stack(src), self.stack(tgt))) 113 | 114 | @classmethod 115 | def setup_from_dir(cls, source_dirs: List[str], target_dirs: List[str], 116 | transform: Optional[Callable] = None, segmentation: bool = False, 117 | preload: bool = False, ext: str = '*.nii*', **kwargs): 118 | source_fns = [glob_imgs(sd, ext) for sd in source_dirs] 119 | target_fns = [glob_imgs(td, ext) for td in target_dirs] 120 | return cls(source_fns, target_fns, transform, segmentation, preload, **kwargs) 121 | 122 | def __len__(self): 123 | return len(self.source_fns[0]) 124 | 125 | def __getitem__(self, idx: int): 126 | if not self.preload: 127 | src_fns, tgt_fns = [sfns[idx] for sfns in self.source_fns], [tfns[idx] for tfns in self.target_fns] 128 | sample = (self.stack([self.get_data(s) for s in src_fns]), 129 | self.stack([self.get_data(t) for t in tgt_fns])) 130 | else: 131 | sample = self.imgs[idx] 132 | if self.transform is not None: 133 | sample = self.transform(sample) 134 | if self.segmentation: 135 | sample = (sample[0], sample[1].squeeze().long()) # for segmentation, loss expects no channel dim 136 | return sample 137 | 138 | def get_data(self, fn): 139 | raise NotImplementedError 140 | 141 | def stack(self, imgs): 142 | raise NotImplementedError 143 | 144 | 145 | class MultimodalNiftiDataset(MultimodalDataset): 146 | """ 147 | create a dataset class in PyTorch for reading N types of NIfTI files to M types of output NIfTI files 148 | 149 | ** note that all images must have the same dimensions! ** 150 | 151 | Args: 152 | source_dirs (List[str]): paths to source images 153 | target_dirs (List[str]): paths to target images 154 | transform (Callable): transform to apply to both source and target images 155 | """ 156 | 157 | def get_data(self, fn): return nib.load(fn).get_fdata(dtype=np.float32) 158 | 159 | def stack(self, imgs): return np.stack(imgs) 160 | 161 | 162 | class MultimodalNifti2p5DDataset(MultimodalNiftiDataset): 163 | """ 164 | create a dataset class in PyTorch for reading N types of NIfTI files to M types of output NIfTI files 165 | 2.5D dataset, so return images stacked in the channel dimension for processing with a 2D CNN 166 | 167 | ** note that all images must have the same dimensions! ** 168 | 169 | Args: 170 | source_dirs (List[str]): paths to source images 171 | target_dirs (List[str]): paths to target images 172 | transform (Callable): transform to apply to both source and target images 173 | """ 174 | 175 | def __init__(self, source_dirs: List[str], target_dirs: List[str], transform: Optional[Callable] = None, 176 | segmentation: bool = False, preload: bool = False, axis: int = 0): 177 | self.axis = axis 178 | super().__init__(source_dirs, target_dirs, transform, segmentation, preload) 179 | 180 | def stack(self, imgs): return np.swapaxes(np.concatenate(imgs, axis=self.axis), 0, self.axis) 181 | 182 | 183 | class MultimodalImageDataset(MultimodalDataset): 184 | """ 185 | create a dataset class in PyTorch for reading N types of (no channel) image files to M types of output image files. 186 | can use whatever PIL can open. 187 | 188 | ** note that all images must have the same dimensions! ** 189 | 190 | There is no implementation of ImageDataset because it is sufficient to use normal pytorch image 191 | dataloaders for that use case 192 | 193 | Args: 194 | source_dirs (List[str]): paths to source images 195 | target_dirs (List[str]): paths to target images 196 | transform (Callable): transform to apply to both source and target images 197 | color (bool): images are color, ie, 3 channels 198 | """ 199 | 200 | def __init__(self, source_dirs: List[str], target_dirs: List[str], transform: Optional[Callable] = None, 201 | segmentation: bool = False, color: bool = False, preload: bool = False): 202 | self.color = color 203 | super().__init__(source_dirs, target_dirs, transform, segmentation, preload) 204 | 205 | @classmethod 206 | def setup_from_dir(cls, source_dirs: List[str], target_dirs: List[str], 207 | transform: Optional[Callable] = None, 208 | segmentation: bool = False, 209 | color: bool = False, preload: bool = False, 210 | ext: str = '*.tif*', **kwargs): 211 | source_fns = [glob_imgs(sd, ext) for sd in source_dirs] 212 | target_fns = [glob_imgs(td, ext) for td in target_dirs] 213 | return cls(source_fns, target_fns, transform, segmentation, color, preload) 214 | 215 | def get_data(self, fn): 216 | data = np.asarray(Image.open(fn), dtype=np.float32) 217 | if self.color: data = data.transpose((2, 0, 1)) 218 | return data 219 | 220 | def stack(self, imgs): 221 | data = np.stack(imgs) 222 | if self.color: data = data.squeeze() 223 | return data 224 | 225 | 226 | def train_val_split(source_dir: str, target_dir: str, valid_pct: float = 0.2, 227 | transform: Optional[Callable] = None, preload: bool = False): 228 | """ 229 | create two separate NiftiDatasets in PyTorch for working with NifTi files. If a directory contains source files 230 | and the other one contains target files and also you dont have a specific directory for validation set, 231 | this function splits data to two NiftiDatasets randomly with given percentage. 232 | 233 | Args: 234 | source_dir (str): path to source images. 235 | target_dir (str): path to target images. 236 | valid_pct (float): percent of validation set from data. 237 | transform (Callable): transform to apply to both source and target images. 238 | preload: load all data when initializing the dataset 239 | Returns: 240 | Tuple: (train_dataset, validation_dataset). 241 | """ 242 | if not (0 < valid_pct < 1): 243 | raise ValueError(f'valid_pct must be between 0 and 1') 244 | source_fns, target_fns = glob_imgs(source_dir), glob_imgs(target_dir) 245 | rand_idx = np.random.permutation(list(range(len(source_fns)))) 246 | cut = int(valid_pct * len(source_fns)) 247 | return (NiftiDataset(source_fns=[source_fns[i] for i in rand_idx[cut:]], 248 | target_fns=[target_fns[i] for i in rand_idx[cut:]], 249 | transform=transform, preload=preload), 250 | NiftiDataset(source_fns=[source_fns[i] for i in rand_idx[:cut]], 251 | target_fns=[target_fns[i] for i in rand_idx[:cut]], 252 | transform=transform, preload=preload)) 253 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | tests.test_dataset 5 | 6 | test submodules for runtime errors 7 | 8 | Author: Jacob Reinhold (jacob.reinhold@jhu.edu) 9 | 10 | Created on: Oct 24, 2018 11 | """ 12 | 13 | import os 14 | import shutil 15 | import tempfile 16 | import unittest 17 | 18 | import torchvision.transforms as torch_tfms 19 | from typing import Tuple 20 | import torch 21 | from niftidataset import * 22 | 23 | 24 | class TestUtilities(unittest.TestCase): 25 | 26 | def setUp(self): 27 | wd = os.path.dirname(os.path.abspath(__file__)) 28 | self.nii_dir = os.path.join(wd, 'test_data', 'nii') 29 | self.tif_dir = os.path.join(wd, 'test_data', 'tif') 30 | self.out_dir = tempfile.mkdtemp() 31 | self.train_dir = os.path.join(self.out_dir, 'train') 32 | os.mkdir(self.train_dir) 33 | os.mkdir(os.path.join(self.train_dir, '1')) 34 | os.mkdir(os.path.join(self.train_dir, '2')) 35 | nii = glob_imgs(self.nii_dir)[0] 36 | tif = os.path.join(self.tif_dir, 'test.tif') 37 | path, base, ext = split_filename(nii) 38 | for i in range(4): 39 | shutil.copy(nii, os.path.join(self.train_dir, base + str(i) + ext)) 40 | shutil.copy(tif, os.path.join(self.train_dir, '1', base + str(i) + '.tif')) 41 | shutil.copy(tif, os.path.join(self.train_dir, '2', base + str(i) + '.tif')) 42 | 43 | def test_niftidataset_2d(self): 44 | composed = torch_tfms.Compose([RandomCrop2D(10, 0), 45 | ToTensor(), 46 | FixIntensityRange()]) 47 | myds = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed) 48 | self.assertEqual(myds[0][0].shape, (1,10,10)) 49 | 50 | def test_niftidataset_2d_slice(self): 51 | composed = torch_tfms.Compose([RandomSlice(0), 52 | ToTensor()]) 53 | myds = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed) 54 | self.assertEqual(myds[0][0].shape, (1,64,64)) 55 | 56 | def test_niftidataset_3d(self): 57 | composed = torch_tfms.Compose([RandomCrop3D(10), 58 | ToTensor()]) 59 | myds = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed) 60 | self.assertEqual(myds[0][0].shape, (1,10,10,10)) 61 | 62 | def test_niftidataset_preload(self): 63 | composed = torch_tfms.Compose([RandomCrop3D(10), 64 | ToTensor(), 65 | AddChannel()]) 66 | myds = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed, preload=True) 67 | self.assertEqual(myds[0][0].shape, (1,1,10,10,10)) 68 | 69 | def test_multimodalnifti_2d(self): 70 | composed = torch_tfms.Compose([RandomCrop2D(10, 0), 71 | ToTensor(), 72 | FixIntensityRange()]) 73 | sd, td = [self.train_dir] * 3, [self.train_dir] 74 | myds = MultimodalNiftiDataset.setup_from_dir(sd, td, composed) 75 | self.assertEqual(myds[0][0].shape, (3,10,10)) 76 | self.assertEqual(myds[0][1].shape, (1,10,10)) 77 | 78 | def test_multimodalnifti_slice(self): 79 | composed = torch_tfms.Compose([RandomSlice(0), 80 | ToTensor()]) 81 | sd, td = [self.train_dir] * 2, [self.train_dir] * 4 82 | myds = MultimodalNiftiDataset.setup_from_dir(sd, td, composed) 83 | self.assertEqual(myds[0][0].shape, (2,64,64)) 84 | self.assertEqual(myds[0][1].shape, (4,64,64)) 85 | 86 | def test_multimodalnifti_3d(self): 87 | composed = torch_tfms.Compose([RandomCrop3D(10), 88 | ToTensor()]) 89 | sd, td = [self.train_dir] * 3, [self.train_dir] * 2 90 | myds = MultimodalNiftiDataset.setup_from_dir(sd, td, composed) 91 | self.assertEqual(myds[0][0].shape, (3,10,10,10)) 92 | self.assertEqual(myds[0][1].shape, (2,10,10,10)) 93 | 94 | def test_multimodalnifti_2p5D(self): 95 | composed = torch_tfms.Compose([ToTensor()]) 96 | sd, td = [self.train_dir] * 3, [self.train_dir] * 2 97 | myds = MultimodalNifti2p5DDataset.setup_from_dir(sd, td, composed, axis=0) 98 | self.assertEqual(myds[0][0].shape, (3*51,64,64)) 99 | self.assertEqual(myds[0][1].shape, (2*51,64,64)) 100 | 101 | def test_multimodalnifti_preload(self): 102 | composed = torch_tfms.Compose([RandomCrop3D(10), 103 | ToTensor()]) 104 | sd, td = [self.train_dir] * 3, [self.train_dir] * 2 105 | myds = MultimodalNiftiDataset.setup_from_dir(sd, td, composed, preload=True) 106 | self.assertEqual(myds[0][0].shape, (3,10,10,10)) 107 | self.assertEqual(myds[0][1].shape, (2,10,10,10)) 108 | 109 | def test_multimodaltiff(self): 110 | composed = torch_tfms.Compose([ToTensor()]) 111 | sd, td = [self.train_dir+'/1/'] * 3, [self.train_dir+'/2/'] * 2 112 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 113 | self.assertEqual(myds[0][0].shape, (3,256,256)) 114 | self.assertEqual(myds[0][1].shape, (2,256,256)) 115 | 116 | def test_multimodaltiff_preload(self): 117 | composed = torch_tfms.Compose([ToTensor()]) 118 | sd, td = [self.train_dir+'/1/'] * 3, [self.train_dir+'/2/'] * 2 119 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed, preload=True) 120 | self.assertEqual(myds[0][0].shape, (3,256,256)) 121 | self.assertEqual(myds[0][1].shape, (2,256,256)) 122 | 123 | def test_multimodaltiff_seg(self): 124 | composed = torch_tfms.Compose([ToTensor()]) 125 | sd, td = [self.train_dir + '/1/'] * 3, [self.train_dir + '/2/'] 126 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed, segmentation=True) 127 | self.assertEqual(myds[0][0].shape, (3, 256, 256)) 128 | self.assertEqual(myds[0][1].shape, (256, 256)) 129 | 130 | def test_multimodaltiff_crop(self): 131 | composed = torch_tfms.Compose([ToTensor(), RandomCrop(32)]) 132 | sd, td = [self.train_dir+'/1/'] * 3, [self.train_dir+'/2/'] * 2 133 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 134 | self.assertEqual(myds[0][0].shape, (3,32,32)) 135 | self.assertEqual(myds[0][1].shape, (2,32,32)) 136 | 137 | def test_aug_affine_2d(self): 138 | composed = torch_tfms.Compose([ToPILImage(), 139 | RandomAffine(1, 15, 0.1, 0.1), 140 | ToTensor()]) 141 | sd, td = [self.train_dir+'/1/'], [self.train_dir+'/2/'] 142 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 143 | self.assertEqual(myds[0][0].shape, (1,256,256)) 144 | self.assertEqual(myds[0][1].shape, (1,256,256)) 145 | 146 | def test_aug_flip_2d(self): 147 | composed = torch_tfms.Compose([ToPILImage(), 148 | RandomFlip(1, True, True), 149 | ToTensor()]) 150 | sd, td = [self.train_dir+'/1/'], [self.train_dir+'/2/'] 151 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 152 | self.assertEqual(myds[0][0].shape, (1,256,256)) 153 | self.assertEqual(myds[0][1].shape, (1,256,256)) 154 | 155 | def test_aug_intensity_2d(self): 156 | composed = torch_tfms.Compose([ToTensor(), 157 | RandomGamma(1, True, 0.1, 0.1)]) 158 | sd, td = [self.train_dir+'/1/'], [self.train_dir+'/2/'] 159 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 160 | self.assertEqual(myds[0][0].shape, (1,256,256)) 161 | self.assertEqual(myds[0][1].shape, (1,256,256)) 162 | 163 | def test_aug_noise_2d(self): 164 | composed = torch_tfms.Compose([ToTensor(), 165 | RandomNoise(1, True, True, 1)]) 166 | sd, td = [self.train_dir+'/1/'], [self.train_dir+'/2/'] 167 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 168 | self.assertEqual(myds[0][0].shape, (1,256,256)) 169 | self.assertEqual(myds[0][1].shape, (1,256,256)) 170 | 171 | def test_aug_digitize_2d(self): 172 | composed = torch_tfms.Compose([Digitize(True, True, (1,100), 1), 173 | ToTensor()]) 174 | sd, td = [self.train_dir+'/1/'], [self.train_dir+'/2/'] 175 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 176 | self.assertEqual(myds[0][0].shape, (1,256,256)) 177 | self.assertEqual(myds[0][1].shape, (1,256,256)) 178 | 179 | def test_aug_block_2d(self): 180 | composed = torch_tfms.Compose([ToTensor(), 181 | RandomBlock(1, (1,100))]) 182 | sd, td = [self.train_dir+'/1/'], [self.train_dir+'/2/'] 183 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 184 | self.assertEqual(myds[0][0].shape, (1,256,256)) 185 | self.assertEqual(myds[0][1].shape, (1,256,256)) 186 | 187 | def test_aug_block_3d(self): 188 | composed = torch_tfms.Compose([RandomCrop3D(10), 189 | ToTensor(), 190 | RandomBlock(1, (1,4), thresh=0, is_3d=True)]) 191 | myds = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed) 192 | self.assertEqual(myds[0][0].shape, (1,10,10,10)) 193 | 194 | def test_aug_block_3d_oblong(self): 195 | composed = torch_tfms.Compose([RandomCrop3D(10), 196 | ToTensor(), 197 | RandomBlock(1, ((1,4),(1,2),(1,3)), thresh=0, is_3d=True)]) 198 | myds = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed) 199 | self.assertEqual(myds[0][0].shape, (1,10,10,10)) 200 | 201 | def test_get_transform_2d(self): 202 | composed = torch_tfms.Compose(get_transforms([1,1,1,1,1],True,True,15,0.1,0.1,True,True,0.1,0.1,1,(3,4),None,False,(1,),(1,))) 203 | sd, td = [self.train_dir+'/1/'], [self.train_dir+'/2/'] 204 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed) 205 | self.assertEqual(myds[0][0].shape, (1,256,256)) 206 | self.assertEqual(myds[0][1].shape, (1,256,256)) 207 | 208 | def test_get_transform_2d_multi(self): 209 | composed = torch_tfms.Compose(get_transforms([1,1,1,1,1],True,True,15,0.1,0.1,True,True,0.1,0.1,1,(3,4),None,False,(1,),(1,))) 210 | sd, td = [self.train_dir+'/1/'] * 2, [self.train_dir+'/2/'] 211 | myds = MultimodalImageDataset.setup_from_dir(sd, td, composed, segmentation=True) 212 | self.assertEqual(myds[0][0].shape, (2,256,256)) 213 | self.assertEqual(myds[0][1].shape, (256,256)) 214 | 215 | def test_get_transform_3d(self): 216 | composed = torch_tfms.Compose(get_transforms([0,0,1,1,1],True,True,0,0,0,False,False,0.1,0.1,1,(3,4),None,True,(1,),(1,))) 217 | sd, td = [self.train_dir], [self.train_dir] 218 | myds = MultimodalNiftiDataset.setup_from_dir(sd, td, composed) 219 | self.assertEqual(myds[0][0].shape, (1, 51, 64, 64)) 220 | self.assertEqual(myds[0][1].shape, (1, 51, 64, 64)) 221 | 222 | def test_train_val_split(self): 223 | composed = torch_tfms.Compose([ToTensor()]) 224 | tr, val = train_val_split(self.train_dir, self.train_dir, 225 | valid_pct=0.25, transform=composed) 226 | self.assertEqual(len(tr), 3) 227 | self.assertEqual(len(val), 1) 228 | self.assertEqual(torch.all(torch.eq(val[0][0], tr[0][0])), torch.tensor(True)) 229 | 230 | def test_trim_intensity(self): 231 | import numpy as np 232 | composed = torch_tfms.Compose([ToTensor()]) 233 | src, tgt = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed)[0] 234 | maxim = np.max(src.numpy()) 235 | minim = np.min(src.numpy()) 236 | composed2 = torch_tfms.Compose([ToTensor(), 237 | TrimIntensity(max_val=maxim-1000, min_val=minim-1000)]) 238 | src, tgt = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed2)[0] 239 | self.assertEqual(np.max(src.numpy()), 1.) 240 | self.assertTrue(np.min(src.numpy()) > -1.) 241 | 242 | def test_normalize_without_std_and_mean(self): 243 | composed = torch_tfms.Compose([ToTensor(), Normalize(is_3d=False)]) 244 | src, tgt = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed)[0] 245 | mean = src.numpy().mean() 246 | std = src.numpy().std() 247 | self.assertTrue(-1 < mean < 1) 248 | self.assertTrue(0 < std < 2) 249 | composed2 = torch_tfms.Compose([ToTensor(), AddFakeChannel(), 250 | Normalize(is_3d=True)]) 251 | src2, tgt2 = NiftiDataset.setup_from_dir(self.train_dir, self.train_dir, composed2)[0] 252 | mean2 = src2.numpy().mean(axis=(1, 2, 3)) 253 | std2 = src2.numpy().std() 254 | self.assertIsNotNone(mean2) 255 | self.assertIsNotNone(std2) 256 | 257 | def tearDown(self): 258 | shutil.rmtree(self.out_dir) 259 | 260 | 261 | class AddFakeChannel: 262 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 263 | src, tgt = sample 264 | src = torch.Tensor([src.numpy(), src.numpy()]) 265 | tgt = torch.Tensor([tgt.numpy(), tgt.numpy()]) 266 | return src, tgt 267 | 268 | 269 | if __name__ == '__main__': 270 | unittest.main() 271 | -------------------------------------------------------------------------------- /niftidataset/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | niftidataset.transforms 5 | 6 | transformations to apply to images in dataset 7 | 8 | Author: Jacob Reinhold (jacob.reinhold@jhu.edu) 9 | 10 | Created on: Oct 24, 2018 11 | """ 12 | 13 | __all__ = ['RandomCrop2D', 14 | 'RandomCrop3D', 15 | 'RandomCrop', 16 | 'RandomSlice', 17 | 'ToTensor', 18 | 'ToPILImage', 19 | 'AddChannel', 20 | 'FixIntensityRange', 21 | 'Normalize', 22 | 'Digitize', 23 | 'MedianFilter', 24 | 'RandomAffine', 25 | 'RandomBlock', 26 | 'RandomFlip', 27 | 'RandomGamma', 28 | 'RandomNoise', 29 | 'TrimIntensity', 30 | 'get_transforms'] 31 | 32 | import random 33 | from typing import Optional, Tuple, Union 34 | 35 | import numpy as np 36 | from PIL import Image 37 | import torch 38 | import torchvision as tv 39 | import torchvision.transforms.functional as TF 40 | 41 | from niftidataset.errors import NiftiDatasetError 42 | 43 | PILImage = type(Image) 44 | 45 | 46 | class BaseTransform: 47 | def __repr__(self): return f'{self.__class__.__name__}' 48 | 49 | 50 | class CropBase(BaseTransform): 51 | """ base class for crop transform """ 52 | 53 | def __init__(self, out_dim: int, output_size: Union[tuple, int, list], threshold: Optional[float] = None, 54 | pct: Tuple[float, float] = (0., 1.), axis=0): 55 | """ provide the common functionality for RandomCrop2D and RandomCrop3D """ 56 | assert isinstance(output_size, (int, tuple, list)) 57 | if isinstance(output_size, int): 58 | self.output_size = (output_size,) 59 | for _ in range(out_dim - 1): 60 | self.output_size += (output_size,) 61 | else: 62 | assert len(output_size) == out_dim 63 | self.output_size = output_size 64 | self.out_dim = out_dim 65 | self.thresh = threshold 66 | self.pct = pct 67 | self.axis = axis 68 | 69 | def _get_sample_idxs(self, img: np.ndarray) -> Tuple[int, int, int]: 70 | """ get the set of indices from which to sample (foreground) """ 71 | mask = np.where(img >= (img.mean() if self.thresh is None else self.thresh)) # returns a tuple of length 3 72 | c = np.random.randint(0, len(mask[0])) # choose the set of idxs to use 73 | h, w, d = [m[c] for m in mask] # pull out the chosen idxs 74 | return h, w, d 75 | 76 | def _offset_by_pct(self, h, w, d): 77 | s = (h, w, d) 78 | hml = wml = dml = 0 79 | hmh = wmh = dmh = 0 80 | i0, i1 = int(s[self.axis] * self.pct[0]), int(s[self.axis] * (1. - self.pct[1])) 81 | if self.axis == 0: 82 | hml += i0 83 | hmh += i1 84 | elif self.axis == 1: 85 | wml += i0 86 | wmh += i1 87 | else: 88 | dml += i0 89 | dmh += i1 90 | return (hml, wml, dml), (hmh, wmh, dmh) 91 | 92 | def __repr__(self): 93 | s = '{name}(output_size={output_size}, threshold={thresh})' 94 | d = dict(self.__dict__) 95 | return s.format(name=self.__class__.__name__, **d) 96 | 97 | 98 | class RandomCrop2D(CropBase): 99 | """ 100 | Randomly crop a 2d slice/patch from a 3d image 101 | 102 | Args: 103 | output_size (tuple or int): Desired output size. 104 | If int, cube crop is made. 105 | axis (int or None): along which axis should the patch/slice be extracted 106 | provide None for random axis 107 | include_neighbors (bool): extract 3 neighboring slices instead of just 1 108 | """ 109 | 110 | def __init__(self, output_size: Union[int, tuple, list], axis: Optional[int] = 0, 111 | include_neighbors: bool = False, threshold: Optional[float] = None) -> None: 112 | if axis is not None: 113 | assert 0 <= axis <= 2 114 | super().__init__(2, output_size, threshold) 115 | self.axis = axis 116 | self.include_neighbors = include_neighbors 117 | 118 | def __call__(self, sample: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 119 | axis = self.axis if self.axis is not None else np.random.randint(0, 3) 120 | src, tgt = sample 121 | *cs, h, w, d = src.shape 122 | *ct, _, _, _ = src.shape 123 | new_h, new_w = self.output_size 124 | max_idxs = (np.inf, w - new_h // 2, d - new_w // 2) if axis == 0 else \ 125 | (h - new_h // 2, np.inf, d - new_w // 2) if axis == 1 else \ 126 | (h - new_h // 2, w - new_w // 2, np.inf) 127 | min_idxs = (-np.inf, new_h // 2, new_w // 2) if axis == 0 else \ 128 | (new_h // 2, -np.inf, new_w // 2) if axis == 1 else \ 129 | (new_h // 2, new_w // 2, -np.inf) 130 | s = src[0] if len(cs) > 0 else src # use the first image to determine sampling if multimodal 131 | s_idxs = super()._get_sample_idxs(s) 132 | idxs = [i if min_i <= i <= max_i else max_i if i > max_i else min_i 133 | for max_i, min_i, i in zip(max_idxs, min_idxs, s_idxs)] 134 | s = self._get_slice(src, idxs, axis).squeeze() 135 | t = self._get_slice(tgt, idxs, axis).squeeze() 136 | if len(cs) == 0 or s.ndim == 2: s = s[np.newaxis, ...] # add channel axis if empty 137 | if len(ct) == 0 or t.ndim == 2: t = t[np.newaxis, ...] 138 | return s, t 139 | 140 | def _get_slice(self, img: np.ndarray, idxs: Tuple[int, int, int], axis: int) -> np.ndarray: 141 | h, w = self.output_size 142 | n = 1 if self.include_neighbors else 0 143 | oh = 0 if h % 2 == 0 else 1 144 | ow = 0 if w % 2 == 0 else 1 145 | i, j, k = idxs 146 | s = img[..., i - n:i + 1 + n, j - h // 2:j + h // 2 + oh, k - w // 2:k + w // 2 + ow] if axis == 0 else \ 147 | img[..., i - h // 2:i + h // 2 + oh, j - n:j + 1 + n, k - w // 2:k + w // 2 + ow] if axis == 1 else \ 148 | img[..., i - h // 2:i + h // 2 + oh, j - w // 2:j + w // 2 + ow, k - n:k + 1 + n] 149 | if self.include_neighbors: 150 | s = np.transpose(s, (0, 1, 2)) if axis == 0 else \ 151 | np.transpose(s, (1, 0, 2)) if axis == 1 else \ 152 | np.transpose(s, (2, 0, 1)) 153 | return s 154 | 155 | 156 | class RandomCrop3D(CropBase): 157 | """ 158 | Randomly crop a 3d patch from a (pair of) 3d image 159 | 160 | Args: 161 | output_size (tuple or int): Desired output size. 162 | If int, cube crop is made. 163 | """ 164 | 165 | def __init__(self, output_size: Union[tuple, int, list], threshold: Optional[float] = None, 166 | pct: Tuple[float, float] = (0., 1.), axis=0): 167 | super().__init__(3, output_size, threshold, pct, axis) 168 | 169 | def __call__(self, sample: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 170 | src, tgt = sample 171 | *cs, h, w, d = src.shape 172 | *ct, _, _, _ = tgt.shape 173 | hh, ww, dd = self.output_size 174 | (hml, wml, dml), (hmh, wmh, dmh) = self._offset_by_pct(h, w, d) 175 | max_idxs = (h - hmh - hh // 2, w - wmh - ww // 2, d - dmh - dd // 2) 176 | min_idxs = (hml + hh // 2, wml + ww // 2, dml + dd // 2) 177 | s = src[0] if len(cs) > 0 else src # use the first image to determine sampling if multimodal 178 | s_idxs = self._get_sample_idxs(s) 179 | i, j, k = [i if min_i <= i <= max_i else max_i if i > max_i else min_i 180 | for max_i, min_i, i in zip(max_idxs, min_idxs, s_idxs)] 181 | oh = 0 if hh % 2 == 0 else 1 182 | ow = 0 if ww % 2 == 0 else 1 183 | od = 0 if dd % 2 == 0 else 1 184 | s = src[..., i - hh // 2:i + hh // 2 + oh, j - ww // 2:j + ww // 2 + ow, k - dd // 2:k + dd // 2 + od] 185 | t = tgt[..., i - hh // 2:i + hh // 2 + oh, j - ww // 2:j + ww // 2 + ow, k - dd // 2:k + dd // 2 + od] 186 | if len(cs) == 0: s = s[np.newaxis, ...] # add channel axis if empty 187 | if len(ct) == 0: t = t[np.newaxis, ...] 188 | return s, t 189 | 190 | 191 | class RandomCrop: 192 | """ 193 | Randomly crop a 2d patch from a 2d image 194 | 195 | Args: 196 | output_size (tuple or int): Desired output size. 197 | If int, square crop is made. 198 | """ 199 | 200 | def __init__(self, output_size: Union[tuple, int], threshold: Optional[float] = None): 201 | self.output_size = (output_size, output_size) if isinstance(output_size, int) else output_size 202 | self.thresh = threshold 203 | 204 | def __call__(self, sample: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 205 | src, tgt = sample 206 | *cs, h, w = src.shape 207 | *ct, _, _ = tgt.shape 208 | hh, ww = self.output_size 209 | max_idxs = (h - hh // 2, w - ww // 2) 210 | min_idxs = (hh // 2, ww // 2) 211 | s = src[0] if len(cs) > 0 else src # use the first image to determine sampling if multimodal 212 | mask = np.where(s >= (s.mean() if self.thresh is None else self.thresh)) 213 | c = np.random.randint(0, len(mask[0])) # choose the set of idxs to use 214 | s_idxs = [m[c] for m in mask] # pull out the chosen idxs 215 | i, j = [i if min_i <= i <= max_i else max_i if i > max_i else min_i 216 | for max_i, min_i, i in zip(max_idxs, min_idxs, s_idxs)] 217 | oh = 0 if hh % 2 == 0 else 1 218 | ow = 0 if ww % 2 == 0 else 1 219 | s = src[..., i - hh // 2:i + hh // 2 + oh, j - ww // 2:j + ww // 2 + ow] 220 | t = tgt[..., i - hh // 2:i + hh // 2 + oh, j - ww // 2:j + ww // 2 + ow] 221 | if len(cs) == 0: s = s[np.newaxis, ...] # add channel axis if empty 222 | if len(ct) == 0: t = t[np.newaxis, ...] 223 | return s, t 224 | 225 | def __repr__(self): 226 | s = '{name}(output_size={output_size}, threshold={thresh})' 227 | d = dict(self.__dict__) 228 | return s.format(name=self.__class__.__name__, **d) 229 | 230 | 231 | class RandomSlice(BaseTransform): 232 | """ 233 | take a random 2d slice from an image given a sample axis 234 | 235 | Args: 236 | axis (int): axis on which to take a slice 237 | div (float): divide the mean by this value in the calculation of mask 238 | the higher this value, the more background will be "valid" 239 | """ 240 | 241 | def __init__(self, axis: int = 0, div: float = 2): 242 | assert 0 <= axis <= 2 243 | self.axis = axis 244 | self.div = div 245 | 246 | def __call__(self, sample: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 247 | src, tgt = sample 248 | *cs, _, _, _ = src.shape 249 | *ct, _, _, _ = tgt.shape 250 | s = src[0] if len(cs) > 0 else src # use the first image to determine sampling if multimodal 251 | idx = np.random.choice(self._valid_idxs(s)[self.axis]) 252 | s = self._get_slice(src, idx) 253 | t = self._get_slice(tgt, idx) 254 | if len(cs) == 0: s = s[np.newaxis, ...] # add channel axis if empty 255 | if len(ct) == 0: t = t[np.newaxis, ...] 256 | return s, t 257 | 258 | def _get_slice(self, img: np.ndarray, idx: int): 259 | s = img[..., idx, :, :] if self.axis == 0 else \ 260 | img[..., :, idx, :] if self.axis == 1 else \ 261 | img[..., :, :, idx] 262 | return s 263 | 264 | def _valid_idxs(self, img: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 265 | """ get the set of indices from which to sample (foreground) """ 266 | mask = np.where(img > img.mean() / self.div) # returns a tuple of length 3 267 | h, w, d = [np.arange(np.min(m), np.max(m) + 1) for m in mask] # pull out the valid idx ranges 268 | return h, w, d 269 | 270 | 271 | class ToTensor(BaseTransform): 272 | """ Convert images in sample to Tensors """ 273 | 274 | def __init__(self, color=False): 275 | self.color = color 276 | 277 | def __call__(self, sample: Tuple[np.ndarray, np.ndarray]) -> Tuple[torch.Tensor, torch.Tensor]: 278 | src, tgt = sample 279 | if isinstance(src, np.ndarray) and isinstance(tgt, np.ndarray): 280 | return torch.from_numpy(src), torch.from_numpy(tgt) 281 | if isinstance(src, list): src = np.stack(src) 282 | if isinstance(tgt, list): src = np.stack(tgt) 283 | # handle PIL images 284 | src, tgt = np.asarray(src), np.asarray(tgt) 285 | if src.ndim == 3 and self.color: src = src.transpose((2, 0, 1)).astype(np.float32) 286 | if tgt.ndim == 3 and self.color: tgt = tgt.transpose((2, 0, 1)).astype(np.float32) 287 | if src.ndim == 2: src = src[None, ...] # add channel dimension 288 | if tgt.ndim == 2: tgt = tgt[None, ...] 289 | return torch.from_numpy(src.copy()), torch.from_numpy(tgt.copy()) 290 | 291 | 292 | class ToPILImage(BaseTransform): 293 | """ convert 2D image to PIL image """ 294 | 295 | def __init__(self, color=False): 296 | self.color = color 297 | 298 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]): 299 | src, tgt = sample 300 | src, tgt = np.squeeze(src), np.squeeze(tgt) 301 | if src.ndim == 3 and self.color: 302 | src = Image.fromarray(src.transpose((1, 2, 0)).astype(np.uint8)) 303 | elif src.ndim == 2: 304 | src = Image.fromarray(src) 305 | else: 306 | src = [Image.fromarray(s) for s in src] 307 | if tgt.ndim == 3 and self.color: 308 | tgt = Image.fromarray(tgt.transpose((1, 2, 0)).astype(np.uint8)) 309 | elif tgt.ndim == 2: 310 | tgt = Image.fromarray(tgt) 311 | else: 312 | tgt = [Image.fromarray(t) for t in tgt] 313 | return src, tgt 314 | 315 | 316 | class RandomAffine(tv.transforms.RandomAffine): 317 | """ apply random affine transformations to a sample of images """ 318 | 319 | def __init__(self, p: float, degrees: float, translate: float = 0, scale: float = 0, 320 | interpolation: int = TF.InterpolationMode.BILINEAR, 321 | segmentation=False): 322 | self.p = p 323 | self.degrees, self.translate, self.scale = (-degrees, degrees), (translate, translate), (1 - scale, 1 + scale) 324 | self.shear, self.fill = None, 0 325 | self.resample = self.interpolation = interpolation 326 | self.segmentation = segmentation 327 | 328 | def affine(self, x, params, interpolation=TF.InterpolationMode.BILINEAR): 329 | return TF.affine(x, *params, interpolation=interpolation, fill=0) 330 | 331 | def __call__(self, sample: Tuple[PILImage, PILImage]): 332 | src, tgt = sample 333 | ret = self.get_params(self.degrees, self.translate, self.scale, None, tgt.size) 334 | if self.degrees[1] > 0 and random.random() < self.p: 335 | if not isinstance(src, list): 336 | src = self.affine(src, ret, self.resample) 337 | else: 338 | src = [self.affine(s, ret, self.resample) for s in src] 339 | resample = Image.NEAREST if self.segmentation else self.resample 340 | if not isinstance(tgt, list): 341 | tgt = self.affine(tgt, ret, resample) 342 | else: 343 | tgt = [self.affine(t, ret, resample) for t in tgt] 344 | return src, tgt 345 | 346 | 347 | class RandomFlip: 348 | def __init__(self, p: float, vflip: bool = False, hflip: bool = False): 349 | self.p = p 350 | self.vflip, self.hflip = vflip, hflip 351 | 352 | def __call__(self, sample: Tuple[PILImage, PILImage]): 353 | src, tgt = sample 354 | if self.vflip and random.random() < self.p: 355 | if not isinstance(src, list): 356 | src = TF.vflip(src) 357 | else: 358 | src = [TF.vflip(s) for s in src] 359 | if not isinstance(tgt, list): 360 | tgt = TF.vflip(tgt) 361 | else: 362 | tgt = [TF.vflip(t) for t in tgt] 363 | if self.hflip and random.random() < self.p: 364 | if not isinstance(src, list): 365 | src = TF.hflip(src) 366 | else: 367 | src = [TF.hflip(s) for s in src] 368 | if not isinstance(tgt, list): 369 | tgt = TF.hflip(tgt) 370 | else: 371 | tgt = [TF.hflip(t) for t in tgt] 372 | return src, tgt 373 | 374 | def __repr__(self): 375 | s = '{name}(p={p}, vflip={vflip}, hflip={hflip})' 376 | d = dict(self.__dict__) 377 | return s.format(name=self.__class__.__name__, **d) 378 | 379 | 380 | class RandomGamma: 381 | """ apply random gamma transformations to a sample of images """ 382 | 383 | def __init__(self, p, tfm_y=False, gamma: float = 0., gain: float = 0.): 384 | self.p, self.tfm_y = p, tfm_y 385 | self.gamma, self.gain = (max(1 - gamma, 0), 1 + gamma), (max(1 - gain, 0), 1 + gain) 386 | 387 | @staticmethod 388 | def _make_pos(x): 389 | return x.min(), x - x.min() 390 | 391 | def _gamma(self, x, gain, gamma): 392 | is_pos = torch.all(x >= 0) 393 | if not is_pos: m, x = self._make_pos(x) 394 | x = gain * x ** gamma 395 | if not is_pos: x = x + m 396 | return x 397 | 398 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]): 399 | src, tgt = sample 400 | if random.random() < self.p: 401 | gamma = random.uniform(self.gamma[0], self.gamma[1]) 402 | gain = random.uniform(self.gain[0], self.gain[1]) 403 | src = self._gamma(src, gain, gamma) 404 | if self.tfm_y: tgt = self._gamma(tgt, gain, gamma) 405 | return src, tgt 406 | 407 | def __repr__(self): 408 | s = '{name}(p={p}, tfm_y={tfm_y}, gamma={gamma}, gain={gain})' 409 | d = dict(self.__dict__) 410 | return s.format(name=self.__class__.__name__, **d) 411 | 412 | 413 | class RandomNoise: 414 | """ add random gaussian noise to a sample of images """ 415 | 416 | def __init__(self, p, tfm_x=True, tfm_y=False, std: float = 0): 417 | self.p, self.tfm_x, self.tfm_y, self.std = p, tfm_x, tfm_y, std 418 | 419 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]): 420 | src, tgt = sample 421 | if self.std > 0 and random.random() < self.p: 422 | if self.tfm_x: src = src + torch.randn_like(src).mul(self.std) 423 | if self.tfm_y: tgt = tgt + torch.randn_like(tgt).mul(self.std) 424 | return src, tgt 425 | 426 | def __repr__(self): 427 | s = '{name}(p={p}, tfm_x={tfm_x}, tfm_y={tfm_y}, std={std})' 428 | d = dict(self.__dict__) 429 | return s.format(name=self.__class__.__name__, **d) 430 | 431 | 432 | class RandomBlock: 433 | """ add random blocks of random intensity to a sample of images """ 434 | 435 | def __init__(self, p, sz_range, thresh=None, int_range=None, tfm_x=True, tfm_y=False, is_3d=False): 436 | self.p, self.int, self.tfm_x, self.tfm_y = p, int_range, tfm_x, tfm_y 437 | self.sz = sz_range if all([isinstance(szr, (tuple, list)) for szr in sz_range]) else \ 438 | (sz_range, sz_range, sz_range) if is_3d else (sz_range, sz_range) 439 | self.thresh = thresh 440 | self.is_3d = is_3d 441 | 442 | def block2d(self, src, tgt): 443 | _, hmax, wmax = src.shape 444 | mask = np.where(src >= (src.mean() if self.thresh is None else self.thresh)) 445 | c = np.random.randint(0, len(mask[1])) # choose the set of idxs to use 446 | h, w = [m[c] for m in mask[1:]] # pull out the chosen idxs (2D) 447 | sh, sw = random.randrange(*self.sz[0]), random.randrange(*self.sz[1]) 448 | oh = 0 if sh % 2 == 0 else 1 449 | ow = 0 if sw % 2 == 0 else 1 450 | if h + (sh // 2) + oh >= hmax: h = hmax - (sh // 2) - oh 451 | if w + (sw // 2) + ow >= wmax: w = wmax - (sw // 2) - ow 452 | if h - (sh // 2) < 0: h = sh // 2 453 | if w - (sw // 2) < 0: w = sw // 2 454 | int_range = self.int if self.int is not None else (src.min(), src.max() + 1) 455 | if random.random() < self.p: 456 | if self.tfm_x: src[:, h - sh // 2:h + sh // 2 + oh, w - sw // 2:w + sw // 2 + ow] = np.random.uniform( 457 | *int_range) 458 | if self.tfm_y: tgt[:, h - sh // 2:h + sh // 2 + oh, w - sw // 2:w + sw // 2 + ow] = np.random.uniform( 459 | *int_range) 460 | return src, tgt 461 | 462 | def block3d(self, src, tgt): 463 | _, hmax, wmax, dmax = src.shape 464 | mask = np.where(src >= (src.mean() if self.thresh is None else self.thresh)) 465 | c = np.random.randint(0, len(mask[1])) # choose the set of idxs to use 466 | h, w, d = [m[c] for m in mask[1:]] # pull out the chosen idxs (2D) 467 | sh, sw, sd = random.randrange(*self.sz[0]), random.randrange(*self.sz[1]), random.randrange(*self.sz[2]) 468 | oh = 0 if sh % 2 == 0 else 1 469 | ow = 0 if sw % 2 == 0 else 1 470 | od = 0 if sd % 2 == 0 else 1 471 | if h + (sh // 2) + oh >= hmax: h = hmax - (sh // 2) - oh 472 | if w + (sw // 2) + ow >= wmax: w = wmax - (sw // 2) - ow 473 | if d + (sd // 2) + od >= dmax: d = dmax - (sd // 2) - od 474 | if h - (sh // 2) < 0: h = sh // 2 475 | if w - (sw // 2) < 0: w = sw // 2 476 | if d - (sd // 2) < 0: d = sd // 2 477 | int_range = self.int if self.int is not None else (src.min(), src.max() + 1) 478 | if isinstance(src, torch.Tensor): src, tgt = src.clone(), tgt.clone() 479 | if random.random() < self.p: 480 | if self.tfm_x: src[:, h - sh // 2:h + sh // 2 + oh, w - sw // 2:w + sw // 2 + ow, 481 | d - sd // 2:d + sd // 2 + od] = np.random.uniform(*int_range) 482 | if self.tfm_y: tgt[:, h - sh // 2:h + sh // 2 + oh, w - sw // 2:w + sw // 2 + ow, 483 | d - sd // 2:d + sd // 2 + od] = np.random.uniform(*int_range) 484 | return src, tgt 485 | 486 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]): 487 | src, tgt = sample 488 | src, tgt = self.block2d(src, tgt) if not self.is_3d else self.block3d(src, tgt) 489 | return src, tgt 490 | 491 | def __repr__(self): 492 | s = '{name}(p={p}, sz={sz}, int_range={int}, thresh={thresh}, tfm_x={tfm_x}, tfm_y={tfm_y}, is_3d={is_3d})' 493 | d = dict(self.__dict__) 494 | return s.format(name=self.__class__.__name__, **d) 495 | 496 | 497 | class AddChannel: 498 | """ Add empty first dimension to sample """ 499 | 500 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 501 | src, tgt = sample 502 | return (src.unsqueeze(0), tgt.unsqueeze(0)) 503 | 504 | 505 | class FixIntensityRange: 506 | """ put data in range of 0 to 1 """ 507 | 508 | def __init__(self, scale: float = 1): 509 | self.scale = scale 510 | 511 | def __call__(self, sample: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 512 | x, y = sample 513 | x = self.scale * ((x - x.min()) / (x.max() - x.min())) 514 | y = self.scale * ((y - y.min()) / (y.max() - y.min())) 515 | return x, y 516 | 517 | 518 | class Digitize: 519 | """ digitize a sample of images """ 520 | 521 | def __init__(self, tfm_x=False, tfm_y=True, int_range=(1, 100), step=1): 522 | self.tfm_x, self.tfm_y, self.range, self.step = tfm_x, tfm_y, int_range, step 523 | 524 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]): 525 | src, tgt = sample 526 | if self.tfm_x: src = np.digitize(src, np.arange(self.range[0], self.range[1], self.step)) 527 | if self.tfm_y: tgt = np.digitize(tgt, np.arange(self.range[0], self.range[1], self.step)) 528 | return src, tgt 529 | 530 | 531 | def normalize3d(tensor, mean, std, inplace=False): 532 | """ 533 | normalize a 3d tensor 534 | 535 | Args: 536 | tensor (Tensor): Tensor image of size (C, H, W, D) to be normalized. 537 | mean (sequence): Sequence of means for each channel. 538 | std (sequence): Sequence of standard deviations for each channel. 539 | 540 | Returns: 541 | Tensor: Normalized Tensor image. 542 | """ 543 | if not inplace: 544 | tensor = tensor.clone() 545 | 546 | mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device) 547 | std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device) 548 | tensor.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 549 | return tensor 550 | 551 | 552 | class Normalize: 553 | """ 554 | Implement a normalize function for input two images. 555 | It computes std and mean for each input Tensor if mean and std equal to None, 556 | then the function normalizes Tensor using the computed values. 557 | 558 | Args: 559 | mean: mean of input Tensor. if None passed, mean of each Tensor will be computed and normalization will be performed based on computed mean. 560 | std: standard deviation of input Tensor. if None passed, std of each Tensor will be computed and normalization will be performed based on computed std. 561 | tfm_x (bool): transform x or not 562 | tfm_y (bool): transform y or not 563 | is_3d (bool): is the Tensor 3d or not. this causes to normalize the Tensor on each channel. 564 | """ 565 | 566 | def __init__(self, mean=None, std=None, tfm_x: bool = True, tfm_y: bool = False, 567 | is_3d: bool = False): 568 | self.mean = mean 569 | self.std = std 570 | self.tfm_x = tfm_x 571 | self.tfm_y = tfm_y 572 | self.is_3d = is_3d 573 | 574 | def _tfm(self, tensor: torch.Tensor): 575 | if self.is_3d: 576 | norm = normalize3d 577 | mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device) if not ( 578 | self.mean is None) else tensor.mean(dim=(1, 2, 3)) 579 | std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device) if not ( 580 | self.std is None) else tensor.std(dim=(1, 2, 3)) 581 | # to prevent division by zero 582 | std[std == 0.] = 1e-6 583 | else: 584 | norm = tv.transforms.functional.normalize 585 | mean = self.mean if not (self.mean is None) else tensor.mean().item() 586 | std = self.std if not (self.std is None) else tensor.std().item() 587 | # to prevent division by zero 588 | if std == 0.: 589 | std = 1e-6 590 | return norm(tensor, mean, std) 591 | 592 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]): 593 | src, tgt = sample 594 | if self.tfm_x: src = self._tfm(src) 595 | if self.tfm_y: tgt = self._tfm(tgt) 596 | return src, tgt 597 | 598 | def __repr__(self): 599 | s = '{name}(mean={mean}, std={std}, tfm_x={tfm_x}, tfm_y={tfm_y}, is_3d={is_3d})' 600 | d = dict(self.__dict__) 601 | return s.format(name=self.__class__.__name__, **d) 602 | 603 | 604 | class MedianFilter: 605 | """ median filter the sample """ 606 | 607 | def __init__(self, tfm_x=True, tfm_y=False): 608 | try: 609 | from scipy.ndimage.filters import median_filter 610 | except (ModuleNotFoundError, ImportError): 611 | raise NiftiDatasetError('scipy not installed, cannot use median filter') 612 | self.filter = median_filter 613 | self.tfm_x = tfm_x 614 | self.tfm_y = tfm_y 615 | 616 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]): 617 | src, tgt = sample 618 | if self.tfm_x: src = self.filter(src, 3) 619 | if self.tfm_y: tgt = self.filter(tgt, 3) 620 | return src, tgt 621 | 622 | 623 | class TrimIntensity: 624 | """ 625 | Trims intensity to given interval [new_min, new_max]. 626 | Trim intensities that are outside range [min_val, max_val], then scale to [new_min, new_max]. 627 | """ 628 | 629 | def __init__(self, min_val: float, max_val: float, 630 | new_min: float = -1.0, new_max: float = 1.0, tfm_x: bool = True, tfm_y: bool = False): 631 | if min_val >= max_val: 632 | raise ValueError('min_val must be less than max_val') 633 | if new_min >= new_max: 634 | raise ValueError('new_min must be less than new_max') 635 | self.min_val = min_val 636 | self.max_val = max_val 637 | self.new_min = new_min 638 | self.new_max = new_max 639 | self.tfm_x = tfm_x 640 | self.tfm_y = tfm_y 641 | 642 | def _tfm(self, x: torch.Tensor): 643 | x = (x - self.min_val) / (self.max_val - self.min_val) 644 | x[x > 1] = 1. 645 | x[x < 0] = 0. 646 | diff = self.new_max - self.new_min 647 | x *= diff 648 | x += self.new_min 649 | return x 650 | 651 | def __call__(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 652 | src, tgt = sample 653 | if self.tfm_x: src = self._tfm(src) 654 | if self.tfm_y: tgt = self._tfm(tgt) 655 | return src, tgt 656 | 657 | 658 | def get_transforms(p: Union[list, float], tfm_x: bool = True, tfm_y: bool = False, degrees: float = 0, 659 | translate: float = None, scale: float = None, vflip: bool = False, hflip: bool = False, 660 | gamma: float = 0, gain: float = 0, noise_pwr: float = 0, block: Optional[Tuple[int, int]] = None, 661 | thresh: Optional[float] = None, is_3d: bool = False, 662 | mean: Optional[Tuple[float]] = None, std: Optional[Tuple[float]] = None, 663 | color: bool = False, segmentation: bool = False): 664 | """ get many desired transforms in a way s.t. can apply to nifti/tiffdatasets """ 665 | if isinstance(p, float): p = [p] * 5 666 | tfms = [] 667 | do_affine = p[0] > 0 and (degrees > 0 or translate > 0 or scale > 0) 668 | do_flip = p[1] > 0 and (vflip or hflip) 669 | if do_affine or do_flip: 670 | tfms.append(ToPILImage(color=color)) 671 | if do_affine: 672 | tfms.append(RandomAffine(p[0], degrees, translate, scale, segmentation=segmentation)) 673 | if do_flip: 674 | tfms.append(RandomFlip(p[1], vflip, hflip)) 675 | tfms.append(ToTensor(color)) 676 | if p[2] > 0 and (gamma is not None or gain is not None): 677 | tfms.append(RandomGamma(p[2], tfm_y, gamma, gain)) 678 | if p[3] > 0 and (block is not None): 679 | tfms.append(RandomBlock(p[3], block, thresh=thresh, tfm_x=tfm_x, tfm_y=tfm_y, is_3d=is_3d)) 680 | if p[4] > 0 and (noise_pwr > 0): 681 | tfms.append(RandomNoise(p[4], tfm_x, tfm_y, noise_pwr)) 682 | if mean is not None and std is not None: 683 | tfms.append(Normalize(mean, std, tfm_x, tfm_y, is_3d)) 684 | return tfms 685 | --------------------------------------------------------------------------------