├── tests ├── __init__.py ├── test_sanity.py ├── DataGenerators.py ├── test_random_crop.py ├── test_resample_augmentations.py ├── test_augment_zoom.py ├── test_spatial_transformations.py ├── test_axis_mirroring.py ├── test_multithreaded_augmenter.py ├── test_color_augmentations.py └── test_DataLoader.py ├── batchgenerators ├── __init__.py ├── datasets │ ├── __init__.py │ └── cifar.py ├── examples │ ├── __init__.py │ ├── brats2017 │ │ ├── __init__.py │ │ ├── config.py │ │ ├── readme.md │ │ ├── brats2017_dataloader_2D.py │ │ └── brats2017_preprocessing.py │ ├── multithreaded_dataloading.py │ └── cifar10.py ├── transforms │ ├── __init__.py │ ├── abstract_transforms.py │ ├── resample_transforms.py │ ├── sample_normalization_transforms.py │ ├── channel_selection_transforms.py │ ├── crop_and_pad_transforms.py │ └── color_transforms.py ├── utilities │ ├── __init__.py │ ├── data_splitting.py │ ├── custom_types.py │ └── file_and_folder_operations.py ├── augmentations │ ├── __init__.py │ ├── normalizations.py │ ├── resample_augmentations.py │ ├── noise_augmentations.py │ ├── color_augmentations.py │ └── crop_and_pad_augmentations.py └── dataloading │ ├── __init__.py │ ├── dataset.py │ ├── single_threaded_augmenter.py │ ├── nondet_multi_threaded_augmenter.py │ └── data_loader.py ├── setup.cfg ├── .arcconfig ├── DKFZ_Logo.png ├── HIP_Logo.png ├── requirements.txt ├── .travis.yml ├── Makefile ├── setup.py ├── .gitignore ├── Readme.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /batchgenerators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /batchgenerators/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /batchgenerators/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /batchgenerators/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /batchgenerators/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /batchgenerators/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /batchgenerators/dataloading/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /batchgenerators/examples/brats2017/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description_file = README.md -------------------------------------------------------------------------------- /.arcconfig: -------------------------------------------------------------------------------- 1 | { 2 | "phabricator.uri": "https://phabricator.mitk.org/" 3 | } 4 | -------------------------------------------------------------------------------- /DKFZ_Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/batchgenerators/HEAD/DKFZ_Logo.png -------------------------------------------------------------------------------- /HIP_Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/batchgenerators/HEAD/HIP_Logo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow>=7.1.2 2 | threadpoolctl 3 | scikit-learn 4 | numpy>=1.10.2 5 | scipy 6 | scikit-image 7 | scikit-learn 8 | unittest2 -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | matrix: 4 | include: 5 | - os: linux 6 | python: "3.7" 7 | dist: bionic 8 | # - os: windows 9 | # python: "3.6" 10 | 11 | install: 12 | - pip install . 13 | script: 14 | - pytest 15 | -------------------------------------------------------------------------------- /batchgenerators/examples/brats2017/config.py: -------------------------------------------------------------------------------- 1 | brats_preprocessed_folder = "/media/fabian/DeepLearningData/BraTS2017_preprocessed" 2 | brats_folder_with_downloaded_train_data = "/media/fabian/DeepLearningData/Brats17TrainingData" 3 | num_threads_for_brats_example = 8 -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | init: 2 | pip install numpy ; \ 3 | pip install -r requirements.txt 4 | 5 | test: 6 | python -m unittest discover 7 | 8 | install_develop: 9 | python setup.py develop 10 | 11 | install: 12 | python setup.py install 13 | 14 | documentation: 15 | sphinx-apidoc -e -f DeepLearningBatchGeneratorUtils -o doc/ 16 | -------------------------------------------------------------------------------- /batchgenerators/dataloading/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class Dataset(object): 5 | def __init__(self): 6 | __metaclass__ = ABCMeta 7 | 8 | @abstractmethod 9 | def __getitem__(self, item): 10 | ''' 11 | needs to return a data_dict for the sample at the position item 12 | :param item: 13 | :return: 14 | ''' 15 | pass 16 | 17 | @abstractmethod 18 | def __len__(self): 19 | ''' 20 | returns how many items the dataset has 21 | :return: 22 | ''' 23 | pass 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /tests/test_sanity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import unittest 17 | 18 | 19 | class TestSanity(unittest.TestCase): 20 | 21 | def test_sanity(self): 22 | self.assertTrue(1 == 1, "Sanity test failed") 23 | 24 | 25 | if __name__ == '__main__': 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='batchgenerators', 4 | version='0.25.1', 5 | description='Data augmentation toolkit', 6 | url='https://github.com/MIC-DKFZ/batchgenerators', 7 | author='Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, ' 8 | 'Helmholtz Imaging Platform', 9 | author_email='f.isensee@dkfz-heidelberg.de', 10 | license='Apache License Version 2.0, January 2004', 11 | packages=find_packages(exclude=["tests"]), 12 | install_requires=[ 13 | "pillow>=7.1.2", 14 | "numpy>=1.10.2", 15 | "scipy", 16 | "scikit-image", 17 | "scikit-learn", 18 | "future", 19 | "pandas", 20 | "unittest2", 21 | "threadpoolctl" 22 | ], 23 | keywords=['data augmentation', 'deep learning', 'image segmentation', 'image classification', 24 | 'medical image analysis', 'medical image segmentation'], 25 | ) 26 | -------------------------------------------------------------------------------- /batchgenerators/examples/brats2017/readme.md: -------------------------------------------------------------------------------- 1 | # BraTS2017/2018 example 2 | 3 | 4 | This folder contains a complete example of how to process BraTS2017/2018 data with batchgenerators. You need to 5 | adapt the scripts to match your system and download location. The adaptation should be straightforward. All you need to 6 | do is to change the paths and the number of threads in config.py, then execute `brats2017_preprocessing.py` for 7 | preprocessing the data. 8 | 9 | Once preprocessed, have a look at `brats2017_dataloader_2D.py` and `brats2017_dataloader_3D.py` for how to implement 10 | data loader for 2D and 3D network training, respectively. Naturally these files contain everything you need, including 11 | data augmentation and multiprocessing. They are not designed to be just executed because there is no network training 12 | in there. The idea is that you look at them and execute what code they have in a controlled manner so that you can 13 | get a feel for how batchgenerators work. Questions? -> f.isensee@dkfz.de 14 | 15 | Why are these not IPython Notebooks? I don't like IPython Notebooks. Simple =) 16 | 17 | **IMPORTANT** these DataLoaders are not suited for test set prediction! You need to iterate over preprocessed test 18 | data by yourself. 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | .idea 92 | 93 | .DS_Store 94 | .pytest_cache/* -------------------------------------------------------------------------------- /batchgenerators/utilities/data_splitting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from sklearn.model_selection import KFold 17 | import numpy as np 18 | 19 | 20 | def get_split_deterministic(all_keys, fold=0, num_splits=5, random_state=12345): 21 | """ 22 | Splits a list of patient identifiers (or numbers) into num_splits folds and returns the split for fold fold. 23 | :param all_keys: 24 | :param fold: 25 | :param num_splits: 26 | :param random_state: 27 | :return: 28 | """ 29 | all_keys_sorted = np.sort(list(all_keys)) 30 | splits = KFold(n_splits=num_splits, shuffle=True, random_state=random_state) 31 | for i, (train_idx, test_idx) in enumerate(splits.split(all_keys_sorted)): 32 | if i == fold: 33 | train_keys = np.array(all_keys_sorted)[train_idx] 34 | test_keys = np.array(all_keys_sorted)[test_idx] 35 | break 36 | return train_keys, test_keys 37 | -------------------------------------------------------------------------------- /batchgenerators/dataloading/single_threaded_augmenter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | class SingleThreadedAugmenter(object): 17 | """ 18 | Use this for debugging custom transforms. It does not use a background thread and you can therefore easily debug 19 | into your augmentations. This should not be used for training. If you want a generator that uses (a) background 20 | process(es), use MultiThreadedAugmenter. 21 | Args: 22 | data_loader (generator or DataLoaderBase instance): Your data loader. Must have a .next() function and return 23 | a dict that complies with our data structure 24 | 25 | transform (Transform instance): Any of our transformations. If you want to use multiple transformations then 26 | use our Compose transform! Can be None (in that case no transform will be applied) 27 | """ 28 | def __init__(self, data_loader, transform): 29 | self.data_loader = data_loader 30 | self.transform = transform 31 | 32 | def __iter__(self): 33 | return self 34 | 35 | def __next__(self): 36 | item = next(self.data_loader) 37 | if self.transform is not None: 38 | item = self.transform(**item) 39 | return item 40 | 41 | def next(self): 42 | return self.__next__() 43 | -------------------------------------------------------------------------------- /batchgenerators/utilities/custom_types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Union, Tuple, Any, Callable 17 | import numpy as np 18 | 19 | 20 | ScalarType = Union[int, float, Tuple[float, float], Callable[..., Union[int, float]]] 21 | 22 | 23 | def sample_scalar(scalar_type: ScalarType, *args): 24 | if isinstance(scalar_type, (int, float)): 25 | return scalar_type 26 | elif isinstance(scalar_type, (list, tuple)): 27 | assert len(scalar_type) == 2, 'if list is provided, its length must be 2' 28 | assert scalar_type[0] <= scalar_type[1], 'if list is provided, first entry must be smaller or equal than second entry, ' \ 29 | 'otherwise we cannot sample using np.random.uniform' 30 | if scalar_type[0] == scalar_type[1]: 31 | return scalar_type[0] 32 | return np.random.uniform(*scalar_type) 33 | elif callable(scalar_type): 34 | return scalar_type(*args) 35 | else: 36 | raise RuntimeError('Unknown type: %s. Expected: int, float, list, tuple, callable', type(scalar_type)) 37 | 38 | 39 | if __name__ == '__main__': 40 | sample_scalar(0.5) 41 | sample_scalar((0, 1)) 42 | sample_scalar(lambda: np.random.uniform(-1, 2)) 43 | sample_scalar(lambda x, y: np.random.uniform(x, y), 0.5, 2) 44 | -------------------------------------------------------------------------------- /tests/DataGenerators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import numpy as np 17 | from batchgenerators.dataloading.data_loader import SlimDataLoaderBase 18 | 19 | 20 | class BasicDataLoader(SlimDataLoaderBase): 21 | """ 22 | data is a tuple of images (b,c,x,y(,z)) and segmentations (b,c,x,y(,z)) 23 | """ 24 | 25 | def generate_train_batch(self): 26 | #Sample randomly from data 27 | idx = np.random.choice(self._data[0].shape[0], self.batch_size, True, None) 28 | # copy data to ensure that we are not modifying the original dataset with subsequeng augmentation techniques! 29 | x = np.array(self._data[0][idx]) 30 | y = np.array(self._data[1][idx]) 31 | data_dict = {"data": x, 32 | "seg": y} 33 | return data_dict 34 | 35 | 36 | class DummyGenerator(SlimDataLoaderBase): 37 | """ 38 | creates random data and seg of shape dataset_size and returns those. 39 | """ 40 | def __init__(self, dataset_size, batch_size, fill_data='random', fill_seg='ones'): 41 | if fill_data == "random": 42 | data = np.random.random(dataset_size) 43 | elif fill_data == "ones": 44 | data = np.ones(dataset_size) 45 | else: 46 | raise NotImplementedError 47 | 48 | if fill_seg == "ones": 49 | seg = np.ones(dataset_size) 50 | else: 51 | raise NotImplementedError 52 | 53 | super(DummyGenerator, self).__init__((data, seg), batch_size, None) 54 | 55 | def generate_train_batch(self): 56 | idx = np.random.choice(self._data[0].shape[0], self.batch_size) 57 | 58 | data = self._data[0][idx] 59 | seg = self._data[1][idx] 60 | return {'data': data, 'seg': seg} 61 | 62 | 63 | class OneDotDataLoader(SlimDataLoaderBase): 64 | def __init__(self, dataset_size, batch_size, coord_of_voxel): 65 | """ 66 | creates both data and seg with only one voxel being = 1 and the rest zero. This will allow easy tracking of 67 | spatial transformations 68 | :param data_size: (b,c,x,y(,z)) 69 | :param coord_of_voxel: (x, y(, z))) 70 | """ 71 | super(OneDotDataLoader, self).__init__(None, batch_size, None) 72 | 73 | self.data = np.zeros(dataset_size) 74 | self.seg = np.zeros(dataset_size) 75 | self.data[:, :][coord_of_voxel] = 1 76 | self.seg[:, :][coord_of_voxel] = 1 77 | 78 | def generate_train_batch(self): 79 | idx = np.random.choice(self.data.shape[0], self.batch_size) 80 | 81 | data = self.data[idx] 82 | seg = self.data[idx] 83 | return {'data': data, 'seg': seg} 84 | -------------------------------------------------------------------------------- /tests/test_random_crop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import unittest 17 | import numpy as np 18 | from batchgenerators.augmentations.crop_and_pad_augmentations import random_crop 19 | 20 | 21 | class TestRandomCrop(unittest.TestCase): 22 | 23 | def setUp(self): 24 | np.random.seed(1234) 25 | 26 | def test_random_crop_3D(self): 27 | data = np.random.random((32, 4, 64, 56, 48)) 28 | seg = np.ones(data.shape) 29 | 30 | d, s = random_crop(data, seg, 32, 0) 31 | 32 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32, 32), d.shape)), "data has unexpected return shape") 33 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32, 32), s.shape)), "seg has unexpected return shape") 34 | 35 | self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have" 36 | " happened here!") 37 | 38 | def test_random_crop_2D(self): 39 | data = np.random.random((32, 4, 64, 56)) 40 | seg = np.ones(data.shape) 41 | 42 | d, s = random_crop(data, seg, 32, 0) 43 | 44 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape") 45 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape") 46 | 47 | self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have" 48 | " happened here!") 49 | 50 | def test_random_crop_3D_from_List(self): 51 | data = [np.random.random((4, 64+i, 56+i, 48+i)) for i in range(32)] 52 | seg = [np.random.random((4, 64+i, 56+i, 48+i)) for i in range(32)] 53 | 54 | d, s = random_crop(data, seg, 32, 0) 55 | 56 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape") 57 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape") 58 | 59 | self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have" 60 | " happened here!") 61 | 62 | def test_random_crop_2D_from_List(self): 63 | data = [np.random.random((4, 64+i, 56+i)) for i in range(32)] 64 | seg = [np.random.random((4, 64+i, 56+i)) for i in range(32)] 65 | 66 | d, s = random_crop(data, seg, 32, 0) 67 | 68 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), d.shape)), "data has unexpected return shape") 69 | self.assertTrue(all(i == j for i, j in zip((32, 4, 32, 32), s.shape)), "seg has unexpected return shape") 70 | 71 | self.assertEqual(np.sum(s == 0), 0, "Zeros encountered in seg meaning that we did padding which should not have" 72 | " happened here!") 73 | 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /batchgenerators/transforms/abstract_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import abc 17 | from warnings import warn 18 | 19 | import numpy as np 20 | 21 | 22 | class AbstractTransform(object): 23 | __metaclass__ = abc.ABCMeta 24 | 25 | @abc.abstractmethod 26 | def __call__(self, **data_dict): 27 | raise NotImplementedError("Abstract, so implement") 28 | 29 | def __repr__(self): 30 | ret_str = str(type(self).__name__) + "( " + ", ".join( 31 | [key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )" 32 | return ret_str 33 | 34 | 35 | class RndTransform(AbstractTransform): 36 | """Applies a transformation with a specified probability 37 | 38 | Args: 39 | transform: The transformation (or composed transformation) 40 | 41 | prob: The probability with which to apply it 42 | 43 | alternative_transform: Will be applied if transform is not called. If transform alters for example the 44 | spatial dimension of the data, you need to compensate that with calling a dummy transformation that alters the 45 | spatial dimension in a similar way. We included this functionality because of SpatialTransform which has the 46 | ability to do cropping. If we want to not apply the spatial transformation we will still need to crop and 47 | therefore set the alternative_transform to an instance of RandomCropTransform of CenterCropTransform 48 | """ 49 | 50 | def __init__(self, transform, prob=0.5, alternative_transform=None): 51 | warn("This is deprecated. All applicable transfroms now have a p_per_sample argument which allows " 52 | "batchgenerators to do or not do an augmentation on a per-sample basis instead of the entire batch", 53 | DeprecationWarning) 54 | self.alternative_transform = alternative_transform 55 | self.transform = transform 56 | self.prob = prob 57 | 58 | def __call__(self, **data_dict): 59 | rnd_val = np.random.uniform() 60 | 61 | if rnd_val < self.prob: 62 | return self.transform(**data_dict) 63 | else: 64 | if self.alternative_transform is not None: 65 | return self.alternative_transform(**data_dict) 66 | else: 67 | return data_dict 68 | 69 | 70 | class Compose(AbstractTransform): 71 | """Composes several transforms together. 72 | 73 | Args: 74 | transforms (list of ``Transform`` objects): list of transforms to compose. 75 | 76 | Example: 77 | >>> transforms.Compose([ 78 | >>> transforms.CenterCrop(10), 79 | >>> transforms.ToTensor(), 80 | >>> ]) 81 | """ 82 | 83 | def __init__(self, transforms): 84 | self.transforms = transforms 85 | 86 | def __call__(self, **data_dict): 87 | for t in self.transforms: 88 | data_dict = t(**data_dict) 89 | return data_dict 90 | 91 | def __repr__(self): 92 | return str(type(self).__name__) + " ( " + repr(self.transforms) + " )" 93 | -------------------------------------------------------------------------------- /tests/test_resample_augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import unittest 17 | import numpy as np 18 | from batchgenerators.augmentations.resample_augmentations import augment_linear_downsampling_scipy 19 | 20 | 21 | class TestAugmentResample(unittest.TestCase): 22 | 23 | def setUp(self): 24 | np.random.seed(1234) 25 | self.data_3D = np.random.random((2, 64, 56, 48)) 26 | self.data_2D = np.random.random((2, 64, 56)) 27 | 28 | self.data_3D_unique = np.reshape(range(2 * 64 * 56 * 48), newshape=(2, 64, 56, 48)) 29 | self.data_2D_unique = np.reshape(range(2 * 64 * 56), newshape=(2, 64, 56)) 30 | 31 | self.d_3D = augment_linear_downsampling_scipy(np.copy(self.data_3D), zoom_range=[0.5, 1.5], per_channel=False) 32 | self.d_2D = augment_linear_downsampling_scipy(np.copy(self.data_2D), zoom_range=[0.5, 1.5], per_channel=False) 33 | 34 | self.d_3D_channel = augment_linear_downsampling_scipy(np.copy(self.data_3D), zoom_range=[0.5, 1.5], 35 | per_channel=False, channels=[0]) 36 | self.d_2D_channel = augment_linear_downsampling_scipy(np.copy(self.data_2D), zoom_range=[0.5, 1.5], 37 | per_channel=False, channels=[0]) 38 | 39 | self.zoom_factor = 0.5 40 | self.d_3D_upsample = augment_linear_downsampling_scipy(np.copy(self.data_3D_unique), 41 | zoom_range=[self.zoom_factor, self.zoom_factor], 42 | per_channel=False, order_downsample=0) 43 | self.d_2D_upsample = augment_linear_downsampling_scipy(np.copy(self.data_2D_unique), 44 | zoom_range=[self.zoom_factor, self.zoom_factor], 45 | per_channel=False, order_downsample=0) 46 | 47 | def test_augment_resample(self): 48 | self.assertTrue(self.data_3D.shape == self.d_3D.shape, 49 | "shape of transformed data not the same as original one (3D)") 50 | self.assertTrue(self.data_2D.shape == self.d_2D.shape, 51 | "shape of transformed data not the same as original one (2D)") 52 | 53 | def test_augment_resample_upsample(self): 54 | self.assertTrue(int(len(np.unique(self.data_3D_unique))*pow(self.zoom_factor, 3)) == len(np.unique(self.d_3D_upsample)), 55 | "number of unique values after resampling is not correct") 56 | 57 | def test_augment_resample_channel(self): 58 | np.testing.assert_array_equal(self.d_3D_channel[1], self.data_3D[1], 59 | "channel that should not be augmented is changed (3D)") 60 | np.testing.assert_array_equal(self.d_2D_channel[1], self.data_2D[1], 61 | "channel that should not be augmented is changed (2D)") 62 | 63 | self.assertFalse(np.all(self.d_3D_channel[0] == self.data_3D[0]), 64 | "channel that should be augmented is not changed (3D)") 65 | self.assertFalse(np.all(self.d_2D_channel[0] == self.data_2D[0]), 66 | "channel that should be augmented is not changed (2D)") 67 | 68 | 69 | if __name__ == '__main__': 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /batchgenerators/augmentations/normalizations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import numpy as np 17 | 18 | 19 | def range_normalization(data, rnge=(0, 1), per_channel=True, eps=1e-8): 20 | data_normalized = np.zeros(data.shape, dtype=data.dtype) 21 | for b in range(data.shape[0]): 22 | if per_channel: 23 | for c in range(data.shape[1]): 24 | data_normalized[b, c] = min_max_normalization(data[b, c], eps) 25 | else: 26 | data_normalized[b] = min_max_normalization(data[b], eps) 27 | 28 | data_normalized *= (rnge[1] - rnge[0]) 29 | data_normalized += rnge[0] 30 | return data_normalized 31 | 32 | 33 | def min_max_normalization(data, eps): 34 | mn = data.min() 35 | mx = data.max() 36 | data_normalized = data - mn 37 | old_range = mx - mn + eps 38 | data_normalized /= old_range 39 | 40 | return data_normalized 41 | 42 | def zero_mean_unit_variance_normalization(data, per_channel=True, epsilon=1e-8): 43 | data_normalized = np.zeros(data.shape, dtype=data.dtype) 44 | for b in range(data.shape[0]): 45 | if per_channel: 46 | for c in range(data.shape[1]): 47 | mean = data[b, c].mean() 48 | std = data[b, c].std() + epsilon 49 | data_normalized[b, c] = (data[b, c] - mean) / std 50 | else: 51 | mean = data[b].mean() 52 | std = data[b].std() + epsilon 53 | data_normalized[b] = (data[b] - mean) / std 54 | return data_normalized 55 | 56 | 57 | def mean_std_normalization(data, mean, std, per_channel=True): 58 | data_normalized = np.zeros(data.shape, dtype=data.dtype) 59 | if isinstance(data, np.ndarray): 60 | data_shape = tuple(list(data.shape)) 61 | elif isinstance(data, (list, tuple)): 62 | assert len(data) > 0 and isinstance(data[0], np.ndarray) 63 | data_shape = [len(data)] + list(data[0].shape) 64 | else: 65 | raise TypeError("Data has to be either a numpy array or a list") 66 | 67 | if per_channel and isinstance(mean, float) and isinstance(std, float): 68 | mean = [mean] * data_shape[1] 69 | std = [std] * data_shape[1] 70 | elif per_channel and isinstance(mean, (tuple, list, np.ndarray)): 71 | assert len(mean) == data_shape[1] 72 | elif per_channel and isinstance(std, (tuple, list, np.ndarray)): 73 | assert len(std) == data_shape[1] 74 | 75 | for b in range(data_shape[0]): 76 | if per_channel: 77 | for c in range(data_shape[1]): 78 | data_normalized[b][c] = (data[b][c] - mean[c]) / std[c] 79 | else: 80 | data_normalized[b] = (data[b] - mean) / std 81 | return data_normalized 82 | 83 | 84 | def cut_off_outliers(data, percentile_lower=0.2, percentile_upper=99.8, per_channel=False): 85 | for b in range(len(data)): 86 | if not per_channel: 87 | cut_off_lower = np.percentile(data[b], percentile_lower) 88 | cut_off_upper = np.percentile(data[b], percentile_upper) 89 | data[b][data[b] < cut_off_lower] = cut_off_lower 90 | data[b][data[b] > cut_off_upper] = cut_off_upper 91 | else: 92 | for c in range(data.shape[1]): 93 | cut_off_lower = np.percentile(data[b, c], percentile_lower) 94 | cut_off_upper = np.percentile(data[b, c], percentile_upper) 95 | data[b, c][data[b, c] < cut_off_lower] = cut_off_lower 96 | data[b, c][data[b, c] > cut_off_upper] = cut_off_upper 97 | return data 98 | -------------------------------------------------------------------------------- /tests/test_augment_zoom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import unittest 17 | import numpy as np 18 | from batchgenerators.augmentations.spatial_transformations import augment_zoom 19 | 20 | 21 | class TestAugmentZoom(unittest.TestCase): 22 | 23 | def setUp(self): 24 | np.random.seed(1234) 25 | self.data3D = np.zeros((2, 64, 56, 48)) 26 | self.data3D[:, 21:41, 12:32, 13:33] = 1 27 | self.seg3D = self.data3D 28 | 29 | self.zoom_factors = 2 30 | self.d3D, self.s3D = augment_zoom(self.data3D, self.seg3D, zoom_factors=self.zoom_factors, order=0, order_seg=0) 31 | 32 | self.data2D = np.zeros((2, 64, 56)) 33 | self.data2D[:, 21:41, 12:32] = 1 34 | self.seg2D = self.data2D 35 | self.d2D, self.s2D = augment_zoom(self.data2D, self.seg2D, zoom_factors=self.zoom_factors, order=0, order_seg=0) 36 | 37 | def test_augment_zoom_3D_dimensions(self): 38 | np.testing.assert_array_equal(self.zoom_factors * np.array(self.data3D.shape[1:]), np.array(self.d3D.shape[1:]), "image has unexpected return shape") 39 | self.assertTrue(self.data3D.shape[0] == self.d3D.shape[0], "color has unexpected return shape") 40 | self.assertTrue(self.seg3D.shape[0] == self.s3D.shape[0], "seg color channel has unexpected return shape") 41 | np.testing.assert_array_equal(self.zoom_factors * np.array(self.seg3D.shape[1:]), np.array(self.s3D.shape[1:]), "seg has unexpected return shape") 42 | 43 | def test_augment_zoom_3D_values(self): 44 | self.assertTrue(self.zoom_factors ** 3 * sum(self.data3D.flatten()) == sum(self.d3D.flatten()), "image has unexpected values inside") 45 | self.assertTrue(self.zoom_factors ** 3 * sum(self.seg3D.flatten()) == sum(self.s3D.flatten()), "segmentation has unexpected values inside") 46 | self.assertTrue(np.all(self.d3D[:, 42:82, 24:64, 26:66].flatten()), "image data is not zoomed correctly") 47 | idx = np.where(1 - self.d3D) 48 | tmp = self.d3D[idx] 49 | self.assertFalse(np.all(tmp), "image has unexpected values outside") 50 | idx = np.where(1 - self.s3D) 51 | tmp = self.s3D[idx] 52 | self.assertFalse(np.all(tmp), "segmentation has unexpected values outside") 53 | 54 | def test_augment_zoom_2D_dimensions(self): 55 | np.testing.assert_array_equal(self.zoom_factors * np.array(self.data2D.shape[1:]), np.array(self.d2D.shape[1:]), "image has unexpected return shape") 56 | self.assertTrue(self.data2D.shape[0] == self.d2D.shape[0], "color has unexpected return shape") 57 | self.assertTrue(self.seg2D.shape[0] == self.s2D.shape[0], "seg color channel has unexpected return shape") 58 | np.testing.assert_array_equal(self.zoom_factors * np.array(self.seg2D.shape[1:]), np.array(self.s2D.shape[1:]), "seg has unexpected return shape") 59 | 60 | def test_augment_zoom_2D_values(self): 61 | self.assertTrue(self.zoom_factors ** 2 * sum(self.data2D.flatten()) == sum(self.d2D.flatten()), "image has unexpected values inside") 62 | self.assertTrue(self.zoom_factors ** 2 * sum(self.seg2D.flatten()) == sum(self.s2D.flatten()), "segmentation has unexpected values inside") 63 | self.assertTrue(np.all(self.d2D[:, 42:82, 24:64].flatten()), "image data is not zoomed correctly") 64 | idx = np.where(1 - self.d2D) 65 | tmp = self.d2D[idx] 66 | self.assertFalse(np.all(tmp), "image has unexpected values outside") 67 | idx = np.where(1 - self.s2D) 68 | tmp = self.s2D[idx] 69 | self.assertFalse(np.all(tmp), "segmentation has unexpected values outside") 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /batchgenerators/augmentations/resample_augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from builtins import range 17 | import numpy as np 18 | import random 19 | from skimage.transform import resize 20 | from batchgenerators.augmentations.utils import uniform 21 | 22 | 23 | def augment_linear_downsampling_scipy(data_sample, zoom_range=(0.5, 1), per_channel=True, p_per_channel=1, 24 | channels=None, order_downsample=1, order_upsample=0, ignore_axes=None): 25 | ''' 26 | Downsamples each sample (linearly) by a random factor and upsamples to original resolution again (nearest neighbor) 27 | 28 | Info: 29 | * Uses scipy zoom for resampling. A bit faster than nilearn. 30 | * Resamples all dimensions (channels, x, y, z) with same downsampling factor (like isotropic=True from 31 | linear_downsampling_generator_nilearn) 32 | 33 | Args: 34 | zoom_range: can be either tuple/list/np.ndarray or tuple of tuple. If tuple/list/np.ndarray, then the zoom 35 | factor will be sampled from zoom_range[0], zoom_range[1] (zoom < 0 = downsampling!). If tuple of tuple then 36 | each inner tuple will give a sampling interval for each axis (allows for different range of zoom values for 37 | each axis 38 | 39 | p_per_channel: probability for downsampling/upsampling a channel 40 | 41 | per_channel (bool): whether to draw a new zoom_factor for each channel or keep one for all channels 42 | 43 | channels (list, tuple): if None then all channels can be augmented. If list then only the channel indices can 44 | be augmented (but may not always be depending on p_per_channel) 45 | 46 | order_downsample: 47 | 48 | order_upsample: 49 | 50 | ignore_axes: tuple/list 51 | 52 | ''' 53 | if not isinstance(zoom_range, (list, tuple, np.ndarray)): 54 | zoom_range = [zoom_range] 55 | 56 | shp = np.array(data_sample.shape[1:]) 57 | dim = len(shp) 58 | 59 | if not per_channel: 60 | if isinstance(zoom_range[0], (tuple, list, np.ndarray)): 61 | assert len(zoom_range) == dim 62 | zoom = np.array([uniform(i[0], i[1]) for i in zoom_range]) 63 | else: 64 | zoom = uniform(zoom_range[0], zoom_range[1]) 65 | 66 | target_shape = np.round(shp * zoom).astype(int) 67 | 68 | if ignore_axes is not None: 69 | for i in ignore_axes: 70 | target_shape[i] = shp[i] 71 | 72 | if channels is None: 73 | channels = list(range(data_sample.shape[0])) 74 | 75 | for c in channels: 76 | if np.random.uniform() < p_per_channel: 77 | if per_channel: 78 | if isinstance(zoom_range[0], (tuple, list, np.ndarray)): 79 | assert len(zoom_range) == dim 80 | zoom = np.array([uniform(i[0], i[1]) for i in zoom_range]) 81 | else: 82 | zoom = uniform(zoom_range[0], zoom_range[1]) 83 | 84 | target_shape = np.round(shp * zoom).astype(int) 85 | if ignore_axes is not None: 86 | for i in ignore_axes: 87 | target_shape[i] = shp[i] 88 | 89 | downsampled = resize(data_sample[c].astype(float), target_shape, order=order_downsample, mode='edge', 90 | anti_aliasing=False) 91 | data_sample[c] = resize(downsampled, shp, order=order_upsample, mode='edge', 92 | anti_aliasing=False) 93 | 94 | return data_sample 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /batchgenerators/examples/multithreaded_dataloading.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.dataloading.data_loader import SlimDataLoaderBase 2 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter 3 | import numpy as np 4 | 5 | 6 | class DummyDL(SlimDataLoaderBase): 7 | def __init__(self, num_threads_in_mt=8): 8 | super(DummyDL, self).__init__(None, None, num_threads_in_mt) 9 | self._data = list(range(100)) 10 | self.current_position = 0 11 | self.was_initialized = False 12 | 13 | def reset(self): 14 | self.current_position = self.thread_id 15 | self.was_initialized = True 16 | 17 | def generate_train_batch(self): 18 | if not self.was_initialized: 19 | self.reset() 20 | idx = self.current_position 21 | if idx < len(self._data): 22 | self.current_position = idx + self.number_of_threads_in_multithreaded 23 | return self._data[idx] 24 | else: 25 | self.reset() 26 | raise StopIteration 27 | 28 | 29 | class DummyDLWithShuffle(DummyDL): 30 | def __init__(self, num_threads_in_mt=8): 31 | super(DummyDLWithShuffle, self).__init__(num_threads_in_mt) 32 | self.num_restarted = 0 33 | self.data_order = np.arange(len(self._data)) 34 | 35 | def reset(self): 36 | super(DummyDLWithShuffle, self).reset() 37 | rs = np.random.RandomState(self.num_restarted) 38 | rs.shuffle(self.data_order) 39 | self.num_restarted = self.num_restarted + 1 40 | 41 | def generate_train_batch(self): 42 | if not self.was_initialized: 43 | self.reset() 44 | idx = self.current_position 45 | if idx < len(self._data): 46 | self.current_position = idx + self.number_of_threads_in_multithreaded 47 | return self._data[self.data_order[idx]] 48 | else: 49 | self.reset() 50 | raise StopIteration 51 | 52 | 53 | if __name__ == "__main__": 54 | """ 55 | Why is is so hard to iterate only once over my entire training dataset when MultiThreadedAugmenter is used? 56 | This is because MultiThreadedAugmenter will spawn num_threads workers and each worker will hold a copy of the entire 57 | pipeline, including the DataLoader. Therefore, if your DataLoader is configured to run over the training data once, but 58 | you have 8 threads then what you will be getting from the MultiThreadedAugmenter is an iteration over eight times your 59 | training dataset""" 60 | 61 | """ 62 | HELP I want to iterate over all my training data once per epoch. 63 | Say no more. We go your back. Here is a simple example how you can do that. 64 | 65 | We create a dummy dataloader that has the numbers of 0 to 99 in its _data variable. In the MultiThreadedAugmenter, each 66 | DataLoader will know what thread ID it has. We use that information to iterate over the training data. Since there are 67 | 3 threads, each individual dataloader must return every third item (and start in a different position) 68 | """ 69 | 70 | dl = DummyDL(num_threads_in_mt=3) 71 | mt = MultiThreadedAugmenter(dl, None, 3, 1, None) 72 | 73 | for i in mt: 74 | print(i) 75 | 76 | 77 | """ 78 | You can run the mt as often as you want because the DataLoader it will reset itself before raising StopIteration 79 | """ 80 | for i in mt: 81 | print(i) 82 | 83 | for i in mt: 84 | print(i) 85 | 86 | 87 | """ 88 | But wait. Isn't it suboptimal to iterate over training data always in the same order? Correct. Try this: 89 | """ 90 | 91 | dl = DummyDLWithShuffle(num_threads_in_mt=3) 92 | mt = MultiThreadedAugmenter(dl, None, 3, 1, None) 93 | 94 | batches = [] 95 | for i in mt: 96 | batches.append(i) 97 | print(batches) 98 | assert len(np.unique(batches)) == 100 and len(batches) == 100 # assert makes sure we got what we wanted 99 | 100 | """ 101 | Once again you can run that as often as you want 102 | """ 103 | 104 | batches = [] 105 | for i in mt: 106 | batches.append(i) 107 | print(batches) 108 | assert len(np.unique(batches)) == 100 and len(batches) == 100 109 | 110 | batches = [] 111 | for i in mt: 112 | batches.append(i) 113 | print(batches) 114 | assert len(np.unique(batches)) == 100 and len(batches) == 100 115 | -------------------------------------------------------------------------------- /batchgenerators/augmentations/noise_augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | from typing import Tuple 18 | 19 | import numpy as np 20 | from batchgenerators.augmentations.utils import get_range_val, mask_random_squares 21 | from builtins import range 22 | from scipy.ndimage import gaussian_filter 23 | 24 | 25 | def augment_rician_noise(data_sample, noise_variance=(0, 0.1)): 26 | variance = random.uniform(noise_variance[0], noise_variance[1]) 27 | data_sample = np.sqrt( 28 | (data_sample + np.random.normal(0.0, variance, size=data_sample.shape)) ** 2 + 29 | np.random.normal(0.0, variance, size=data_sample.shape) ** 2) * np.sign(data_sample) 30 | return data_sample 31 | 32 | 33 | def augment_gaussian_noise(data_sample: np.ndarray, noise_variance: Tuple[float, float] = (0, 0.1), 34 | p_per_channel: float = 1, per_channel: bool = False) -> np.ndarray: 35 | if not per_channel: 36 | variance = noise_variance[0] if noise_variance[0] == noise_variance[1] else \ 37 | random.uniform(noise_variance[0], noise_variance[1]) 38 | else: 39 | variance = None 40 | for c in range(data_sample.shape[0]): 41 | if np.random.uniform() < p_per_channel: 42 | # lol good luck reading this 43 | variance_here = variance if variance is not None else \ 44 | noise_variance[0] if noise_variance[0] == noise_variance[1] else \ 45 | random.uniform(noise_variance[0], noise_variance[1]) 46 | # bug fixed: https://github.com/MIC-DKFZ/batchgenerators/issues/86 47 | data_sample[c] = data_sample[c] + np.random.normal(0.0, variance_here, size=data_sample[c].shape) 48 | return data_sample 49 | 50 | 51 | def augment_gaussian_blur(data_sample: np.ndarray, sigma_range: Tuple[float, float], per_channel: bool = True, 52 | p_per_channel: float = 1, different_sigma_per_axis: bool = False, 53 | p_isotropic: float = 0) -> np.ndarray: 54 | if not per_channel: 55 | # Godzilla Had a Stroke Trying to Read This and F***ing Died 56 | # https://i.kym-cdn.com/entries/icons/original/000/034/623/Untitled-3.png 57 | sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or 58 | ((np.random.uniform() < p_isotropic) and 59 | different_sigma_per_axis)) \ 60 | else [get_range_val(sigma_range) for _ in data_sample.shape[1:]] 61 | else: 62 | sigma = None 63 | for c in range(data_sample.shape[0]): 64 | if np.random.uniform() <= p_per_channel: 65 | if per_channel: 66 | sigma = get_range_val(sigma_range) if ((not different_sigma_per_axis) or 67 | ((np.random.uniform() < p_isotropic) and 68 | different_sigma_per_axis)) \ 69 | else [get_range_val(sigma_range) for _ in data_sample.shape[1:]] 70 | data_sample[c] = gaussian_filter(data_sample[c], sigma, order=0) 71 | return data_sample 72 | 73 | 74 | def augment_blank_square_noise(data_sample, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False, 75 | square_pos=None): 76 | # rnd_n_val = get_range_val(noise_val) 77 | rnd_square_size = get_range_val(square_size) 78 | rnd_n_squares = get_range_val(n_squares) 79 | 80 | data_sample = mask_random_squares(data_sample, square_size=rnd_square_size, n_squares=rnd_n_squares, 81 | n_val=noise_val, channel_wise_n_val=channel_wise_n_val, 82 | square_pos=square_pos) 83 | return data_sample 84 | -------------------------------------------------------------------------------- /batchgenerators/transforms/resample_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from warnings import warn 17 | 18 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 19 | from batchgenerators.augmentations.resample_augmentations import augment_linear_downsampling_scipy 20 | import numpy as np 21 | 22 | 23 | class SimulateLowResolutionTransform(AbstractTransform): 24 | """Downsamples each sample (linearly) by a random factor and upsamples to original resolution again 25 | (nearest neighbor) 26 | 27 | Info: 28 | * Uses scipy zoom for resampling. 29 | * Resamples all dimensions (channels, x, y, z) with same downsampling factor (like isotropic=True from 30 | linear_downsampling_generator_nilearn) 31 | 32 | Args: 33 | zoom_range: can be either tuple/list/np.ndarray or tuple of tuple. If tuple/list/np.ndarray, then the zoom 34 | factor will be sampled from zoom_range[0], zoom_range[1] (zoom < 0 = downsampling!). If tuple of tuple then 35 | each inner tuple will give a sampling interval for each axis (allows for different range of zoom values for 36 | each axis 37 | 38 | p_per_channel: 39 | 40 | per_channel (bool): whether to draw a new zoom_factor for each channel or keep one for all channels 41 | 42 | channels (list, tuple): if None then all channels can be augmented. If list then only the channel indices can 43 | be augmented (but may not always be depending on p_per_channel) 44 | 45 | order_downsample: 46 | 47 | order_upsample: 48 | """ 49 | 50 | def __init__(self, zoom_range=(0.5, 1), per_channel=False, p_per_channel=1, 51 | channels=None, order_downsample=1, order_upsample=0, data_key="data", p_per_sample=1, 52 | ignore_axes=None): 53 | self.order_upsample = order_upsample 54 | self.order_downsample = order_downsample 55 | self.channels = channels 56 | self.per_channel = per_channel 57 | self.p_per_channel = p_per_channel 58 | self.p_per_sample = p_per_sample 59 | self.data_key = data_key 60 | self.zoom_range = zoom_range 61 | self.ignore_axes = ignore_axes 62 | 63 | def __call__(self, **data_dict): 64 | for b in range(len(data_dict[self.data_key])): 65 | if np.random.uniform() < self.p_per_sample: 66 | data_dict[self.data_key][b] = augment_linear_downsampling_scipy(data_dict[self.data_key][b], 67 | zoom_range=self.zoom_range, 68 | per_channel=self.per_channel, 69 | p_per_channel=self.p_per_channel, 70 | channels=self.channels, 71 | order_downsample=self.order_downsample, 72 | order_upsample=self.order_upsample, 73 | ignore_axes=self.ignore_axes) 74 | return data_dict 75 | 76 | 77 | class ResampleTransform(SimulateLowResolutionTransform): 78 | def __init__(self, zoom_range=(0.5, 1), per_channel=False, p_per_channel=1, 79 | channels=None, order_downsample=1, order_upsample=0, data_key="data", p_per_sample=1): 80 | warn("This class is deprecated. It was renamed to SimulateLowResolutionTransform. Please change your code", 81 | DeprecationWarning) 82 | super(ResampleTransform, self).__init__(zoom_range, per_channel, p_per_channel, 83 | channels, order_downsample, order_upsample, data_key, p_per_sample) 84 | -------------------------------------------------------------------------------- /batchgenerators/transforms/sample_normalization_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from batchgenerators.augmentations.normalizations import cut_off_outliers, mean_std_normalization, range_normalization, \ 17 | zero_mean_unit_variance_normalization 18 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 19 | 20 | 21 | class RangeTransform(AbstractTransform): 22 | '''Rescales data into the specified range 23 | 24 | Args: 25 | rnge (tuple of float): The range to which the data is scaled 26 | 27 | per_channel (bool): determines whether the min and max values used for the rescaling are computed over the whole 28 | sample or separately for each channel 29 | 30 | ''' 31 | 32 | def __init__(self, rnge=(0, 1), per_channel=True, data_key="data", label_key="seg"): 33 | self.data_key = data_key 34 | self.label_key = label_key 35 | self.per_channel = per_channel 36 | self.rnge = rnge 37 | 38 | def __call__(self, **data_dict): 39 | data_dict[self.data_key] = range_normalization(data_dict[self.data_key], self.rnge, 40 | per_channel=self.per_channel) 41 | return data_dict 42 | 43 | 44 | class CutOffOutliersTransform(AbstractTransform): 45 | """ Removes outliers from data 46 | 47 | Args: 48 | percentile_lower (float between 0 and 100): Lower cutoff percentile 49 | 50 | percentile_upper (float between 0 and 100): Upper cutoff percentile 51 | 52 | per_channel (bool): determines whether percentiles are computed for each color channel separately 53 | """ 54 | 55 | def __init__(self, percentile_lower=0.2, percentile_upper=99.8, per_channel=False, data_key="data", 56 | label_key="seg"): 57 | self.data_key = data_key 58 | self.label_key = label_key 59 | self.per_channel = per_channel 60 | self.percentile_upper = percentile_upper 61 | self.percentile_lower = percentile_lower 62 | 63 | def __call__(self, **data_dict): 64 | data_dict[self.data_key] = cut_off_outliers(data_dict[self.data_key], self.percentile_lower, 65 | self.percentile_upper, 66 | per_channel=self.per_channel) 67 | return data_dict 68 | 69 | 70 | class ZeroMeanUnitVarianceTransform(AbstractTransform): 71 | """ Zero mean unit variance transform 72 | 73 | Args: 74 | per_channel (bool): determines whether mean and std are computed for and applied to each color channel 75 | separately 76 | 77 | epsilon (float): prevent nan if std is zero, keep at 1e-7 78 | """ 79 | 80 | def __init__(self, per_channel=True, epsilon=1e-7, data_key="data", label_key="seg"): 81 | self.data_key = data_key 82 | self.label_key = label_key 83 | self.epsilon = epsilon 84 | self.per_channel = per_channel 85 | 86 | def __call__(self, **data_dict): 87 | data_dict[self.data_key] = zero_mean_unit_variance_normalization(data_dict[self.data_key], self.per_channel, 88 | self.epsilon) 89 | return data_dict 90 | 91 | 92 | class MeanStdNormalizationTransform(AbstractTransform): 93 | """ Zero mean unit variance transform 94 | 95 | Args: 96 | per_channel (bool): determines whether mean and std are computed for and applied to each color channel 97 | separately 98 | 99 | epsilon (float): prevent nan if std is zero, keep at 1e-7 100 | """ 101 | 102 | def __init__(self, mean, std, per_channel=True, data_key="data", label_key="seg"): 103 | self.data_key = data_key 104 | self.label_key = label_key 105 | self.std = std 106 | self.mean = mean 107 | self.per_channel = per_channel 108 | 109 | def __call__(self, **data_dict): 110 | data_dict[self.data_key] = mean_std_normalization(data_dict[self.data_key], self.mean, self.std, 111 | self.per_channel) 112 | return data_dict 113 | -------------------------------------------------------------------------------- /batchgenerators/datasets/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tarfile 4 | from urllib.request import urlretrieve 5 | 6 | import numpy as np 7 | from batchgenerators.dataloading.data_loader import DataLoader 8 | from batchgenerators.dataloading.dataset import Dataset 9 | from batchgenerators.utilities.file_and_folder_operations import join 10 | 11 | 12 | def unpickle(file): 13 | ''' 14 | taken from http://www.cs.toronto.edu/~kriz/cifar.html 15 | :param file: 16 | :return: 17 | ''' 18 | import pickle 19 | 20 | with open(file, 'rb') as fo: 21 | dc = pickle.load(fo, encoding='bytes') 22 | return dc 23 | 24 | 25 | def maybe_download_and_prepare_cifar(target_dir, cifar=10): 26 | ''' 27 | Checks if cifar is already present in target_dir and downloads it if not. 28 | CIFAR comes in 5 batches that need to be unpickled. What a mess. 29 | We stack all 5 batches together to one single npy array. No idea why they are being so complicated 30 | :param target_dir: 31 | :return: 32 | ''' 33 | if not os.path.isfile(os.path.join(target_dir, 'cifar%d_test_data.npz' % cifar)) or not \ 34 | os.path.isfile(os.path.join(target_dir, 'cifar%d_training_data.npz' % cifar)): 35 | print('downloading CIFAR%d...' % cifar) 36 | urlretrieve('http://www.cs.toronto.edu/~kriz/cifar-%d-python.tar.gz' % cifar, join(target_dir, 'cifar-%d-python.tar.gz' % cifar)) 37 | 38 | tar = tarfile.open(os.path.join(target_dir, 'cifar-%d-python.tar.gz' % cifar), "r:gz") 39 | tar.extractall(path=target_dir) 40 | tar.close() 41 | 42 | data = [] 43 | labels = [] 44 | filenames = [] 45 | 46 | for batch in range(1, 6): 47 | loaded = unpickle(os.path.join(target_dir, 'cifar-%d-batches-py' % cifar, 'data_batch_%d' % batch)) 48 | data.append(loaded[b'data'].reshape((loaded[b'data'].shape[0], 3, 32, 32)).astype(np.uint8)) 49 | labels += [int(i) for i in loaded[b'labels']] 50 | filenames += [str(i) for i in loaded[b'filenames']] 51 | 52 | data = np.vstack(data) 53 | labels = np.array(labels) 54 | filenames = np.array(filenames) 55 | 56 | np.savez_compressed(os.path.join(target_dir, 'cifar%d_training_data.npz' % cifar), data=data, labels=labels, 57 | filenames=filenames) 58 | 59 | test = unpickle(os.path.join(target_dir, 'cifar-%d-batches-py' % cifar, 'test_batch')) 60 | data = test[b'data'].reshape((test[b'data'].shape[0], 3, 32, 32)).astype(np.uint8) 61 | labels = [int(i) for i in test[b'labels']] 62 | filenames = [i for i in test[b'filenames']] 63 | 64 | np.savez_compressed(os.path.join(target_dir, 'cifar%d_test_data.npz' % cifar), data=data, labels=labels, 65 | filenames=filenames) 66 | 67 | # clean up 68 | shutil.rmtree(os.path.join(target_dir, 'cifar-%d-batches-py' % cifar)) 69 | os.remove(os.path.join(target_dir, 'cifar-%d-python.tar.gz' % cifar)) 70 | 71 | 72 | class CifarDataset(Dataset): 73 | def __init__(self, dataset_directory, train=True, transform=None, cifar=10): 74 | super(CifarDataset, self).__init__() 75 | self.transform = transform 76 | maybe_download_and_prepare_cifar(dataset_directory) 77 | 78 | self.train = train 79 | 80 | # load appropriate data 81 | if train: 82 | fname = os.path.join(dataset_directory, 'cifar%d_training_data.npz' % cifar) 83 | else: 84 | fname = os.path.join(dataset_directory, 'cifar%d_test_data.npz' % cifar) 85 | 86 | dataset = np.load(fname) 87 | 88 | self.data = dataset['data'] 89 | self.labels = dataset['labels'] 90 | self.filenames = dataset['filenames'] 91 | 92 | def __getitem__(self, item): 93 | data_dict = {'data': self.data[item:item+1].astype(np.float32), 'labels': self.labels[item], 'filenames': self.filenames[item]} 94 | if self.transform is not None: 95 | data_dict = self.transform(**data_dict) 96 | return data_dict 97 | 98 | def __len__(self): 99 | return len(self.data) 100 | 101 | 102 | class HighPerformanceCIFARLoader(DataLoader): 103 | def __init__(self, data, batch_size, num_threads_in_multithreaded, seed_for_shuffle=1, infinite=False, 104 | return_incomplete=False): 105 | super(HighPerformanceCIFARLoader, self).__init__(data, batch_size, num_threads_in_multithreaded, 106 | seed_for_shuffle, infinite=infinite, 107 | return_incomplete=return_incomplete) 108 | self.indices = np.arange(len(data[0])) 109 | 110 | def generate_train_batch(self): 111 | indices = self.get_indices() 112 | 113 | data = self._data[0][indices] 114 | labels = self._data[1][indices] 115 | filenames = self._data[2][indices] 116 | 117 | return {'data': data.astype(np.float32), 'labels': labels, 'filenames': filenames} 118 | 119 | -------------------------------------------------------------------------------- /batchgenerators/utilities/file_and_folder_operations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import pickle 18 | import json 19 | from typing import List, Union, Optional 20 | 21 | 22 | def subdirs(folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True) -> List[str]: 23 | """ 24 | Returns a list of subdirectories in a given folder, optionally filtering by prefix and suffix, 25 | and optionally sorting the results. Uses os.scandir for efficient directory traversal. 26 | 27 | Parameters: 28 | - folder: Path to the folder to list subdirectories from. 29 | - join: Whether to return full paths to subdirectories (if True) or just directory names (if False). 30 | - prefix: Only include subdirectories that start with this prefix (if provided). 31 | - suffix: Only include subdirectories that end with this suffix (if provided). 32 | - sort: Whether to sort the list of subdirectories alphabetically. 33 | 34 | Returns: 35 | - List of subdirectory paths (or names) meeting the specified criteria. 36 | """ 37 | subdirectories = [] 38 | with os.scandir(folder) as entries: 39 | for entry in entries: 40 | if entry.is_dir() and \ 41 | (prefix is None or entry.name.startswith(prefix)) and \ 42 | (suffix is None or entry.name.endswith(suffix)): 43 | dir_path = entry.path if join else entry.name 44 | subdirectories.append(dir_path) 45 | 46 | if sort: 47 | subdirectories.sort() 48 | 49 | return subdirectories 50 | 51 | 52 | def subfiles(folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True) -> List[str]: 53 | """ 54 | Returns a list of files in a given folder, optionally filtering by prefix and suffix, 55 | and optionally sorting the results. Uses os.scandir for efficient directory traversal, 56 | making it suitable for network drives. 57 | 58 | Parameters: 59 | - folder: Path to the folder to list files from. 60 | - join: Whether to return full file paths (if True) or just file names (if False). 61 | - prefix: Only include files that start with this prefix (if provided). 62 | - suffix: Only include files that end with this suffix (if provided). 63 | - sort: Whether to sort the list of files alphabetically. 64 | 65 | Returns: 66 | - List of file paths (or names) meeting the specified criteria. 67 | """ 68 | files = [] 69 | with os.scandir(folder) as entries: 70 | for entry in entries: 71 | if entry.is_file() and \ 72 | (prefix is None or entry.name.startswith(prefix)) and \ 73 | (suffix is None or entry.name.endswith(suffix)): 74 | file_path = entry.path if join else entry.name 75 | files.append(file_path) 76 | 77 | if sort: 78 | files.sort() 79 | 80 | return files 81 | 82 | 83 | def nifti_files(folder: str, join: bool = True, sort: bool = True) -> List[str]: 84 | return subfiles(folder, join=join, sort=sort, suffix='.nii.gz') 85 | 86 | 87 | def maybe_mkdir_p(directory: str) -> None: 88 | os.makedirs(directory, exist_ok=True) 89 | 90 | 91 | def load_pickle(file: str, mode: str = 'rb'): 92 | with open(file, mode) as f: 93 | a = pickle.load(f) 94 | return a 95 | 96 | 97 | def write_pickle(obj, file: str, mode: str = 'wb') -> None: 98 | with open(file, mode) as f: 99 | pickle.dump(obj, f) 100 | 101 | 102 | def load_json(file: str): 103 | with open(file, 'r') as f: 104 | a = json.load(f) 105 | return a 106 | 107 | 108 | def save_json(obj, file: str, indent: int = 4, sort_keys: bool = True) -> None: 109 | with open(file, 'w') as f: 110 | json.dump(obj, f, sort_keys=sort_keys, indent=indent) 111 | 112 | 113 | def pardir(path: str): 114 | return os.path.join(path, os.pardir) 115 | 116 | 117 | def split_path(path: str) -> List[str]: 118 | """ 119 | splits at each separator. This is different from os.path.split which only splits at last separator 120 | """ 121 | return path.split(os.sep) 122 | 123 | 124 | # I'm tired of typing these out 125 | join = os.path.join 126 | isdir = os.path.isdir 127 | isfile = os.path.isfile 128 | listdir = os.listdir 129 | makedirs = maybe_mkdir_p 130 | os_split_path = os.path.split 131 | 132 | # I am tired of confusing those 133 | subfolders = subdirs 134 | save_pickle = write_pickle 135 | write_json = save_json 136 | -------------------------------------------------------------------------------- /batchgenerators/examples/cifar10.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import numpy as np 4 | import os 5 | from batchgenerators.dataloading.data_loader import DataLoaderFromDataset 6 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter 7 | from batchgenerators.datasets.cifar import HighPerformanceCIFARLoader, CifarDataset 8 | from batchgenerators.transforms.abstract_transforms import Compose 9 | from batchgenerators.transforms.spatial_transforms import SpatialTransform 10 | from batchgenerators.transforms.utility_transforms import NumpyToTensor 11 | from torch._six import int_classes, string_classes, container_abcs 12 | from torch.utils.data.dataloader import numpy_type_map 13 | 14 | _use_shared_memory = False 15 | 16 | 17 | def default_collate(batch): 18 | r"""Puts each data field into a tensor with outer dimension batch size""" 19 | 20 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 21 | elem_type = type(batch[0]) 22 | if isinstance(batch[0], torch.Tensor): 23 | out = None 24 | if _use_shared_memory: 25 | # If we're in a background process, concatenate directly into a 26 | # shared memory tensor to avoid an extra copy 27 | numel = sum([x.numel() for x in batch]) 28 | storage = batch[0].storage()._new_shared(numel) 29 | out = batch[0].new(storage) 30 | return torch.stack(batch, 0, out=out) 31 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 32 | and elem_type.__name__ != 'string_': 33 | elem = batch[0] 34 | if elem_type.__name__ == 'ndarray': 35 | # array of string classes and object 36 | if re.search('[SaUO]', elem.dtype.str) is not None: 37 | raise TypeError(error_msg.format(elem.dtype)) 38 | 39 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 40 | if elem.shape == (): # scalars 41 | py_type = float if elem.dtype.name.startswith('float') else int 42 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 43 | elif isinstance(batch[0], int_classes): 44 | return torch.LongTensor(batch) 45 | elif isinstance(batch[0], float): 46 | return torch.DoubleTensor(batch) 47 | elif isinstance(batch[0], string_classes): 48 | return batch 49 | elif isinstance(batch[0], container_abcs.Mapping): 50 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 51 | elif isinstance(batch[0], container_abcs.Sequence): 52 | transposed = zip(*batch) 53 | return [default_collate(samples) for samples in transposed] 54 | 55 | raise TypeError((error_msg.format(type(batch[0])))) 56 | 57 | 58 | if __name__ == '__main__': 59 | ### current implementation of betchgenerators stuff for this script does not use _use_shared_memory! 60 | 61 | from time import time 62 | batch_size = 50 63 | num_workers = 8 64 | pin_memory = False 65 | num_epochs = 3 66 | dataset_dir = '/media/fabian/data/data/cifar10' 67 | numpy_to_tensor = NumpyToTensor(['data', 'labels'], cast_to=None) 68 | fname = os.path.join(dataset_dir, 'cifar10_training_data.npz') 69 | dataset = np.load(fname) 70 | cifar_dataset_as_arrays = (dataset['data'], dataset['labels'], dataset['filenames']) 71 | print('batch_size', batch_size) 72 | print('num_workers', num_workers) 73 | print('pin_memory', pin_memory) 74 | print('num_epochs', num_epochs) 75 | 76 | tr_transforms = [SpatialTransform((32, 32))] * 1 # SpatialTransform is computationally expensive and we need some 77 | # load on CPU so we just stack 5 of them on top of each other 78 | tr_transforms.append(numpy_to_tensor) 79 | tr_transforms = Compose(tr_transforms) 80 | 81 | cifar_dataset = CifarDataset(dataset_dir, train=True, transform=tr_transforms) 82 | 83 | dl = DataLoaderFromDataset(cifar_dataset, batch_size, num_workers, 1) 84 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, pin_memory) 85 | 86 | batches = 0 87 | for _ in mt: 88 | batches += 1 89 | assert len(_['data'].shape) == 4 90 | 91 | assert batches == len(cifar_dataset) / batch_size # this assertion only holds if len(datset) is divisible by 92 | # batch size 93 | 94 | start = time() 95 | for _ in range(num_epochs): 96 | batches = 0 97 | for _ in mt: 98 | batches += 1 99 | stop = time() 100 | print('batchgenerators took %03.4f seconds' % (stop - start)) 101 | 102 | # The best I can do: 103 | 104 | dl = HighPerformanceCIFARLoader(cifar_dataset_as_arrays, batch_size, num_workers, 1) # this circumvents the 105 | # default_collate function, just to see if that is slowing things down 106 | mt = MultiThreadedAugmenter(dl, tr_transforms, num_workers, 1, None, pin_memory) 107 | 108 | batches = 0 109 | for _ in mt: 110 | batches += 1 111 | assert len(_['data'].shape) == 4 112 | 113 | assert batches == len(cifar_dataset_as_arrays[0]) / batch_size # this assertion only holds if len(datset) is 114 | # divisible by batch size 115 | 116 | start = time() 117 | for _ in range(num_epochs): 118 | batches = 0 119 | for _ in mt: 120 | batches += 1 121 | stop = time() 122 | print('high performance batchgenerators %03.4f seconds' % (stop - start)) 123 | 124 | 125 | from torch.utils.data import DataLoader as TorchDataLoader 126 | 127 | trainloader = TorchDataLoader(cifar_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, 128 | pin_memory=pin_memory, collate_fn=default_collate) 129 | 130 | batches = 0 131 | for _ in iter(trainloader): 132 | batches += 1 133 | assert len(_['data'].shape) == 4 134 | 135 | start = time() 136 | for _ in range(num_epochs): 137 | batches = 0 138 | for _ in trainloader: 139 | batches += 1 140 | stop = time() 141 | print('pytorch took %03.4f seconds' % (stop - start)) 142 | -------------------------------------------------------------------------------- /tests/test_spatial_transformations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import unittest 17 | import numpy as np 18 | 19 | from batchgenerators.augmentations.spatial_transformations import augment_rot90, augment_resize, augment_transpose_axes 20 | 21 | 22 | class AugmentTransposeAxes(unittest.TestCase): 23 | 24 | def setUp(self): 25 | np.random.seed(123) 26 | self.data_3D = np.random.random((2, 4, 5, 6)) 27 | self.seg_3D = np.random.random(self.data_3D.shape) 28 | 29 | def test_transpose_axes(self): 30 | n_iter = 1000 31 | tmp = 0 32 | for i in range(n_iter): 33 | data_out, seg_out = augment_transpose_axes(self.data_3D, self.seg_3D, axes=(1, 0)) 34 | 35 | if np.array_equal(data_out, np.swapaxes(self.data_3D, 1, 2)): 36 | tmp += 1 37 | self.assertAlmostEqual(tmp, n_iter/2., delta=10) 38 | 39 | 40 | class AugmentResize(unittest.TestCase): 41 | 42 | def setUp(self): 43 | np.random.seed(123) 44 | self.data_3D = np.random.random((2, 12, 14, 31)) 45 | self.seg_3D = np.random.random(self.data_3D.shape) 46 | 47 | def test_resize(self): 48 | data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=15) 49 | 50 | mean_resized = float(np.mean(data_resized)) 51 | mean_original = float(np.mean(self.data_3D)) 52 | 53 | self.assertAlmostEqual(mean_original, mean_resized, places=2) 54 | 55 | self.assertTrue(all((data_resized.shape[i] == 15 and seg_resized.shape[i] == 15) for i in 56 | range(1, len(data_resized.shape)))) 57 | 58 | def test_resize2(self): 59 | data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=(7, 5, 6)) 60 | 61 | mean_resized = float(np.mean(data_resized)) 62 | mean_original = float(np.mean(self.data_3D)) 63 | 64 | self.assertAlmostEqual(mean_original, mean_resized, places=2) 65 | 66 | self.assertTrue(all([i == j for i, j in zip(data_resized.shape[1:], (7, 5, 6))])) 67 | self.assertTrue(all([i == j for i, j in zip(seg_resized.shape[1:], (7, 5, 6))])) 68 | 69 | 70 | class AugmentRot90(unittest.TestCase): 71 | 72 | def setUp(self): 73 | np.random.seed(123) 74 | self.data_3D = np.random.random((2, 4, 5, 6)) 75 | self.seg_3D = np.random.random(self.data_3D.shape) 76 | self.num_rot = [1] 77 | 78 | def test_rotation_checkerboard(self): 79 | data_2d_checkerboard = np.zeros((1, 2, 2)) 80 | data_2d_checkerboard[0, 0, 0] = 1 81 | data_2d_checkerboard[0, 1, 1] = 1 82 | 83 | data_rotated_list = [] 84 | n_iter = 1000 85 | for i in range(n_iter): 86 | d_r, _ = augment_rot90(np.copy(data_2d_checkerboard), None, num_rot=[4,1], axes=[0, 1]) 87 | data_rotated_list.append(d_r) 88 | 89 | data_rotated_np = np.array(data_rotated_list) 90 | sum_data_list = np.sum(data_rotated_np, axis=0) 91 | a = np.unique(sum_data_list) 92 | self.assertAlmostEqual(a[0], n_iter/2, delta=20) 93 | self.assertTrue(len(a) == 2) 94 | 95 | def test_rotation(self): 96 | data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=self.num_rot, 97 | axes=[0, 1]) 98 | 99 | for i in range(self.data_3D.shape[1]): 100 | self.assertTrue(np.array_equal(self.data_3D[:, i, :, :], np.flip(data_rotated[:, :, i, :], axis=1))) 101 | self.assertTrue(np.array_equal(self.seg_3D[:, i, :, :], np.flip(seg_rotated[:, :, i, :], axis=1))) 102 | 103 | def test_randomness_rotation_axis(self): 104 | tmp = 0 105 | for j in range(100): 106 | data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=self.num_rot, 107 | axes=[0, 1, 2]) 108 | if np.array_equal(self.data_3D[:, 0, :, :], np.flip(data_rotated[:, :, 0, :], axis=1)): 109 | tmp += 1 110 | self.assertAlmostEqual(tmp, 33, places=2) 111 | 112 | def test_rotation_list(self): 113 | num_rot = [1, 3] 114 | data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=num_rot, 115 | axes=[0, 1]) 116 | tmp = 0 117 | for i in range(self.data_3D.shape[1]): 118 | # check for normal and inverse rotations 119 | normal_rotated = np.array_equal(self.data_3D[:, i, :, :], data_rotated[:, :, -i-1, :]) 120 | inverse_rotated = np.array_equal(self.data_3D[:, i, :, :], np.flip(data_rotated[:, :, i, :], axis=1)) 121 | if normal_rotated: 122 | tmp += 1 123 | self.assertTrue(normal_rotated or inverse_rotated) 124 | self.assertTrue(np.array_equal(self.seg_3D[:, i, :, :], seg_rotated[:, :, -i - 1, :]) or 125 | np.array_equal(self.seg_3D[:, i, :, :], np.flip(seg_rotated[:, :, i, :], axis=1))) 126 | 127 | def test_randomness_rotation_number(self): 128 | tmp = 0 129 | num_rot = [1, 3] 130 | n_iter = 1000 131 | for j in range(n_iter): 132 | data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=num_rot, 133 | axes=[0, 1]) 134 | normal_rotated = np.array_equal(self.data_3D[:, 0, :, :], data_rotated[:, :, - 1, :]) 135 | if normal_rotated: 136 | tmp += 1 137 | self.assertAlmostEqual(tmp, n_iter / 2., delta=20) 138 | 139 | 140 | if __name__ == '__main__': 141 | unittest.main() 142 | -------------------------------------------------------------------------------- /batchgenerators/transforms/channel_selection_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import numpy as np 17 | from warnings import warn 18 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 19 | 20 | 21 | class DataChannelSelectionTransform(AbstractTransform): 22 | """Selects color channels from the batch and discards the others. 23 | 24 | Args: 25 | channels (list of int): List of channels to be kept. 26 | 27 | """ 28 | 29 | def __init__(self, channels, data_key="data"): 30 | self.data_key = data_key 31 | self.channels = channels 32 | 33 | def __call__(self, **data_dict): 34 | data_dict[self.data_key] = data_dict[self.data_key][:, self.channels] 35 | return data_dict 36 | 37 | 38 | class SegChannelSelectionTransform(AbstractTransform): 39 | """Segmentations may have more than one channel. This transform selects segmentation channels 40 | 41 | Args: 42 | channels (list of int): List of channels to be kept. 43 | 44 | """ 45 | 46 | def __init__(self, channels, keep_discarded_seg=False, label_key="seg"): 47 | self.label_key = label_key 48 | self.channels = channels 49 | self.keep_discarded = keep_discarded_seg 50 | 51 | def __call__(self, **data_dict): 52 | seg = data_dict.get(self.label_key) 53 | 54 | if seg is None: 55 | warn("You used SegChannelSelectionTransform but there is no 'seg' key in your data_dict, returning " 56 | "data_dict unmodified", Warning) 57 | else: 58 | if self.keep_discarded: 59 | discarded_seg_idx = [i for i in range(len(seg[0])) if i not in self.channels] 60 | data_dict['discarded_seg'] = seg[:, discarded_seg_idx] 61 | data_dict[self.label_key] = seg[:, self.channels] 62 | return data_dict 63 | 64 | 65 | class SegChannelMergeTransform(AbstractTransform): 66 | """Merge selected channels of a onehot segmentation. Will merge into lowest index. 67 | 68 | Args: 69 | channels (list of int): List of channels to be merged. 70 | 71 | """ 72 | 73 | def __init__(self, channels, keep_discarded_seg=False, label_key="seg", fill_value=1): 74 | self.label_key = label_key 75 | self.channels = sorted(channels) 76 | self.keep_discarded = keep_discarded_seg 77 | self.fill_value = fill_value 78 | 79 | def __call__(self, **data_dict): 80 | seg = data_dict.get(self.label_key) 81 | 82 | if seg is None: 83 | warn("You used SegChannelSelectionTransform but there is no 'seg' key in your data_dict, returning data_dict unmodified", Warning) 84 | else: 85 | if self.keep_discarded: 86 | data_dict['discarded_seg'] = seg[:, self.channels[1:]] 87 | all_channels = list(range(seg.shape[1])) 88 | for i in self.channels[1:]: 89 | seg[:, self.channels[0]][seg[:, i] != 0] = self.fill_value 90 | all_channels.remove(i) 91 | data_dict[self.label_key] = seg[:, all_channels] 92 | return data_dict 93 | 94 | 95 | class SegChannelRandomSwapTransform(AbstractTransform): 96 | """Randomly swap two segmentation channels. 97 | 98 | Args: 99 | axis1 (int): First axis for swap 100 | axis2 (int): Second axis for swap 101 | swap_probability (float): Probability for swap 102 | 103 | """ 104 | 105 | def __init__(self, axis1, axis2, swap_probability=0.5, label_key="seg"): 106 | self.axis1 = axis1 107 | self.axis2 = axis2 108 | self.swap_probability = swap_probability 109 | self.label_key = label_key 110 | 111 | def __call__(self, **data_dict): 112 | seg = data_dict.get(self.label_key) 113 | 114 | if seg is None: 115 | warn("You used SegChannelSelectionTransform but there is no 'seg' key in your data_dict, returning " 116 | "data_dict unmodified", Warning) 117 | else: 118 | random_number = np.random.rand() 119 | if random_number < self.swap_probability: 120 | seg[:, [self.axis1, self.axis2]] = seg[:, [self.axis2, self.axis1]] 121 | data_dict[self.label_key] = seg 122 | return data_dict 123 | 124 | 125 | class SegChannelRandomDuplicateTransform(AbstractTransform): 126 | """Creates an additional seg channel full of zeros and randomly swaps it with the base channel. 127 | 128 | Args: 129 | axis (int): Axis to be duplicated 130 | swap_probability (float): Probability for swap 131 | 132 | """ 133 | 134 | def __init__(self, axis, swap_probability=0.5, label_key="seg"): 135 | self.axis = axis 136 | self.swap_probability = swap_probability 137 | self.label_key = label_key 138 | 139 | def __call__(self, **data_dict): 140 | seg = data_dict.get(self.label_key) 141 | 142 | if seg is None: 143 | warn("You used SegChannelSelectionTransform but there is no 'seg' key in your data_dict, returning " 144 | "data_dict unmodified", Warning) 145 | else: 146 | seg_shape = list(seg.shape) 147 | seg_shape[1] = 1 148 | seg = np.concatenate([seg, np.zeros(seg_shape, dtype=seg.dtype)], 1) 149 | random_number = np.random.rand() 150 | if random_number < self.swap_probability: 151 | seg[:, [self.axis, -1]] = seg[:, [-1, self.axis]] 152 | data_dict[self.label_key] = seg 153 | return data_dict 154 | 155 | 156 | class SegLabelSelectionBinarizeTransform(AbstractTransform): 157 | """Will create a binary segmentation, with the selected labels in the foreground. 158 | 159 | Args: 160 | label (int, list of int): Foreground label(s) 161 | 162 | """ 163 | 164 | def __init__(self, label, label_key="seg"): 165 | self.label_key = label_key 166 | if isinstance(label, int): 167 | self.label = [label] 168 | else: 169 | self.label = sorted(label) 170 | 171 | def __call__(self, **data_dict): 172 | seg = data_dict.get(self.label_key) 173 | 174 | if seg is None: 175 | warn("You used SegLabelSelectionBinarizeTransform but there is no 'seg' key in your data_dict, returning " 176 | "data_dict unmodified", Warning) 177 | else: 178 | discard_labels = set(np.unique(seg)) - set(self.label) - set([0]) 179 | for label in discard_labels: 180 | seg[seg == label] = 0 181 | for label in self.label: 182 | seg[seg == label] = 1 183 | data_dict[self.label_key] = seg 184 | return data_dict 185 | -------------------------------------------------------------------------------- /batchgenerators/augmentations/color_augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from builtins import range 17 | from typing import Tuple, Union, Callable 18 | 19 | import numpy as np 20 | from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter 21 | 22 | 23 | def augment_contrast(data_sample: np.ndarray, 24 | contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), 25 | preserve_range: bool = True, 26 | per_channel: bool = True, 27 | p_per_channel: float = 1) -> np.ndarray: 28 | if not per_channel: 29 | if callable(contrast_range): 30 | factor = contrast_range() 31 | else: 32 | if np.random.random() < 0.5 and contrast_range[0] < 1: 33 | factor = np.random.uniform(contrast_range[0], 1) 34 | else: 35 | factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) 36 | 37 | for c in range(data_sample.shape[0]): 38 | if np.random.uniform() < p_per_channel: 39 | mn = data_sample[c].mean() 40 | if preserve_range: 41 | minm = data_sample[c].min() 42 | maxm = data_sample[c].max() 43 | 44 | data_sample[c] = (data_sample[c] - mn) * factor + mn 45 | 46 | if preserve_range: 47 | data_sample[c][data_sample[c] < minm] = minm 48 | data_sample[c][data_sample[c] > maxm] = maxm 49 | else: 50 | for c in range(data_sample.shape[0]): 51 | if np.random.uniform() < p_per_channel: 52 | if callable(contrast_range): 53 | factor = contrast_range() 54 | else: 55 | if np.random.random() < 0.5 and contrast_range[0] < 1: 56 | factor = np.random.uniform(contrast_range[0], 1) 57 | else: 58 | factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1]) 59 | 60 | mn = data_sample[c].mean() 61 | if preserve_range: 62 | minm = data_sample[c].min() 63 | maxm = data_sample[c].max() 64 | 65 | data_sample[c] = (data_sample[c] - mn) * factor + mn 66 | 67 | if preserve_range: 68 | data_sample[c][data_sample[c] < minm] = minm 69 | data_sample[c][data_sample[c] > maxm] = maxm 70 | return data_sample 71 | 72 | 73 | def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel:bool=True, p_per_channel:float=1.): 74 | """ 75 | data_sample must have shape (c, x, y(, z))) 76 | :param data_sample: 77 | :param mu: 78 | :param sigma: 79 | :param per_channel: 80 | :param p_per_channel: 81 | :return: 82 | """ 83 | if not per_channel: 84 | rnd_nb = np.random.normal(mu, sigma) 85 | for c in range(data_sample.shape[0]): 86 | if np.random.uniform() <= p_per_channel: 87 | data_sample[c] += rnd_nb 88 | else: 89 | for c in range(data_sample.shape[0]): 90 | if np.random.uniform() <= p_per_channel: 91 | rnd_nb = np.random.normal(mu, sigma) 92 | data_sample[c] += rnd_nb 93 | return data_sample 94 | 95 | 96 | def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True): 97 | multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) 98 | if not per_channel: 99 | data_sample *= multiplier 100 | else: 101 | for c in range(data_sample.shape[0]): 102 | multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1]) 103 | data_sample[c] *= multiplier 104 | return data_sample 105 | 106 | 107 | def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon=1e-7, per_channel=False, 108 | retain_stats: Union[bool, Callable[[], bool]] = False): 109 | if invert_image: 110 | data_sample = - data_sample 111 | 112 | if not per_channel: 113 | retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats 114 | if retain_stats_here: 115 | mn = data_sample.mean() 116 | sd = data_sample.std() 117 | if np.random.random() < 0.5 and gamma_range[0] < 1: 118 | gamma = np.random.uniform(gamma_range[0], 1) 119 | else: 120 | gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) 121 | minm = data_sample.min() 122 | rnge = data_sample.max() - minm 123 | data_sample = np.power(((data_sample - minm) / float(rnge + epsilon)), gamma) * rnge + minm 124 | if retain_stats_here: 125 | data_sample = data_sample - data_sample.mean() 126 | data_sample = data_sample / (data_sample.std() + 1e-8) * sd 127 | data_sample = data_sample + mn 128 | else: 129 | for c in range(data_sample.shape[0]): 130 | retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats 131 | if retain_stats_here: 132 | mn = data_sample[c].mean() 133 | sd = data_sample[c].std() 134 | if np.random.random() < 0.5 and gamma_range[0] < 1: 135 | gamma = np.random.uniform(gamma_range[0], 1) 136 | else: 137 | gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) 138 | minm = data_sample[c].min() 139 | rnge = data_sample[c].max() - minm 140 | data_sample[c] = np.power(((data_sample[c] - minm) / float(rnge + epsilon)), gamma) * float(rnge + epsilon) + minm 141 | if retain_stats_here: 142 | data_sample[c] = data_sample[c] - data_sample[c].mean() 143 | data_sample[c] = data_sample[c] / (data_sample[c].std() + 1e-8) * sd 144 | data_sample[c] = data_sample[c] + mn 145 | if invert_image: 146 | data_sample = - data_sample 147 | return data_sample 148 | 149 | 150 | def augment_illumination(data, white_rgb): 151 | idx = np.random.choice(len(white_rgb), data.shape[0]) 152 | for sample in range(data.shape[0]): 153 | _, img = general_cc_var_num_channels(data[sample], 0, 5, 0, None, 1., 7, False) 154 | rgb = np.array(white_rgb[idx[sample]]) * np.sqrt(3) 155 | for c in range(data[sample].shape[0]): 156 | data[sample, c] = img[c] * rgb[c] 157 | return data 158 | 159 | 160 | def augment_PCA_shift(data, U, s, sigma=0.2): 161 | for sample in range(data.shape[0]): 162 | data[sample] = illumination_jitter(data[sample], U, s, sigma) 163 | data[sample] -= data[sample].min() 164 | data[sample] /= data[sample].max() 165 | return data 166 | -------------------------------------------------------------------------------- /tests/test_axis_mirroring.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import unittest 17 | import unittest2 18 | import numpy as np 19 | from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter 20 | from skimage import data 21 | 22 | from tests.DataGenerators import BasicDataLoader 23 | from batchgenerators.transforms.spatial_transforms import MirrorTransform 24 | 25 | 26 | class TestMirrorAxis(unittest2.TestCase): 27 | def setUp(self): 28 | self.seed = 1234 29 | 30 | self.batch_size = 10 31 | self.num_batches = 1000 32 | 33 | np.random.seed(self.seed) 34 | 35 | ### 2D initialiazations 36 | 37 | cam = data.camera() 38 | self.cam = cam[np.newaxis, np.newaxis, :, :] 39 | 40 | self.cam_left = self.cam[:, :, :, ::-1] 41 | self.cam_updown = self.cam[:, :, ::-1, :] 42 | self.cam_updown_left = self.cam[:, :, ::-1, ::-1] 43 | 44 | self.x_2D = self.cam 45 | self.y_2D = self.cam 46 | 47 | ### 3D initialiazations 48 | 49 | self.cam_3D = np.random.rand(20, 20, 20)[np.newaxis, np.newaxis, :, :, :] 50 | 51 | self.cam_3D_left = self.cam_3D[:, :, :, ::-1, :] 52 | self.cam_3D_updown = self.cam_3D[:, :, ::-1, :, :] 53 | self.cam_3D_updown_left = self.cam_3D[:, :, ::-1, ::-1, :] 54 | 55 | self.cam_3D_left_z = self.cam_3D_left[:, :, :, :, ::-1] 56 | self.cam_3D_updown_z = self.cam_3D_updown[:, :, :, :, ::-1] 57 | self.cam_3D_updown_left_z = self.cam_3D_updown_left[:, :, :, :, ::-1] 58 | self.cam_3D_z = self.cam_3D[:, :, :, :, ::-1] 59 | 60 | self.x_3D = self.cam_3D 61 | self.y_3D = self.cam_3D 62 | 63 | 64 | def test_random_distributions_2D(self): 65 | ### test whether all 4 possible mirrorings occur in approximately equal frquencies in 2D 66 | 67 | batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None) 68 | batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1))) 69 | 70 | counts = np.zeros(shape=(4,)) 71 | 72 | for b in range(self.num_batches): 73 | batch = next(batch_gen) 74 | 75 | for ix in range(self.batch_size): 76 | if (batch['data'][ix, :, :, :] == self.cam_left).all(): 77 | counts[0] = counts[0] + 1 78 | 79 | elif (batch['data'][ix, :, :, :] == self.cam_updown).all(): 80 | counts[1] = counts[1] + 1 81 | 82 | elif (batch['data'][ix, :, :, :] == self.cam_updown_left).all(): 83 | counts[2] = counts[2] + 1 84 | 85 | elif (batch['data'][ix, :, :, :] == self.cam).all(): 86 | counts[3] = counts[3] + 1 87 | 88 | self.assertTrue([1 if (2200 < c < 2800) else 0 for c in counts] == [1]*4, "2D Images were not mirrored along " 89 | "all axes with equal probability. " 90 | "This may also indicate that " 91 | "mirroring is not working") 92 | 93 | 94 | def test_segmentations_2D(self): 95 | ### test whether segmentations are mirrored coherently with images 96 | 97 | batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None) 98 | batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1))) 99 | 100 | equivalent = True 101 | 102 | for b in range(self.num_batches): 103 | batch = next(batch_gen) 104 | for ix in range(self.batch_size): 105 | if (batch['data'][ix] != batch['seg'][ix]).all(): 106 | equivalent = False 107 | 108 | self.assertTrue(equivalent, "2D images and seg were not mirrored in the same way (they should though because " 109 | "seg needs to match the corresponding data") 110 | 111 | 112 | def test_random_distributions_3D(self): 113 | ### test whether all 8 possible mirrorings occur in approximately equal frquencies in 3D case 114 | 115 | batch_gen = BasicDataLoader((self.x_3D, self.y_3D), self.batch_size, number_of_threads_in_multithreaded=None) 116 | batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1, 2))) 117 | 118 | counts = np.zeros(shape=(8,)) 119 | 120 | for b in range(self.num_batches): 121 | batch = next(batch_gen) 122 | for ix in range(self.batch_size): 123 | if (batch['data'][ix, :, :, :, :] == self.cam_3D_left).all(): 124 | counts[0] = counts[0] + 1 125 | 126 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_updown).all(): 127 | counts[1] = counts[1] + 1 128 | 129 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_updown_left).all(): 130 | counts[2] = counts[2] + 1 131 | 132 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D).all(): 133 | counts[3] = counts[3] + 1 134 | 135 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_left_z).all(): 136 | counts[4] = counts[1] + 1 137 | 138 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_updown_z).all(): 139 | counts[5] = counts[1] + 1 140 | 141 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_updown_left_z).all(): 142 | counts[6] = counts[2] + 1 143 | 144 | elif (batch['data'][ix, :, :, :, :] == self.cam_3D_z).all(): 145 | counts[7] = counts[3] + 1 146 | 147 | self.assertTrue([1 if (1000 < c < 1400) else 0 for c in counts] == [1]*8, "3D Images were not mirrored along " 148 | "all axes with equal probability. " 149 | "This may also indicate that " 150 | "mirroring is not working") 151 | 152 | 153 | def test_segmentations_3D(self): 154 | ### test whether segmentations are rotated coherently with images 155 | 156 | batch_gen = BasicDataLoader((self.x_3D, self.y_3D), self.batch_size, number_of_threads_in_multithreaded=None) 157 | batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1, 2))) 158 | 159 | equivalent = True 160 | 161 | for b in range(self.num_batches): 162 | batch = next(batch_gen) 163 | for ix in range(self.batch_size): 164 | if (batch['data'][ix] != batch['seg'][ix]).all(): 165 | equivalent = False 166 | 167 | self.assertTrue(equivalent, "3D images and seg were not mirrored in the same way (they should though because " 168 | "seg needs to match the corresponding data") 169 | 170 | 171 | if __name__ == '__main__': 172 | unittest.main() 173 | 174 | -------------------------------------------------------------------------------- /batchgenerators/augmentations/crop_and_pad_augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from builtins import range 17 | import numpy as np 18 | from batchgenerators.augmentations.utils import pad_nd_image 19 | 20 | 21 | def center_crop(data, crop_size, seg=None): 22 | return crop(data, seg, crop_size, 0, 'center') 23 | 24 | 25 | def get_lbs_for_random_crop(crop_size, data_shape, margins): 26 | """ 27 | 28 | :param crop_size: 29 | :param data_shape: (b,c,x,y(,z)) must be the whole thing! 30 | :param margins: 31 | :return: 32 | """ 33 | lbs = [] 34 | for i in range(len(data_shape) - 2): 35 | if data_shape[i+2] - crop_size[i] - margins[i] > margins[i]: 36 | lbs.append(np.random.randint(margins[i], data_shape[i+2] - crop_size[i] - margins[i])) 37 | else: 38 | lbs.append((data_shape[i+2] - crop_size[i]) // 2) 39 | return lbs 40 | 41 | 42 | def get_lbs_for_center_crop(crop_size, data_shape): 43 | """ 44 | :param crop_size: 45 | :param data_shape: (b,c,x,y(,z)) must be the whole thing! 46 | :return: 47 | """ 48 | lbs = [] 49 | for i in range(len(data_shape) - 2): 50 | lbs.append((data_shape[i + 2] - crop_size[i]) // 2) 51 | return lbs 52 | 53 | 54 | def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", 55 | pad_mode='constant', pad_kwargs={'constant_values': 0}, 56 | pad_mode_seg='constant', pad_kwargs_seg={'constant_values': 0}): 57 | """ 58 | crops data and seg (seg may be None) to crop_size. Whether this will be achieved via center or random crop is 59 | determined by crop_type. Margin will be respected only for random_crop and will prevent the crops form being closer 60 | than margin to the respective image border. crop_size can be larger than data_shape - margin -> data/seg will be 61 | padded with zeros in that case. margins can be negative -> results in padding of data/seg followed by cropping with 62 | margin=0 for the appropriate axes 63 | 64 | :param data: b, c, x, y(, z) 65 | :param seg: 66 | :param crop_size: 67 | :param margins: distance from each border, can be int or list/tuple of ints (one element for each dimension). 68 | Can be negative (data/seg will be padded if needed) 69 | :param crop_type: random or center 70 | :return: 71 | """ 72 | if not isinstance(data, (list, tuple, np.ndarray)): 73 | raise TypeError("data has to be either a numpy array or a list") 74 | 75 | data_shape = tuple([len(data)] + list(data[0].shape)) 76 | data_dtype = data[0].dtype 77 | dim = len(data_shape) - 2 78 | 79 | if seg is not None: 80 | seg_shape = tuple([len(seg)] + list(seg[0].shape)) 81 | seg_dtype = seg[0].dtype 82 | 83 | if not isinstance(seg, (list, tuple, np.ndarray)): 84 | raise TypeError("data has to be either a numpy array or a list") 85 | 86 | assert all([i == j for i, j in zip(seg_shape[2:], data_shape[2:])]), "data and seg must have the same spatial " \ 87 | "dimensions. Data: %s, seg: %s" % \ 88 | (str(data_shape), str(seg_shape)) 89 | 90 | if type(crop_size) not in (tuple, list, np.ndarray): 91 | crop_size = [crop_size] * dim 92 | else: 93 | assert len(crop_size) == len( 94 | data_shape) - 2, "If you provide a list/tuple as center crop make sure it has the same dimension as your " \ 95 | "data (2d/3d)" 96 | 97 | if not isinstance(margins, (np.ndarray, tuple, list)): 98 | margins = [margins] * dim 99 | 100 | data_return = np.zeros([data_shape[0], data_shape[1]] + list(crop_size), dtype=data_dtype) 101 | if seg is not None: 102 | seg_return = np.zeros([seg_shape[0], seg_shape[1]] + list(crop_size), dtype=seg_dtype) 103 | else: 104 | seg_return = None 105 | 106 | for b in range(data_shape[0]): 107 | data_shape_here = [data_shape[0]] + list(data[b].shape) 108 | if seg is not None: 109 | seg_shape_here = [seg_shape[0]] + list(seg[b].shape) 110 | 111 | if crop_type == "center": 112 | lbs = get_lbs_for_center_crop(crop_size, data_shape_here) 113 | elif crop_type == "random": 114 | lbs = get_lbs_for_random_crop(crop_size, data_shape_here, margins) 115 | else: 116 | raise NotImplementedError("crop_type must be either center or random") 117 | 118 | need_to_pad = [[0, 0]] + [[abs(min(0, lbs[d])), 119 | abs(min(0, data_shape_here[d + 2] - (lbs[d] + crop_size[d])))] 120 | for d in range(dim)] 121 | 122 | # we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed 123 | ubs = [min(lbs[d] + crop_size[d], data_shape_here[d+2]) for d in range(dim)] 124 | lbs = [max(0, lbs[d]) for d in range(dim)] 125 | 126 | slicer_data = [slice(0, data_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)] 127 | data_cropped = data[b][tuple(slicer_data)] 128 | 129 | if seg_return is not None: 130 | slicer_seg = [slice(0, seg_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)] 131 | seg_cropped = seg[b][tuple(slicer_seg)] 132 | 133 | if any([i > 0 for j in need_to_pad for i in j]): 134 | data_return[b] = np.pad(data_cropped, need_to_pad, pad_mode, **pad_kwargs) 135 | if seg_return is not None: 136 | seg_return[b] = np.pad(seg_cropped, need_to_pad, pad_mode_seg, **pad_kwargs_seg) 137 | else: 138 | data_return[b] = data_cropped 139 | if seg_return is not None: 140 | seg_return[b] = seg_cropped 141 | 142 | return data_return, seg_return 143 | 144 | 145 | def random_crop(data, seg=None, crop_size=128, margins=[0, 0, 0]): 146 | return crop(data, seg, crop_size, margins, 'random') 147 | 148 | 149 | def pad_nd_image_and_seg(data, seg, new_shape=None, must_be_divisible_by=None, pad_mode_data='constant', 150 | np_pad_kwargs_data=None, pad_mode_seg='constant', np_pad_kwargs_seg=None): 151 | """ 152 | Pads data and seg to new_shape. new_shape is thereby understood as min_shape (if data/seg is already larger then 153 | new_shape the shape stays the same for the dimensions this applies) 154 | :param data: 155 | :param seg: 156 | :param new_shape: if none then only must_be_divisible_by is applied 157 | :param must_be_divisible_by: UNet like architectures sometimes require the input to be divisibly by some number. This 158 | will modify new_shape if new_shape is not divisibly by this (by increasing it accordingly). 159 | must_be_divisible_by should be a list of int (one for each spatial dimension) and this list must have the same 160 | length as new_shape 161 | :param pad_mode_data: see np.pad 162 | :param np_pad_kwargs_data:see np.pad 163 | :param pad_mode_seg:see np.pad 164 | :param np_pad_kwargs_seg:see np.pad 165 | :return: 166 | """ 167 | sample_data = pad_nd_image(data, new_shape, mode=pad_mode_data, kwargs=np_pad_kwargs_data, 168 | return_slicer=False, shape_must_be_divisible_by=must_be_divisible_by) 169 | if seg is not None: 170 | sample_seg = pad_nd_image(seg, new_shape, mode=pad_mode_seg, kwargs=np_pad_kwargs_seg, 171 | return_slicer=False, shape_must_be_divisible_by=must_be_divisible_by) 172 | else: 173 | sample_seg = None 174 | return sample_data, sample_seg 175 | -------------------------------------------------------------------------------- /tests/test_multithreaded_augmenter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | from time import sleep 17 | 18 | import numpy as np 19 | from batchgenerators.dataloading.data_loader import SlimDataLoaderBase 20 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter 21 | from batchgenerators.examples.multithreaded_dataloading import DummyDL, DummyDLWithShuffle 22 | from batchgenerators.transforms.abstract_transforms import Compose 23 | from batchgenerators.transforms.spatial_transforms import MirrorTransform, TransposeAxesTransform 24 | from batchgenerators.transforms.utility_transforms import NumpyToTensor 25 | from skimage.data import camera, checkerboard, astronaut, binary_blobs, coins 26 | from skimage.transform import resize 27 | from copy import deepcopy 28 | 29 | 30 | class DummyDL2DImage(SlimDataLoaderBase): 31 | def __init__(self, batch_size, num_threads=8): 32 | data = [] 33 | target_shape = (224, 224) 34 | 35 | c = camera() 36 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32) 37 | data.append(c[None]) 38 | 39 | c = checkerboard() 40 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32) 41 | data.append(c[None]) 42 | 43 | c = astronaut().mean(-1) 44 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32) 45 | data.append(c[None]) 46 | 47 | c = binary_blobs() 48 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32) 49 | data.append(c[None]) 50 | 51 | c = coins() 52 | c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32) 53 | data.append(c[None]) 54 | data = np.stack(data) 55 | super(DummyDL2DImage, self).__init__(data, batch_size, num_threads) 56 | 57 | def generate_train_batch(self): 58 | idx = np.random.choice(len(self._data), self.batch_size) 59 | res = [] 60 | for i in idx: 61 | res.append(self._data[i:i+1]) 62 | res = np.vstack(res) 63 | return {'data': res} 64 | 65 | 66 | class TestMultiThreadedAugmenter(unittest.TestCase): 67 | """ 68 | This test is inspired by the multithreaded example I did a while back 69 | """ 70 | def setUp(self): 71 | np.random.seed(1234) 72 | self.num_threads = 4 73 | self.dl = DummyDL(self.num_threads) 74 | self.dl_with_shuffle = DummyDLWithShuffle(self.num_threads) 75 | self.dl_images = DummyDL2DImage(4, self.num_threads) 76 | 77 | def test_no_crash(self): 78 | """ 79 | This one should just not crash, that's all 80 | :return: 81 | """ 82 | dl = self.dl_images 83 | mt_dl = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False) 84 | 85 | for _ in range(20): 86 | _ = mt_dl.next() 87 | 88 | def test_DummyDL(self): 89 | """ 90 | DummyDL must return numbers from 0 to 99 in ascending order 91 | :return: 92 | """ 93 | dl = DummyDL(1) 94 | res = [] 95 | for i in dl: 96 | res.append(i) 97 | 98 | assert len(res) == 100 99 | res_copy = deepcopy(res) 100 | res.sort() 101 | assert all((i == j for i, j in zip(res, res_copy))) 102 | assert all((i == j for i, j in zip(res, np.arange(0, 100)))) 103 | 104 | def test_order(self): 105 | """ 106 | Coordinating workers in a multiprocessing envrionment is difficult. We want DummyDL in a multithreaded 107 | environment to still give us the numbers from 0 to 99 in ascending order 108 | :return: 109 | """ 110 | dl = self.dl 111 | mt = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False) 112 | 113 | res = [] 114 | for i in mt: 115 | res.append(i) 116 | 117 | assert len(res) == 100 118 | res_copy = deepcopy(res) 119 | res.sort() 120 | assert all((i == j for i, j in zip(res, res_copy))) 121 | assert all((i == j for i, j in zip(res, np.arange(0, 100)))) 122 | 123 | def test_restart_and_order(self): 124 | """ 125 | Coordinating workers in a multiprocessing envrionment is difficult. We want DummyDL in a multithreaded 126 | environment to still give us the numbers from 0 to 99 in ascending order. 127 | 128 | We want the MultiThreadedAugmenter to restart and return the same result in each run 129 | :return: 130 | """ 131 | dl = self.dl 132 | mt = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False) 133 | 134 | res = [] 135 | for i in mt: 136 | res.append(i) 137 | 138 | assert len(res) == 100 139 | res_copy = deepcopy(res) 140 | res.sort() 141 | assert all((i == j for i, j in zip(res, res_copy))) 142 | assert all((i == j for i, j in zip(res, np.arange(0, 100)))) 143 | 144 | res = [] 145 | for i in mt: 146 | res.append(i) 147 | 148 | assert len(res) == 100 149 | res_copy = deepcopy(res) 150 | res.sort() 151 | assert all((i == j for i, j in zip(res, res_copy))) 152 | assert all((i == j for i, j in zip(res, np.arange(0, 100)))) 153 | 154 | res = [] 155 | for i in mt: 156 | res.append(i) 157 | 158 | assert len(res) == 100 159 | res_copy = deepcopy(res) 160 | res.sort() 161 | assert all((i == j for i, j in zip(res, res_copy))) 162 | assert all((i == j for i, j in zip(res, np.arange(0, 100)))) 163 | 164 | def test_image_pipeline_and_pin_memory(self): 165 | ''' 166 | This just should not crash 167 | :return: 168 | ''' 169 | try: 170 | import torch 171 | except ImportError: 172 | '''dont test if torch is not installed''' 173 | return 174 | 175 | 176 | tr_transforms = [] 177 | tr_transforms.append(MirrorTransform()) 178 | tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) 179 | tr_transforms.append(NumpyToTensor(keys='data', cast_to='float')) 180 | 181 | composed = Compose(tr_transforms) 182 | 183 | dl = self.dl_images 184 | mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True) 185 | 186 | for _ in range(50): 187 | res = mt.next() 188 | 189 | assert isinstance(res['data'], torch.Tensor) 190 | assert res['data'].is_pinned() 191 | 192 | # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent 193 | # the success of the test but it does not look pretty) 194 | sleep(2) 195 | 196 | def test_image_pipeline(self): 197 | ''' 198 | This just should not crash 199 | :return: 200 | ''' 201 | 202 | tr_transforms = [] 203 | tr_transforms.append(MirrorTransform()) 204 | tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) 205 | 206 | composed = Compose(tr_transforms) 207 | 208 | dl = self.dl_images 209 | mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False) 210 | 211 | for _ in range(50): 212 | res = mt.next() 213 | 214 | # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent 215 | # the success of the test but it does not look pretty) 216 | sleep(2) 217 | 218 | 219 | if __name__ == '__main__': 220 | from multiprocessing import freeze_support 221 | freeze_support() 222 | unittest.main() 223 | -------------------------------------------------------------------------------- /batchgenerators/examples/brats2017/brats2017_dataloader_2D.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | from batchgenerators.augmentations.crop_and_pad_augmentations import crop 5 | from batchgenerators.augmentations.utils import pad_nd_image 6 | from batchgenerators.dataloading.data_loader import DataLoader 7 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter 8 | from batchgenerators.examples.brats2017.brats2017_dataloader_3D import get_list_of_patients, BraTS2017DataLoader3D, \ 9 | get_train_transform 10 | from batchgenerators.examples.brats2017.config import brats_preprocessed_folder, num_threads_for_brats_example 11 | from batchgenerators.utilities.data_splitting import get_split_deterministic 12 | 13 | 14 | class BraTS2017DataLoader2D(DataLoader): 15 | def __init__(self, data, batch_size, patch_size, num_threads_in_multithreaded, seed_for_shuffle=1234, return_incomplete=False, 16 | shuffle=True): 17 | """ 18 | data must be a list of patients as returned by get_list_of_patients (and split by get_split_deterministic) 19 | 20 | patch_size is the spatial size the retured batch will have 21 | 22 | """ 23 | super().__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle, return_incomplete, shuffle, 24 | True) 25 | self.patch_size = patch_size 26 | self.num_modalities = 4 27 | self.indices = list(range(len(data))) 28 | 29 | @staticmethod 30 | def load_patient(patient): 31 | return BraTS2017DataLoader3D.load_patient(patient) 32 | 33 | def generate_train_batch(self): 34 | # DataLoader has its own methods for selecting what patients to use next, see its Documentation 35 | idx = self.get_indices() 36 | patients_for_batch = [self._data[i] for i in idx] 37 | 38 | # initialize empty array for data and seg 39 | data = np.zeros((self.batch_size, self.num_modalities, *self.patch_size), dtype=np.float32) 40 | seg = np.zeros((self.batch_size, 1, *self.patch_size), dtype=np.float32) 41 | 42 | metadata = [] 43 | patient_names = [] 44 | 45 | # iterate over patients_for_batch and include them in the batch 46 | for i, j in enumerate(patients_for_batch): 47 | patient_data, patient_metadata = self.load_patient(j) 48 | 49 | # patient data is a memmap. If we extract just one slice then just this one slice will be read from the 50 | # disk, so no worries! 51 | slice_idx = np.random.choice(patient_data.shape[1]) 52 | patient_data = patient_data[:, slice_idx] 53 | 54 | # this will only pad patient_data if its shape is smaller than self.patch_size 55 | patient_data = pad_nd_image(patient_data, self.patch_size) 56 | 57 | # now random crop to self.patch_size 58 | # crop expects the data to be (b, c, x, y, z) but patient_data is (c, x, y, z) so we need to add one 59 | # dummy dimension in order for it to work (@Todo, could be improved) 60 | patient_data, patient_seg = crop(patient_data[:-1][None], patient_data[-1:][None], self.patch_size, crop_type="random") 61 | 62 | data[i] = patient_data[0] 63 | seg[i] = patient_seg[0] 64 | 65 | metadata.append(patient_metadata) 66 | patient_names.append(j) 67 | 68 | return {'data': data, 'seg':seg, 'metadata':metadata, 'names':patient_names} 69 | 70 | 71 | if __name__ == "__main__": 72 | patients = get_list_of_patients(brats_preprocessed_folder) 73 | 74 | train, val = get_split_deterministic(patients, fold=0, num_splits=5, random_state=12345) 75 | 76 | patch_size = (160, 160) 77 | batch_size = 48 78 | 79 | # I recommend you don't use 'iteration oder all training data' as epoch because in patch based training this is 80 | # really not super well defined. If you leave all arguments as default then each batch sill contain randomly 81 | # selected patients. Since we don't care about epochs here we can set num_threads_in_multithreaded to anything. 82 | dataloader = BraTS2017DataLoader2D(train, batch_size, patch_size, 1) 83 | 84 | batch = next(dataloader) 85 | try: 86 | from batchviewer import view_batch 87 | # batch viewer can show up to 4d tensors. We can show only one sample, but that should be sufficient here 88 | view_batch(np.concatenate((batch['data'][0], batch['seg'][0]), 0)[:, None]) 89 | except ImportError: 90 | view_batch = None 91 | print("you can visualize batches with batchviewer. It's a nice and handy tool. You can get it here: " 92 | "https://github.com/FabianIsensee/BatchViewer") 93 | 94 | # now we have some DataLoader. Let's go an get some augmentations 95 | 96 | # first let's collect all shapes, you will see why later 97 | shapes = [BraTS2017DataLoader2D.load_patient(i)[0].shape[2:] for i in patients] 98 | max_shape = np.max(shapes, 0) 99 | max_shape = np.max((max_shape, patch_size), 0) 100 | 101 | # we create a new instance of DataLoader. This one will return batches of shape max_shape. Cropping/padding is 102 | # now done by SpatialTransform. If we do it this way we avoid border artifacts (the entire brain of all cases will 103 | # be in the batch and SpatialTransform will use zeros which is exactly what we have outside the brain) 104 | # this is viable here but not viable if you work with different data. If you work for example with CT scans that 105 | # can be up to 500x500x500 voxels large then you should do this differently. There, instead of using max_shape you 106 | # should estimate what shape you need to extract so that subsequent SpatialTransform does not introduce border 107 | # artifacts 108 | dataloader_train = BraTS2017DataLoader2D(train, batch_size, max_shape, 1) 109 | 110 | # during training I like to run a validation from time to time to see where I am standing. This is not a correct 111 | # validation because just like training this is patch-based but it's good enough. We don't do augmentation for the 112 | # validation, so patch_size is used as shape target here 113 | dataloader_validation = BraTS2017DataLoader2D(val, batch_size, patch_size, 1) 114 | 115 | tr_transforms = get_train_transform(patch_size) 116 | 117 | # finally we can create multithreaded transforms that we can actually use for training 118 | # we don't pin memory here because this is pytorch specific. 119 | tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=num_threads_for_brats_example, 120 | num_cached_per_queue=3, 121 | seeds=None, pin_memory=False) 122 | # we need less processes for vlaidation because we dont apply transformations 123 | val_gen = MultiThreadedAugmenter(dataloader_validation, None, 124 | num_processes=max(1, num_threads_for_brats_example // 2), num_cached_per_queue=1, 125 | seeds=None, 126 | pin_memory=False) 127 | 128 | # lets start the MultiThreadedAugmenter. This is not necessary but allows them to start generating training 129 | # batches while other things run in the main thread 130 | tr_gen.restart() 131 | val_gen.restart() 132 | 133 | # now if this was a network training you would run epochs like this (remember tr_gen and val_gen generate 134 | # inifinite examples! Don't do "for batch in tr_gen:"!!!): 135 | num_batches_per_epoch = 10 136 | num_validation_batches_per_epoch = 3 137 | num_epochs = 5 138 | # let's run this to get a time on how long it takes 139 | time_per_epoch = [] 140 | start = time() 141 | for epoch in range(num_epochs): 142 | start_epoch = time() 143 | for b in range(num_batches_per_epoch): 144 | batch = next(tr_gen) 145 | # do network training here with this batch 146 | 147 | for b in range(num_validation_batches_per_epoch): 148 | batch = next(val_gen) 149 | # run validation here 150 | end_epoch = time() 151 | time_per_epoch.append(end_epoch - start_epoch) 152 | end = time() 153 | total_time = end - start 154 | print("Running %d epochs took a total of %.2f seconds with time per epoch being %s" % 155 | (num_epochs, total_time, str(time_per_epoch))) 156 | 157 | # if you notice that you have CPU usage issues, reduce the probability with which the spatial transformations are 158 | # applied in get_train_transform (down to 0.1 for example). SpatialTransform is the most expensive transform 159 | 160 | # if you wish to visualize some augmented examples, install batchviewer and uncomment this 161 | if view_batch is not None: 162 | for _ in range(4): 163 | batch = next(tr_gen) 164 | view_batch(np.concatenate((batch['data'][0], batch['seg'][0]), 0)[:, None]) 165 | else: 166 | print("Cannot visualize batches, install batchviewer first") 167 | -------------------------------------------------------------------------------- /batchgenerators/examples/brats2017/brats2017_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from batchgenerators.examples.brats2017.config import brats_preprocessed_folder, \ 3 | brats_folder_with_downloaded_train_data, num_threads_for_brats_example 4 | from batchgenerators.utilities.file_and_folder_operations import * 5 | 6 | try: 7 | import SimpleITK as sitk 8 | except ImportError: 9 | print("You need to have SimpleITK installed to run this example!") 10 | raise ImportError("SimpleITK not found") 11 | 12 | from multiprocessing import Pool 13 | 14 | 15 | def get_list_of_files(base_dir): 16 | """ 17 | returns a list of lists containing the filenames. The outer list contains all training examples. Each entry in the 18 | outer list is again a list pointing to the files of that training example in the following order: 19 | T1, T1c, T2, FLAIR, segmentation 20 | :param base_dir: 21 | :return: 22 | """ 23 | list_of_lists = [] 24 | for glioma_type in ['HGG', 'LGG']: 25 | current_directory = join(base_dir, glioma_type) 26 | patients = subfolders(current_directory, join=False) 27 | for p in patients: 28 | patient_directory = join(current_directory, p) 29 | t1_file = join(patient_directory, p + "_t1.nii.gz") 30 | t1c_file = join(patient_directory, p + "_t1ce.nii.gz") 31 | t2_file = join(patient_directory, p + "_t2.nii.gz") 32 | flair_file = join(patient_directory, p + "_flair.nii.gz") 33 | seg_file = join(patient_directory, p + "_seg.nii.gz") 34 | this_case = [t1_file, t1c_file, t2_file, flair_file, seg_file] 35 | assert all((isfile(i) for i in this_case)), "some file is missing for patient %s; make sure the following " \ 36 | "files are there: %s" % (p, str(this_case)) 37 | list_of_lists.append(this_case) 38 | print("Found %d patients" % len(list_of_lists)) 39 | return list_of_lists 40 | 41 | 42 | def load_and_preprocess(case, patient_name, output_folder): 43 | """ 44 | loads, preprocesses and saves a case 45 | This is what happens here: 46 | 1) load all images and stack them to a 4d array 47 | 2) crop to nonzero region, this removes unnecessary zero-valued regions and reduces computation time 48 | 3) normalize the nonzero region with its mean and standard deviation 49 | 4) save 4d tensor as numpy array. Also save metadata required to create niftis again (required for export 50 | of predictions) 51 | 52 | :param case: 53 | :param patient_name: 54 | :return: 55 | """ 56 | # load SimpleITK Images 57 | imgs_sitk = [sitk.ReadImage(i) for i in case] 58 | 59 | # get pixel arrays from SimpleITK images 60 | imgs_npy = [sitk.GetArrayFromImage(i) for i in imgs_sitk] 61 | 62 | # get some metadata 63 | spacing = imgs_sitk[0].GetSpacing() 64 | # the spacing returned by SimpleITK is in inverse order relative to the numpy array we receive. If we wanted to 65 | # resample the data and if the spacing was not isotropic (in BraTS all cases have already been resampled to 1x1x1mm 66 | # by the organizers) then we need to pay attention here. Therefore we bring the spacing into the correct order so 67 | # that spacing[0] actually corresponds to the spacing of the first axis of the numpy array 68 | spacing = np.array(spacing)[::-1] 69 | 70 | direction = imgs_sitk[0].GetDirection() 71 | origin = imgs_sitk[0].GetOrigin() 72 | 73 | original_shape = imgs_npy[0].shape 74 | 75 | # now stack the images into one 4d array, cast to float because we will get rounding problems if we don't 76 | imgs_npy = np.concatenate([i[None] for i in imgs_npy]).astype(np.float32) 77 | 78 | # now find the nonzero region and crop to that 79 | nonzero = [np.array(np.where(i != 0)) for i in imgs_npy] 80 | nonzero = [[np.min(i, 1), np.max(i, 1)] for i in nonzero] 81 | nonzero = np.array([np.min([i[0] for i in nonzero], 0), np.max([i[1] for i in nonzero], 0)]).T 82 | # nonzero now has shape 3, 2. It contains the (min, max) coordinate of nonzero voxels for each axis 83 | 84 | # now crop to nonzero 85 | imgs_npy = imgs_npy[:, 86 | nonzero[0, 0] : nonzero[0, 1] + 1, 87 | nonzero[1, 0]: nonzero[1, 1] + 1, 88 | nonzero[2, 0]: nonzero[2, 1] + 1, 89 | ] 90 | 91 | # now we create a brain mask that we use for normalization 92 | nonzero_masks = [i != 0 for i in imgs_npy[:-1]] 93 | brain_mask = np.zeros(imgs_npy.shape[1:], dtype=bool) 94 | for i in range(len(nonzero_masks)): 95 | brain_mask = brain_mask | nonzero_masks[i] 96 | 97 | # now normalize each modality with its mean and standard deviation (computed within the brain mask) 98 | for i in range(len(imgs_npy) - 1): 99 | mean = imgs_npy[i][brain_mask].mean() 100 | std = imgs_npy[i][brain_mask].std() 101 | imgs_npy[i] = (imgs_npy[i] - mean) / (std + 1e-8) 102 | imgs_npy[i][brain_mask == 0] = 0 103 | 104 | # the segmentation of brats has the values 0, 1, 2 and 4. This is pretty inconvenient to say the least. 105 | # We move everything that is 4 to 3 106 | imgs_npy[-1][imgs_npy[-1] == 4] = 3 107 | 108 | # now save as npz 109 | np.save(join(output_folder, patient_name + ".npy"), imgs_npy) 110 | 111 | metadata = { 112 | 'spacing': spacing, 113 | 'direction': direction, 114 | 'origin': origin, 115 | 'original_shape': original_shape, 116 | 'nonzero_region': nonzero 117 | } 118 | 119 | save_pickle(metadata, join(output_folder, patient_name + ".pkl")) 120 | 121 | 122 | def save_segmentation_as_nifti(segmentation, metadata, output_file): 123 | original_shape = metadata['original_shape'] 124 | seg_original_shape = np.zeros(original_shape, dtype=np.uint8) 125 | nonzero = metadata['nonzero_region'] 126 | seg_original_shape[nonzero[0, 0] : nonzero[0, 1] + 1, 127 | nonzero[1, 0]: nonzero[1, 1] + 1, 128 | nonzero[2, 0]: nonzero[2, 1] + 1] = segmentation 129 | sitk_image = sitk.GetImageFromArray(seg_original_shape) 130 | sitk_image.SetDirection(metadata['direction']) 131 | sitk_image.SetOrigin(metadata['origin']) 132 | # remember to revert spacing back to sitk order again 133 | sitk_image.SetSpacing(tuple(metadata['spacing'][[2, 1, 0]])) 134 | sitk.WriteImage(sitk_image, output_file) 135 | 136 | 137 | if __name__ == "__main__": 138 | # This is the same preprocessing I used for our contributions to the BraTS 2017 and 2018 challenges. 139 | # Preprocessing is described in the documentation of load_and_preprocess 140 | 141 | # The training data is identical between BraTS 2017 and 2018. You can request access here: 142 | # https://ipp.cbica.upenn.edu/#BraTS18_registration 143 | 144 | # brats_base points to where the extracted downloaded training data is 145 | 146 | # preprocessed data is saved as npy. This may seem odd if you are familiar with medical images, but trust me it's 147 | # the best way to do this for deep learning. It does not make much of a difference for BraTS, but if you are 148 | # dealing with larger images this is crusial for your pipelines to not get stuck in CPU bottleneck. What we can do 149 | # with numpy arrays is we can load them via np.load(file, mmap_mode="r") and then read just parts of it on the fly 150 | # during training. This is super important if your patch size is smaller than the size of the entire patient (for 151 | # example if you work with large CT data or if you need 2D slices). 152 | # For this to work properly the output_folder (or wherever the data is stored during training) must be on an SSD! 153 | # HDDs are usually too slow and you also wouldn't want to do this over a network share (there are exceptions but 154 | # take this as a rule of thumb) 155 | 156 | # Why is this not an IPython Notebook you may ask? Because I HATE IPython Notebooks. Simple :-) 157 | 158 | list_of_lists = get_list_of_files(brats_folder_with_downloaded_train_data) 159 | 160 | maybe_mkdir_p(brats_preprocessed_folder) 161 | 162 | patient_names = [i[0].split("/")[-2] for i in list_of_lists] 163 | 164 | p = Pool(processes=num_threads_for_brats_example) 165 | p.starmap(load_and_preprocess, zip(list_of_lists, patient_names, [brats_preprocessed_folder] * len(list_of_lists))) 166 | p.close() 167 | p.join() 168 | 169 | # remember that we cropped the data before preprocessing. If we predict the test cases, we want to run the same 170 | # preprocessing for them. We need to then put the segmentation back into its original position (due to cropping). 171 | # Here is how you can do that: 172 | 173 | # lets use Brats17_2013_0_1 for this example 174 | img = np.load(join(brats_preprocessed_folder, "Brats17_2013_0_1.npy")) 175 | metadata = load_pickle(join(brats_preprocessed_folder, "Brats17_2013_0_1.pkl")) 176 | # remember that we changed the segmentation labels from 0, 1, 2, 4 to 0, 1, 2, 3. We need to change that back to 177 | # get the correct format 178 | img[-1][img[-1] == 3] = 4 179 | save_segmentation_as_nifti(img[-1], metadata, join(brats_preprocessed_folder, "delete_me.nii.gz")) 180 | -------------------------------------------------------------------------------- /tests/test_color_augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import unittest 17 | import numpy as np 18 | from batchgenerators.augmentations.color_augmentations import augment_contrast, augment_brightness_additive,\ 19 | augment_brightness_multiplicative, augment_gamma 20 | 21 | 22 | class TestAugmentContrast(unittest.TestCase): 23 | 24 | def setUp(self): 25 | np.random.seed(1234) 26 | self.data_3D = np.random.random((2, 64, 56, 48)) 27 | self.data_2D = np.random.random((2, 64, 56)) 28 | self.factor = (0.75, 1.25) 29 | 30 | self.d_3D = augment_contrast(self.data_3D, contrast_range=self.factor, preserve_range=False, per_channel=False) 31 | self.d_2D = augment_contrast(self.data_2D, contrast_range=self.factor, preserve_range=False, per_channel=False) 32 | 33 | def test_augment_contrast_3D(self): 34 | 35 | mean = np.mean(self.data_3D) 36 | 37 | idx0 = np.where(self.data_3D < mean) # where the data is lower than mean value 38 | idx1 = np.where(self.data_3D > mean) # where the data is greater than mean value 39 | 40 | contrast_lower_limit_0 = self.factor[1] * (self.data_3D[idx0] - mean) + mean 41 | contrast_lower_limit_1 = self.factor[0] * (self.data_3D[idx1] - mean) + mean 42 | contrast_upper_limit_0 = self.factor[0] * (self.data_3D[idx0] - mean) + mean 43 | contrast_upper_limit_1 = self.factor[1] * (self.data_3D[idx1] - mean) + mean 44 | 45 | # augmented values lower than mean should be lower than lower limit and greater than upper limit 46 | self.assertTrue(np.all(np.logical_and(self.d_3D[idx0] >= contrast_lower_limit_0, 47 | self.d_3D[idx0] <= contrast_upper_limit_0)), 48 | "Augmented contrast below mean value not within range") 49 | # augmented values greater than mean should be lower than upper limit and greater than lower limit 50 | self.assertTrue(np.all(np.logical_and(self.d_3D[idx1] >= contrast_lower_limit_1, 51 | self.d_3D[idx1] <= contrast_upper_limit_1)), 52 | "Augmented contrast above mean not within range") 53 | 54 | def test_augment_contrast_2D(self): 55 | 56 | mean = np.mean(self.data_2D) 57 | 58 | idx0 = np.where(self.data_2D < mean) # where the data is lower than mean value 59 | idx1 = np.where(self.data_2D > mean) # where the data is greater than mean value 60 | 61 | contrast_lower_limit_0 = self.factor[1] * (self.data_2D[idx0] - mean) + mean 62 | contrast_lower_limit_1 = self.factor[0] * (self.data_2D[idx1] - mean) + mean 63 | contrast_upper_limit_0 = self.factor[0] * (self.data_2D[idx0] - mean) + mean 64 | contrast_upper_limit_1 = self.factor[1] * (self.data_2D[idx1] - mean) + mean 65 | 66 | # augmented values lower than mean should be lower than lower limit and greater than upper limit 67 | self.assertTrue(np.all(np.logical_and(self.d_2D[idx0] >= contrast_lower_limit_0, 68 | self.d_2D[idx0] <= contrast_upper_limit_0)), 69 | "Augmented contrast below mean value not within range") 70 | # augmented values greater than mean should be lower than upper limit and greater than lower limit 71 | self.assertTrue(np.all(np.logical_and(self.d_2D[idx1] >= contrast_lower_limit_1, 72 | self.d_2D[idx1] <= contrast_upper_limit_1)), 73 | "Augmented contrast above mean not within range") 74 | 75 | 76 | class TestAugmentBrightness(unittest.TestCase): 77 | 78 | def setUp(self): 79 | np.random.seed(1234) 80 | self.data_input_3D = np.random.random((2, 64, 56, 48)) 81 | self.data_input_2D = np.random.random((2, 64, 56)) 82 | self.factor = (0.75, 1.25) 83 | self.multiplier_range = [2,4] 84 | 85 | self.d_3D_per_channel = augment_brightness_additive(np.copy(self.data_input_3D), mu=100, sigma=10, 86 | per_channel=True) 87 | self.d_3D = augment_brightness_additive(np.copy(self.data_input_3D), mu=100, sigma=10, per_channel=False) 88 | 89 | self.d_2D_per_channel = augment_brightness_additive(np.copy(self.data_input_2D), mu=100, sigma=10, 90 | per_channel=True) 91 | self.d_2D = augment_brightness_additive(np.copy(self.data_input_2D), mu=100, sigma=10, per_channel=False) 92 | 93 | self.d_3D_per_channel_mult = augment_brightness_multiplicative(np.copy(self.data_input_3D), 94 | multiplier_range=self.multiplier_range, 95 | per_channel=True) 96 | self.d_3D_mult = augment_brightness_multiplicative(np.copy(self.data_input_3D), 97 | multiplier_range=self.multiplier_range, per_channel=False) 98 | 99 | self.d_2D_per_channel_mult = augment_brightness_multiplicative(np.copy(self.data_input_2D), 100 | multiplier_range=self.multiplier_range, 101 | per_channel=True) 102 | self.d_2D_mult = augment_brightness_multiplicative(np.copy(self.data_input_2D), 103 | multiplier_range=self.multiplier_range, per_channel=False) 104 | 105 | def test_augment_brightness_additive_3D(self): 106 | add_factor = self.d_3D-self.data_input_3D 107 | self.assertTrue(len(np.unique(add_factor.round(decimals=8)))==1, 108 | "Added brightness factor is not equal for all channels") 109 | 110 | add_factor = self.d_3D_per_channel - self.data_input_3D 111 | self.assertTrue(len(np.unique(add_factor.round(decimals=8))) == self.data_input_3D.shape[0], 112 | "Added brightness factor is not different for each channels") 113 | 114 | def test_augment_brightness_additive_2D(self): 115 | add_factor = self.d_2D-self.data_input_2D 116 | self.assertTrue(len(np.unique(add_factor.round(decimals=8)))==1, 117 | "Added brightness factor is not equal for all channels") 118 | 119 | add_factor = self.d_2D_per_channel - self.data_input_2D 120 | self.assertTrue(len(np.unique(add_factor.round(decimals=8))) == self.data_input_2D.shape[0], 121 | "Added brightness factor is not different for each channels") 122 | 123 | def test_augment_brightness_multiplicative_3D(self): 124 | mult_factor = self.d_3D_mult/self.data_input_3D 125 | self.assertTrue(len(np.unique(mult_factor.round(decimals=6)))==1, 126 | "Multiplied brightness factor is not equal for all channels") 127 | 128 | mult_factor = self.d_3D_per_channel_mult/self.data_input_3D 129 | self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == self.data_input_3D.shape[0], 130 | "Multiplied brightness factor is not different for each channels") 131 | 132 | def test_augment_brightness_multiplicative_2D(self): 133 | mult_factor = self.d_2D_mult/self.data_input_2D 134 | self.assertTrue(len(np.unique(mult_factor.round(decimals=6)))==1, 135 | "Multiplied brightness factor is not equal for all channels") 136 | 137 | mult_factor = self.d_2D_per_channel_mult/self.data_input_2D 138 | self.assertTrue(len(np.unique(mult_factor.round(decimals=6))) == self.data_input_2D.shape[0], 139 | "Multiplied brightness factor is not different for each channels") 140 | 141 | 142 | class TestAugmentGamma(unittest.TestCase): 143 | 144 | def setUp(self): 145 | np.random.seed(1234) 146 | self.data_input_3D = np.random.random((2, 64, 56, 48)) 147 | self.data_input_2D = np.random.random((2, 64, 56)) 148 | 149 | self.d_3D = augment_gamma(np.copy(self.data_input_2D), gamma_range=(0.2, 1.2), per_channel=False) 150 | 151 | def test_augment_gamma_3D(self): 152 | self.assertTrue(self.d_3D.min().round(decimals=3) == self.data_input_3D.min().round(decimals=3) and 153 | self.d_3D.max().round(decimals=3) == self.data_input_3D.max().round(decimals=3), 154 | "Input range does not equal output range") 155 | 156 | 157 | if __name__ == '__main__': 158 | unittest.main() 159 | -------------------------------------------------------------------------------- /batchgenerators/transforms/crop_and_pad_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop, pad_nd_image_and_seg, random_crop 17 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 18 | import numpy as np 19 | 20 | 21 | class CenterCropTransform(AbstractTransform): 22 | """ Crops data and seg (if available) in the center 23 | 24 | Args: 25 | output_size (int or tuple of int): Output patch size 26 | 27 | """ 28 | 29 | def __init__(self, crop_size, data_key="data", label_key="seg"): 30 | self.data_key = data_key 31 | self.label_key = label_key 32 | self.crop_size = crop_size 33 | 34 | def __call__(self, **data_dict): 35 | data = data_dict.get(self.data_key) 36 | seg = data_dict.get(self.label_key) 37 | data, seg = center_crop(data, self.crop_size, seg) 38 | 39 | data_dict[self.data_key] = data 40 | if seg is not None: 41 | data_dict[self.label_key] = seg 42 | 43 | return data_dict 44 | 45 | 46 | class CenterCropSegTransform(AbstractTransform): 47 | """ Crops seg in the center (required if you are using unpadded convolutions in a segmentation network). 48 | Leaves data as it is 49 | 50 | Args: 51 | output_size (int or tuple of int): Output patch size 52 | 53 | """ 54 | 55 | def __init__(self, output_size, data_key="data", label_key="seg"): 56 | self.data_key = data_key 57 | self.label_key = label_key 58 | self.output_size = output_size 59 | 60 | def __call__(self, **data_dict): 61 | seg = data_dict.get(self.label_key) 62 | 63 | if seg is not None: 64 | data_dict[self.label_key] = center_crop(seg, self.output_size, None)[0] 65 | else: 66 | from warnings import warn 67 | warn("You shall not pass data_dict without seg: Used CenterCropSegTransform, but there is no seg", Warning) 68 | return data_dict 69 | 70 | 71 | class RandomCropTransform(AbstractTransform): 72 | """ Randomly crops data and seg (if available) 73 | 74 | Args: 75 | crop_size (int or tuple of int): Output patch size 76 | 77 | margins (tuple of int): how much distance should the patch border have to the image broder (bilaterally)? 78 | 79 | """ 80 | 81 | def __init__(self, crop_size=128, margins=(0, 0, 0), data_key="data", label_key="seg"): 82 | self.data_key = data_key 83 | self.label_key = label_key 84 | self.margins = margins 85 | self.crop_size = crop_size 86 | 87 | def __call__(self, **data_dict): 88 | data = data_dict.get(self.data_key) 89 | seg = data_dict.get(self.label_key) 90 | 91 | data, seg = random_crop(data, seg, self.crop_size, self.margins) 92 | 93 | data_dict[self.data_key] = data 94 | if seg is not None: 95 | data_dict[self.label_key] = seg 96 | 97 | return data_dict 98 | 99 | 100 | class PadTransform(AbstractTransform): 101 | def __init__(self, new_size, pad_mode_data='constant', pad_mode_seg='constant', 102 | np_pad_kwargs_data=None, np_pad_kwargs_seg=None, 103 | data_key="data", label_key="seg"): 104 | """ 105 | Pads data and seg to new_size. Only supports numpy arrays for data and seg. 106 | 107 | :param new_size: (x, y(, z)) 108 | :param pad_value_data: 109 | :param pad_value_seg: 110 | :param data_key: 111 | :param label_key: 112 | """ 113 | self.data_key = data_key 114 | self.label_key = label_key 115 | self.new_size = new_size 116 | self.pad_mode_data = pad_mode_data 117 | self.pad_mode_seg = pad_mode_seg 118 | if np_pad_kwargs_data is None: 119 | np_pad_kwargs_data = {} 120 | if np_pad_kwargs_seg is None: 121 | np_pad_kwargs_seg = {} 122 | self.np_pad_kwargs_data = np_pad_kwargs_data 123 | self.np_pad_kwargs_seg = np_pad_kwargs_seg 124 | 125 | assert isinstance(self.new_size, (tuple, list, np.ndarray)), "new_size must be tuple, list or np.ndarray" 126 | 127 | def __call__(self, **data_dict): 128 | data = data_dict.get(self.data_key) 129 | seg = data_dict.get(self.label_key) 130 | 131 | assert len(self.new_size) + 2 == len(data.shape), "new size must be a tuple/list/np.ndarray with shape " \ 132 | "(x, y(, z))" 133 | data, seg = pad_nd_image_and_seg(data, seg, self.new_size, None, 134 | np_pad_kwargs_data=self.np_pad_kwargs_data, 135 | np_pad_kwargs_seg=self.np_pad_kwargs_seg, 136 | pad_mode_data=self.pad_mode_data, 137 | pad_mode_seg=self.pad_mode_seg) 138 | 139 | data_dict[self.data_key] = data 140 | if seg is not None: 141 | data_dict[self.label_key] = seg 142 | 143 | return data_dict 144 | 145 | 146 | class RandomShiftTransform(AbstractTransform): 147 | def __init__(self, shift_mu, shift_sigma, p_per_sample=1, p_per_channel=0.5, border_value=0, apply_to_keys=('data',)): 148 | """ 149 | randomly shifts the data by some amount. Equivalent to pad -> random crop but with (probably) less 150 | computational requirements 151 | 152 | shift_mu gives the mean value of the shift, 0 is recommended 153 | shift_sigma gives the standard deviation of the shift 154 | 155 | shift will ne drawn from a Gaussian distribution with mean shift_mu and variance shift_sigma 156 | 157 | shift_mu and shift_sigma can either be float values OR tuples of float values. If they are tuples they will 158 | be interpreted as separate mean and std for each dimension 159 | 160 | TODO separate per channel or not? 161 | 162 | :param shift_mu: 163 | :param shift_sigma: 164 | :param p_per_sample: 165 | :param p_per_channel: 166 | :param apply_to_keys: 167 | """ 168 | self.apply_to_keys = apply_to_keys 169 | self.p_per_channel = p_per_channel 170 | self.p_per_sample = p_per_sample 171 | self.shift_sigma = shift_sigma 172 | self.shift_mu = shift_mu 173 | self.border_value = border_value 174 | 175 | def __call__(self, **data_dict): 176 | for k in self.apply_to_keys: 177 | workon = data_dict[k] 178 | for b in range(workon.shape[0]): 179 | if np.random.uniform(0, 1) < self.p_per_sample: 180 | for c in range(workon.shape[1]): 181 | if np.random.uniform(0, 1) < self.p_per_channel: 182 | shift_here = [] 183 | for d in range(len(workon.shape) - 2): 184 | shift_here.append(int(np.round(np.random.normal( 185 | self.shift_mu[d] if isinstance(self.shift_mu, (list, tuple)) else self.shift_mu, 186 | self.shift_sigma[d] if isinstance(self.shift_sigma, 187 | (list, tuple)) else self.shift_sigma, 188 | size=1)))) 189 | data_copy = np.ones_like(workon[b, c]) * self.border_value 190 | lb_x = max(shift_here[0], 0) 191 | ub_x = max(0, min(workon.shape[2], workon.shape[2] + shift_here[0])) 192 | lb_y = max(shift_here[1], 0) 193 | ub_y = max(0, min(workon.shape[3], workon.shape[3] + shift_here[1])) 194 | 195 | t_lb_x = max(-shift_here[0], 0) 196 | t_ub_x = max(0, min(workon.shape[2], workon.shape[2] - shift_here[0])) 197 | t_lb_y = max(-shift_here[1], 0) 198 | t_ub_y = max(0, min(workon.shape[3], workon.shape[3] - shift_here[1])) 199 | 200 | if len(shift_here) == 2: 201 | data_copy[t_lb_x:t_ub_x, t_lb_y:t_ub_y] = workon[b, c, lb_x:ub_x, lb_y:ub_y] 202 | elif len(shift_here) == 3: 203 | lb_z = max(shift_here[2], 0) 204 | ub_z = max(0, min(workon.shape[4], workon.shape[4] + shift_here[2])) 205 | t_lb_z = max(-shift_here[2], 0) 206 | t_ub_z = max(0, min(workon.shape[2], workon.shape[4] - shift_here[2])) 207 | data_copy[t_lb_x:t_ub_x, t_lb_y:t_ub_y, t_lb_z:t_ub_z] = workon[b, c, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] 208 | data_dict[k][b, c] = data_copy 209 | return data_dict -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # batchgenerators by MIC@DKFZ 2 | 3 | Copyright German Cancer Research Center (DKFZ) and contributors. 4 | Please make sure that your usage of this code is in compliance with its 5 | [`license`](https://github.com/MIC-DKFZ/batchgenerators/blob/master/LICENSE). 6 | 7 | batchgenerators is a python package for data augmentation. It is developed jointly between the Division of 8 | Medical Image Computing at the German Cancer Research Center (DKFZ) and the Applied Computer 9 | Vision Lab of the Helmholtz Imaging Platform. 10 | 11 | It is not (yet) perfect, but we feel it is good enough to be shared with the community. If you encounter bug, feel free 12 | to contact us or open a github issue. 13 | 14 | If you use it please cite the following work: 15 | ``` 16 | Isensee Fabian, Jäger Paul, Wasserthal Jakob, Zimmerer David, Petersen Jens, Kohl Simon, 17 | Schock Justus, Klein Andre, Roß Tobias, Wirkert Sebastian, Neher Peter, Dinkelacker Stefan, 18 | Köhler Gregor, Maier-Hein Klaus (2020). batchgenerators - a python framework for data 19 | augmentation. doi:10.5281/zenodo.3632567 20 | ``` 21 | 22 | batchgenerators also contains the following application-specific augmentations: 23 | * **Anatomy-informed Data Augmentation** 24 | Proposed at [MICCAI 2023](https://arxiv.org/abs/2309.03652) for simulation of soft-tissue deformations. Implementation details can be found [here](https://github.com/MIC-DKFZ/anatomy_informed_DA). 25 | * **Misalignment Data Augmentation** 26 | Proposed in [Nature Scientific Reports 2023](https://www.nature.com/articles/s41598-023-46747-z) 27 | for enhancing model's adaptability to diverse misalignments\ 28 | between multi-modal (multi-channel) images and thereby ensuring robust performance. Implementation details can be found [here](https://github.com/MIC-DKFZ/misalignment_DA). 29 | 30 | If you use these augmentations please cite them too. 31 | 32 | [![Build Status](https://travis-ci.com/MIC-DKFZ/batchgenerators.svg?branch=master)](https://travis-ci.com/github/MIC-DKFZ/batchgenerators) 33 | 34 | ## Supported Augmentations 35 | We supports a variety of augmentations, all of which are compatible with **2D and 3D input data**! (This is something 36 | that was missing in most other frameworks). 37 | 38 | * **Spatial Augmentations** 39 | * mirroring 40 | * channel translation (to simulate registration errors) 41 | * elastic deformations 42 | * rotations 43 | * scaling 44 | * resampling 45 | * multi-channel misalignments 46 | * **Color Augmentations** 47 | * brightness (additive, multiplivative) 48 | * contrast 49 | * gamma (like gamma correction in photo editing) 50 | * **Noise Augmentations** 51 | * Gaussian Noise 52 | * Rician Noise 53 | * ...will be expanded in future commits 54 | * **Cropping** 55 | * random crop 56 | * center crop 57 | * padding 58 | * **Anatomy-informed Augmentation** 59 | 60 | Note: Stack transforms by using batchgenerators.transforms.abstract_transforms.Compose. Finish it up by plugging the 61 | composed transform into our **multithreader**: batchgenerators.dataloading.multi_threaded_augmenter.MultiThreadedAugmenter 62 | 63 | 64 | ## How to use it 65 | 66 | The working principle is simple: Derive from DataLoaderBase class, reimplement generate_train_batch member function and 67 | use it to stack your augmentations! 68 | For simple example see `batchgenerators/examples/example_ipynb.ipynb` 69 | 70 | A heavily commented example for using SlimDataLoaderBase and MultithreadedAugmentor is available at: 71 | `batchgenerators/examples/multithreaded_with_batches.ipynb`. 72 | It gives an idea of the interplay between the SlimDataLoaderBase and the MultiThreadedAugmentor. 73 | The example uses the MultiThreadedAugmentor for loading and augmentation on mutiple processes, while 74 | covering the entire dataset only once per epoch (basically sampling without replacement). 75 | 76 | We also now have an extensive example for BraTS2017/2018 with both 2D and 3D DataLoader and augmentations: 77 | `batchgenerators/examples/brats2017/` 78 | 79 | There are also CIFAR10/100 datasets and DataLoader available at `batchgenerators/datasets/cifar.py` 80 | 81 | ## Data Structure 82 | 83 | The data structure that is used internally (and with which you have to comply when implementing generate_train_batch) 84 | is kept simple as well: It is just a regular python dictionary! We did this to allow maximum flexibility in the kind of 85 | data that is passed along through the pipeline. The dictionary must have a 'data' key:value pair. It optionally can 86 | handle a 'seg' key:vlaue pair to hold a segmentation. If a 'seg' key:value pair is present all spatial transformations 87 | will also be applied to the segmentation! A part from 'data' and 'seg' you are free to do whatever you want (your image 88 | classification/regression target for example). All key:value pairs other than 'data' and 'seg' will be passed through the 89 | pipeline unmodified. 90 | 91 | 'data' value must have shape (b, c, x, y) for 2D or shape (b, c, x, y, z) for 3D! 92 | 'seg' value must have shape (b, c, x, y) for 2D or shape (b, c, x, y, z) for 3D! Color channel may be used here to 93 | allow for several segmentation maps. If you have only one segmentation, make sure to have shape (b, 1, x, y (, z)) 94 | 95 | ## How to install locally 96 | 97 | Install batchgenerators 98 | ``` 99 | pip install --upgrade batchgenerators 100 | ``` 101 | 102 | Import as follows 103 | ``` 104 | from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform 105 | ``` 106 | 107 | ## Windows Support is very experimental! 108 | Batchgenerators makes heavy use of python multiprocessing and python multiprocessing on windows is different from linux. 109 | To prevent the workers from freezing in windows, you have to guard your code with `if __name__ == '__main__'` and use multiprocessing's [`freeze_support`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.freeze_support). The executed script may then look like this: 110 | 111 | ``` 112 | # some imports and functions here 113 | 114 | def main(): 115 | # do some stuff 116 | 117 | if __name__ == '__main__': 118 | from multiprocessing import freeze_support 119 | freeze_support() 120 | main() 121 | ``` 122 | 123 | This is not required on Linux. 124 | 125 | 126 | ## Release Notes 127 | (only highlights, not an exhaustive list) 128 | - 0.23.2: 129 | - Misalignment data augmentation added 130 | - 0.23.1: 131 | - Anatomy-informed data augmentation added 132 | - 0.23: 133 | - fixed the import mess. `__init__.py` files are now empty. This is a breaking change for some users! 134 | Please adapt your imports :-) 135 | - local_transforms are now a thing, check them out! 136 | - resize_segmentation now uses 'edge' mode and no longer takes a cval argument. Resizing segmentations with constant 137 | border values (previous default) can cause problems and should not be done. 138 | - 0.20.0: 139 | - fixed an issue with MultiThreadedAugmenter not terminating properly after KeyboardInterrupt; Fixed an error 140 | with the number and order of samples being returned when pin_memory=True; Improved performance by always hiding 141 | process-process communication bottleneck through threading 142 | - 0.19.5: 143 | - fixed OMP_NUM_THREADS issue by using threadpoolctl package; dropped python 2 support (threadpoolctl is not 144 | available for python 2) 145 | - 0.19: 146 | - There is now a complete example for BraTS2017/8 available for both 2D and 3D. Use this if you would like to get 147 | some insights on how I (Fabian) do my experiments 148 | - Windows is now supported! Thanks @justusschock for your support! 149 | - new, simple parametrization of elastic deformation. Use SpatialTransform_2! 150 | - CIFAR10/100 DataLoader are now available for your convenience 151 | - a bug in MultiThreadedAugmenter that could interfere with reproducibility is now fixed 152 | 153 | - 0.18: 154 | - all augmentations (there are some exceptions though) are implemented on a per-sample basis. This should make it 155 | easier to use the augmentations outside of the Transforms of batchgenerators 156 | - applicable Transforms now have a keyword p_per_sample with which the user can specify a probability with which this 157 | transform is applied to a sample. Before, this was handled by RndTransform and applied to the whole batch (so 158 | either all samples were augmented or none). Now this decision is made on a per-sample basis and increases 159 | variability by a lot. 160 | - following the previous point, RndTransform is now deprecated 161 | - AlternativeMultiThreadedAugmenter is now deprecated as well (no need to have this anymore) 162 | - pytorch users can now transform numpy arrays to pytorch tensors within batchgenerators (NumpyToTensor). For some 163 | reason, inter-process communication is faster with tensors (~factor 4), so this is recommended! 164 | - if numpy arrays were converted to pytorch tensors, MultithreadedAugmenter now allows to pin the memory as well 165 | (pin_memory=True). This will happen in a background thread (inspired by pytorch DataLoader). pinned memory can be 166 | copied to the GPU much faster. My (Fabian) classification experiment with Resnet50 got a speed boost of 12% from just 167 | that. 168 | 169 | 170 | ------------------------- 171 | 172 | 173 | 174 | 175 | 176 | batchgenerators is developed by the [Division of Medical Image Computing](https://www.dkfz.de/en/mic/index.php) of the 177 | German Cancer Research Center (DKFZ) and the Applied Computer Vision Lab (ACVL) of the 178 | [Helmholtz Imaging Platform](https://helmholtz-imaging.de). 179 | -------------------------------------------------------------------------------- /batchgenerators/transforms/color_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Union, Tuple, Callable 17 | 18 | import numpy as np 19 | 20 | from batchgenerators.augmentations.color_augmentations import augment_contrast, augment_brightness_additive, \ 21 | augment_brightness_multiplicative, augment_gamma, augment_illumination, augment_PCA_shift 22 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 23 | 24 | 25 | class ContrastAugmentationTransform(AbstractTransform): 26 | def __init__(self, 27 | contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25), 28 | preserve_range: bool = True, 29 | per_channel: bool = True, 30 | data_key: str = "data", 31 | p_per_sample: float = 1, 32 | p_per_channel: float = 1): 33 | """ 34 | Augments the contrast of data 35 | :param contrast_range: 36 | (float, float): range from which to sample a random contrast that is applied to the data. If 37 | one value is smaller and one is larger than 1, half of the contrast modifiers will be >1 38 | and the other half <1 (in the inverval that was specified) 39 | callable : must be contrast_range() -> float 40 | :param preserve_range: if True then the intensity values after contrast augmentation will be cropped to min and 41 | max values of the data before augmentation. 42 | :param per_channel: whether to use the same contrast modifier for all color channels or a separate one for each 43 | channel 44 | :param data_key: 45 | :param p_per_sample: 46 | """ 47 | self.p_per_sample = p_per_sample 48 | self.data_key = data_key 49 | self.contrast_range = contrast_range 50 | self.preserve_range = preserve_range 51 | self.per_channel = per_channel 52 | self.p_per_channel = p_per_channel 53 | 54 | def __call__(self, **data_dict): 55 | for b in range(len(data_dict[self.data_key])): 56 | if np.random.uniform() < self.p_per_sample: 57 | data_dict[self.data_key][b] = augment_contrast(data_dict[self.data_key][b], 58 | contrast_range=self.contrast_range, 59 | preserve_range=self.preserve_range, 60 | per_channel=self.per_channel, 61 | p_per_channel=self.p_per_channel) 62 | return data_dict 63 | 64 | 65 | class NormalizeTransform(AbstractTransform): 66 | def __init__(self, means, stds, data_key='data'): 67 | self.data_key = data_key 68 | self.stds = stds 69 | self.means = means 70 | 71 | def __call__(self, **data_dict): 72 | for c in range(data_dict[self.data_key].shape[1]): 73 | data_dict[self.data_key][:, c] -= self.means[c] 74 | data_dict[self.data_key][:, c] /= self.stds[c] 75 | return data_dict 76 | 77 | 78 | class BrightnessTransform(AbstractTransform): 79 | def __init__(self, mu, sigma, per_channel=True, data_key="data", p_per_sample=1, p_per_channel=1): 80 | """ 81 | Augments the brightness of data. Additive brightness is sampled from Gaussian distribution with mu and sigma 82 | :param mu: mean of the Gaussian distribution to sample the added brightness from 83 | :param sigma: standard deviation of the Gaussian distribution to sample the added brightness from 84 | :param per_channel: whether to use the same brightness modifier for all color channels or a separate one for 85 | each channel 86 | :param data_key: 87 | :param p_per_sample: 88 | """ 89 | self.p_per_sample = p_per_sample 90 | self.data_key = data_key 91 | self.mu = mu 92 | self.sigma = sigma 93 | self.per_channel = per_channel 94 | self.p_per_channel = p_per_channel 95 | 96 | def __call__(self, **data_dict): 97 | data = data_dict[self.data_key] 98 | 99 | for b in range(data.shape[0]): 100 | if np.random.uniform() < self.p_per_sample: 101 | data[b] = augment_brightness_additive(data[b], self.mu, self.sigma, self.per_channel, 102 | p_per_channel=self.p_per_channel) 103 | 104 | data_dict[self.data_key] = data 105 | return data_dict 106 | 107 | 108 | class BrightnessMultiplicativeTransform(AbstractTransform): 109 | def __init__(self, multiplier_range=(0.5, 2), per_channel=True, data_key="data", p_per_sample=1): 110 | """ 111 | Augments the brightness of data. Multiplicative brightness is sampled from multiplier_range 112 | :param multiplier_range: range to uniformly sample the brightness modifier from 113 | :param per_channel: whether to use the same brightness modifier for all color channels or a separate one for 114 | each channel 115 | :param data_key: 116 | :param p_per_sample: 117 | """ 118 | self.p_per_sample = p_per_sample 119 | self.data_key = data_key 120 | self.multiplier_range = multiplier_range 121 | self.per_channel = per_channel 122 | 123 | def __call__(self, **data_dict): 124 | for b in range(len(data_dict[self.data_key])): 125 | if np.random.uniform() < self.p_per_sample: 126 | data_dict[self.data_key][b] = augment_brightness_multiplicative(data_dict[self.data_key][b], 127 | self.multiplier_range, 128 | self.per_channel) 129 | return data_dict 130 | 131 | 132 | class GammaTransform(AbstractTransform): 133 | def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, data_key="data", 134 | retain_stats: Union[bool, Callable[[], bool]] = False, p_per_sample=1): 135 | """ 136 | Augments by changing 'gamma' of the image (same as gamma correction in photos or computer monitors 137 | 138 | :param gamma_range: range to sample gamma from. If one value is smaller than 1 and the other one is 139 | larger then half the samples will have gamma <1 and the other >1 (in the inverval that was specified). 140 | Tuple of float. If one value is < 1 and the other > 1 then half the images will be augmented with gamma values 141 | smaller than 1 and the other half with > 1 142 | :param invert_image: whether to invert the image before applying gamma augmentation 143 | :param per_channel: 144 | :param data_key: 145 | :param retain_stats: Gamma transformation will alter the mean and std of the data in the patch. If retain_stats=True, 146 | the data will be transformed to match the mean and standard deviation before gamma augmentation. retain_stats 147 | can also be callable (signature retain_stats() -> bool) 148 | :param p_per_sample: 149 | """ 150 | self.p_per_sample = p_per_sample 151 | self.retain_stats = retain_stats 152 | self.per_channel = per_channel 153 | self.data_key = data_key 154 | self.gamma_range = gamma_range 155 | self.invert_image = invert_image 156 | 157 | def __call__(self, **data_dict): 158 | for b in range(len(data_dict[self.data_key])): 159 | if np.random.uniform() < self.p_per_sample: 160 | data_dict[self.data_key][b] = augment_gamma(data_dict[self.data_key][b], self.gamma_range, 161 | self.invert_image, 162 | per_channel=self.per_channel, 163 | retain_stats=self.retain_stats) 164 | return data_dict 165 | 166 | 167 | class IlluminationTransform(AbstractTransform): 168 | """Do not use this for now""" 169 | 170 | def __init__(self, white_rgb, data_key="data"): 171 | self.data_key = data_key 172 | self.white_rgb = white_rgb 173 | 174 | def __call__(self, **data_dict): 175 | data_dict[self.data_key] = augment_illumination(data_dict[self.data_key], self.white_rgb) 176 | return data_dict 177 | 178 | 179 | class FancyColorTransform(AbstractTransform): 180 | """Do not use this for now""" 181 | 182 | def __init__(self, U, s, sigma=0.2, data_key="data"): 183 | self.data_key = data_key 184 | self.s = s 185 | self.U = U 186 | self.sigma = sigma 187 | 188 | def __call__(self, **data_dict): 189 | data_dict[self.data_key] = augment_PCA_shift(data_dict[self.data_key], self.U, self.s, self.sigma) 190 | return data_dict 191 | 192 | 193 | class ClipValueRange(AbstractTransform): 194 | def __init__(self, min=None, max=None, data_key="data"): 195 | """ 196 | Clips the value range of data to [min, max] 197 | :param min: 198 | :param max: 199 | :param data_key: 200 | """ 201 | self.data_key = data_key 202 | self.min = min 203 | self.max = max 204 | 205 | def __call__(self, **data_dict): 206 | data_dict[self.data_key] = np.clip(data_dict[self.data_key], self.min, self.max) 207 | return data_dict 208 | -------------------------------------------------------------------------------- /batchgenerators/dataloading/nondet_multi_threaded_augmenter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import traceback 17 | from copy import deepcopy 18 | from typing import List, Union, Callable 19 | import threading 20 | from builtins import range 21 | from multiprocessing import Process 22 | from multiprocessing import Queue 23 | from queue import Queue as thrQueue 24 | import numpy as np 25 | import logging 26 | from multiprocessing import Event 27 | from time import sleep, time 28 | 29 | from batchgenerators.dataloading.data_loader import DataLoader 30 | from threadpoolctl import threadpool_limits 31 | 32 | try: 33 | import torch 34 | except ImportError: 35 | torch = None 36 | 37 | 38 | def producer(queue: Queue, data_loader, transform, thread_id: int, seed, 39 | abort_event: Event, wait_time: float = 0.02): 40 | # the producer will set the abort event if something happens 41 | with threadpool_limits(1, None): 42 | np.random.seed(seed) 43 | data_loader.set_thread_id(thread_id) 44 | item = None 45 | 46 | try: 47 | while True: 48 | 49 | if abort_event.is_set(): 50 | return 51 | else: 52 | if item is None: 53 | item = next(data_loader) 54 | if transform is not None: 55 | item = transform(**item) 56 | 57 | if abort_event.is_set(): 58 | return 59 | 60 | if not queue.full(): 61 | queue.put(item) 62 | item = None 63 | else: 64 | sleep(wait_time) 65 | 66 | except KeyboardInterrupt: 67 | abort_event.set() 68 | return 69 | 70 | except Exception as e: 71 | print("Exception in background worker %d:\n" % thread_id, e) 72 | traceback.print_exc() 73 | abort_event.set() 74 | return 75 | 76 | 77 | def pin_memory_of_all_eligible_items_in_dict(result_dict): 78 | for k in result_dict.keys(): 79 | if isinstance(result_dict[k], torch.Tensor): 80 | result_dict[k] = result_dict[k].pin_memory() 81 | return result_dict 82 | 83 | 84 | def results_loop(in_queue: Queue, out_queue: thrQueue, abort_event: Event, 85 | pin_memory: bool, worker_list: List[Process], 86 | gpu: Union[int, None] = None, wait_time: float = 0.02): 87 | do_pin_memory = torch is not None and pin_memory and gpu is not None and torch.cuda.is_available() 88 | 89 | if do_pin_memory: 90 | print('using pin_memory on device', gpu) 91 | torch.cuda.set_device(gpu) 92 | 93 | item = None 94 | 95 | while True: 96 | try: 97 | if abort_event.is_set(): 98 | return 99 | 100 | # check if all workers are still alive 101 | if not all([i.is_alive() for i in worker_list]): 102 | abort_event.set() 103 | raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the " 104 | "print statements above for the actual error message") 105 | 106 | if item is None: 107 | if not in_queue.empty(): 108 | item = in_queue.get() 109 | if do_pin_memory: 110 | item = pin_memory_of_all_eligible_items_in_dict(item) 111 | else: 112 | sleep(wait_time) 113 | continue 114 | 115 | # we only arrive here if item is not None. Now put item in to the out_queue 116 | if not out_queue.full(): 117 | out_queue.put(item) 118 | item = None 119 | else: 120 | sleep(wait_time) 121 | continue 122 | 123 | except Exception as e: 124 | abort_event.set() 125 | raise e 126 | 127 | 128 | class NonDetMultiThreadedAugmenter(object): 129 | """ 130 | Non-deterministic but potentially faster than MultiThreadedAugmenter and uses less RAM. Also less complicated. 131 | This one only has one queue through which the communication with background workers happens, meaning that there 132 | can be a race condition to it (and thus a nondeterministic ordering of batches). The advantage of this approach is 133 | that we will never run into the issue where everything needs to wait for worker X to finish its work. 134 | Also this approach requires less RAM because we do not need to have some number of cached batches per worker and 135 | now use a global pool of caches batches that is shared among all workers. 136 | THIS MTA ONLY WORKS WITH DATALOADER THAT RETURN INFINITE RANDOM SAMPLES! So if you are using DataLoader, make sure 137 | to set infinite=True. 138 | Seeding this is not recommended :-) 139 | """ 140 | 141 | def __init__(self, data_loader, transform, num_processes, num_cached=2, seeds=None, pin_memory=False, 142 | wait_time=0.02, results_loop_fn: Callable = results_loop): 143 | self.pin_memory = pin_memory 144 | self.transform = transform 145 | self.num_cached = num_cached 146 | 147 | if isinstance(data_loader, DataLoader): assert data_loader.infinite, "Only use DataLoader instances that" \ 148 | " have infinite=True" 149 | self.generator = data_loader 150 | self.num_processes = num_processes 151 | 152 | self._queue = None 153 | self._processes = [] 154 | self.results_loop_fn = results_loop 155 | self.results_loop_thread = None 156 | self.results_loop_queue = None 157 | self.abort_event = None 158 | self.initialized = False 159 | 160 | self.wait_time = wait_time 161 | 162 | if seeds is not None: 163 | assert len(seeds) == num_processes 164 | else: 165 | seeds = [None] * num_processes 166 | self.seeds = seeds 167 | 168 | def __iter__(self): 169 | return self 170 | 171 | def next(self): 172 | return self.__next__() 173 | 174 | def __get_next_item(self): 175 | item = None 176 | 177 | while item is None: 178 | # 179 | if self.abort_event.is_set(): 180 | # self.communication_thread handles checking for dead workers and will set the abort event if necessary 181 | self._finish() 182 | raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the " 183 | "print statements above for the actual error message") 184 | 185 | if not self.results_loop_queue.empty(): 186 | item = self.results_loop_queue.get() 187 | self.results_loop_queue.task_done() 188 | else: 189 | sleep(self.wait_time) 190 | 191 | return item 192 | 193 | def __next__(self): 194 | if not self.initialized: 195 | self._start() 196 | 197 | item = self.__get_next_item() 198 | return item 199 | 200 | def _start(self): 201 | if not self.initialized: 202 | self._finish() 203 | 204 | self._queue = Queue(self.num_cached) 205 | self.results_loop_queue = thrQueue(self.num_cached) 206 | self.abort_event = Event() 207 | 208 | logging.debug("starting workers") 209 | if isinstance(self.generator, DataLoader): 210 | self.generator.was_initialized = False 211 | 212 | if torch is not None: 213 | torch_nthreads = torch.get_num_threads() 214 | torch.set_num_threads(1) 215 | with threadpool_limits(limits=1, user_api=None): 216 | for i in range(self.num_processes): 217 | self._processes.append(Process(target=producer, args=( 218 | self._queue, self.generator, self.transform, i, self.seeds[i], self.abort_event, self.wait_time 219 | ))) 220 | self._processes[-1].daemon = True 221 | _ = [i.start() for i in self._processes] 222 | if torch is not None: 223 | torch.set_num_threads(torch_nthreads) 224 | 225 | if torch is not None and torch.cuda.is_available(): 226 | gpu = torch.cuda.current_device() 227 | else: 228 | gpu = None 229 | 230 | # in_queue: Queue, out_queue: thrQueue, abort_event: Event, pin_memory: bool, worker_list: List[Process], 231 | # gpu: Union[int, None] = None, wait_time: float = 0.02 232 | self.results_loop_thread = threading.Thread(target=self.results_loop_fn, args=( 233 | self._queue, self.results_loop_queue, self.abort_event, self.pin_memory, self._processes, gpu, 234 | self.wait_time) 235 | ) 236 | self.results_loop_thread.daemon = True 237 | self.results_loop_thread.start() 238 | 239 | self.initialized = True 240 | else: 241 | logging.debug("MultiThreadedGenerator Warning: start() has been called but workers are already running") 242 | 243 | def _finish(self): 244 | if self.initialized: 245 | self.abort_event.set() 246 | sleep(self.wait_time) 247 | [i.terminate() for i in self._processes if i.is_alive()] 248 | 249 | del self._queue, self.results_loop_queue, self.results_loop_thread, self.abort_event, self._processes 250 | self._queue, self.results_loop_queue, self.results_loop_thread, self.abort_event = None, None, None, None 251 | self._processes = [] 252 | self.initialized = False 253 | 254 | def restart(self): 255 | self._finish() 256 | self._start() 257 | 258 | def __del__(self): 259 | logging.debug("MultiThreadedGenerator: destructor was called") 260 | self._finish() 261 | 262 | 263 | if __name__ == '__main__': 264 | from tests.test_DataLoader import DummyDataLoader 265 | dl = DummyDataLoader(deepcopy(list(range(1234))), 2, 3, None, 266 | return_incomplete=False, shuffle=True, 267 | infinite=True) 268 | 269 | mt = NonDetMultiThreadedAugmenter(dl, None, 3, 2, None, False, 0.02) 270 | mt._start() 271 | 272 | st = time() 273 | for i in range(1000): 274 | print(i) 275 | b = next(mt) 276 | end = time() 277 | print(end - st) 278 | 279 | mt._finish() -------------------------------------------------------------------------------- /tests/test_DataLoader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import unittest 17 | from copy import deepcopy 18 | import numpy as np 19 | from batchgenerators.dataloading.data_loader import DataLoader 20 | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter 21 | 22 | 23 | class DummyDataLoader(DataLoader): 24 | def __init__(self, data, batch_size, num_threads_in_multithreaded, seed_for_shuffle=1, return_incomplete=False, 25 | shuffle=True, infinite=False): 26 | super(DummyDataLoader, self).__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle, return_incomplete, shuffle, 27 | infinite) 28 | self.indices = data 29 | 30 | def generate_train_batch(self): 31 | idx = self.get_indices() 32 | return idx 33 | 34 | 35 | class TestDataLoader(unittest.TestCase): 36 | def test_return_all_indices_single_threaded_shuffle_False(self): 37 | data = list(range(123)) 38 | batch_sizes = [1, 3, 75, 12, 23] 39 | 40 | for b in batch_sizes: 41 | dl = DummyDataLoader(deepcopy(data), b, 1, 1, return_incomplete=True, shuffle=False, infinite=False) 42 | 43 | for _ in range(3): 44 | idx = [] 45 | for i in dl: 46 | idx += i 47 | 48 | self.assertTrue(len(idx) == len(data)) 49 | self.assertTrue(all([i == j for i,j in zip(idx, data)])) 50 | 51 | def test_return_all_indices_single_threaded_shuffle_True(self): 52 | data = list(range(123)) 53 | batch_sizes = [1, 3, 75, 12, 23] 54 | np.random.seed(1234) 55 | 56 | for b in batch_sizes: 57 | dl = DummyDataLoader(deepcopy(data), b, 1, 1, return_incomplete=True, shuffle=True, infinite=False) 58 | 59 | for _ in range(3): 60 | idx = [] 61 | for i in dl: 62 | idx += i 63 | 64 | self.assertTrue(len(idx) == len(data)) 65 | 66 | self.assertTrue(not all([i == j for i, j in zip(idx, data)])) 67 | 68 | idx.sort() 69 | self.assertTrue(all([i == j for i,j in zip(idx, data)])) 70 | 71 | def test_infinite_single_threaded(self): 72 | data = list(range(123)) 73 | 74 | dl = DummyDataLoader(deepcopy(data), 12, 1, 1, return_incomplete=True, shuffle=True, infinite=False) 75 | # this should raise a StopIteration 76 | with self.assertRaises(StopIteration): 77 | for i in range(1000): 78 | idx = next(dl) 79 | 80 | dl = DummyDataLoader(deepcopy(data), 12, 1, 1, return_incomplete=True, shuffle=True, infinite=True) 81 | # this should now not raise a StopIteration anymore 82 | for i in range(1000): 83 | idx = next(dl) 84 | 85 | def test_return_incomplete_single_threaded(self): 86 | data = list(range(123)) 87 | batch_size = 12 88 | 89 | dl = DummyDataLoader(deepcopy(data), batch_size, 1, 1, return_incomplete=False, shuffle=False, infinite=False) 90 | # this should now not raise a StopIteration anymore 91 | total = 0 92 | ctr = 0 93 | for i in dl: 94 | ctr += 1 95 | assert len(i) == batch_size 96 | total += batch_size 97 | 98 | self.assertTrue(total == 120) 99 | self.assertTrue(ctr == 10) 100 | 101 | dl = DummyDataLoader(deepcopy(data), batch_size, 1, 1, return_incomplete=True, shuffle=False, infinite=False) 102 | # this should now not raise a StopIteration anymore 103 | total = 0 104 | ctr = 0 105 | for i in dl: 106 | ctr += 1 107 | total += len(i) 108 | 109 | self.assertTrue(total == 123) 110 | self.assertTrue(ctr == 11) 111 | 112 | def test_return_all_indices_multi_threaded_shuffle_False(self): 113 | data = list(range(123)) 114 | batch_sizes = [1, 3, 75, 12, 23] 115 | num_workers = 3 116 | 117 | for b in batch_sizes: 118 | dl = DummyDataLoader(deepcopy(data), b, num_workers, 1, return_incomplete=True, shuffle=False, infinite=False) 119 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) 120 | 121 | for _ in range(3): 122 | idx = [] 123 | for i in mt: 124 | idx += i 125 | 126 | self.assertTrue(len(idx) == len(data)) 127 | self.assertTrue(all([i == j for i,j in zip(idx, data)])) 128 | 129 | def test_return_all_indices_multi_threaded_shuffle_True(self): 130 | data = list(range(123)) 131 | batch_sizes = [1, 3, 75, 12, 23] 132 | num_workers = 3 133 | 134 | for b in batch_sizes: 135 | dl = DummyDataLoader(deepcopy(data), b, num_workers, 1, return_incomplete=True, shuffle=True, infinite=False) 136 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) 137 | 138 | for _ in range(3): 139 | idx = [] 140 | for i in mt: 141 | idx += i 142 | 143 | self.assertTrue(len(idx) == len(data)) 144 | 145 | self.assertTrue(not all([i == j for i, j in zip(idx, data)])) 146 | 147 | idx.sort() 148 | self.assertTrue(all([i == j for i,j in zip(idx, data)])) 149 | 150 | def test_infinite_multi_threaded(self): 151 | data = list(range(123)) 152 | num_workers = 3 153 | 154 | dl = DummyDataLoader(deepcopy(data), 12, num_workers, 1, return_incomplete=True, shuffle=True, infinite=False) 155 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) 156 | 157 | # this should raise a StopIteration 158 | with self.assertRaises(StopIteration): 159 | for i in range(1000): 160 | idx = next(mt) 161 | 162 | dl = DummyDataLoader(deepcopy(data), 12, num_workers, 1, return_incomplete=True, shuffle=True, infinite=True) 163 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) 164 | # this should now not raise a StopIteration anymore 165 | for i in range(1000): 166 | idx = next(mt) 167 | 168 | def test_return_incomplete_multi_threaded(self): 169 | data = list(range(123)) 170 | batch_size = 12 171 | num_workers = 3 172 | 173 | dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=False, shuffle=False, infinite=False) 174 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) 175 | all_return = [] 176 | total = 0 177 | ctr = 0 178 | for i in mt: 179 | ctr += 1 180 | assert len(i) == batch_size 181 | total += len(i) 182 | all_return += i 183 | 184 | self.assertTrue(total == 120) 185 | self.assertTrue(ctr == 10) 186 | self.assertTrue(len(np.unique(all_return)) == total) 187 | 188 | dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=True, shuffle=False, infinite=False) 189 | mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) 190 | all_return = [] 191 | total = 0 192 | ctr = 0 193 | for i in mt: 194 | ctr += 1 195 | total += len(i) 196 | all_return += i 197 | 198 | self.assertTrue(total == 123) 199 | self.assertTrue(ctr == 11) 200 | self.assertTrue(len(np.unique(all_return)) == len(data)) 201 | 202 | def test_thoroughly(self): 203 | data_list = [list(range(123)), 204 | list(range(1243)), 205 | list(range(1)), 206 | list(range(7)), 207 | ] 208 | worker_list = (1, 4, 7) 209 | batch_size_list = (1, 3, 32) 210 | seed_list = [318, None] 211 | epochs = 3 212 | 213 | for data in data_list: 214 | #print('data', len(data)) 215 | for num_workers in worker_list: 216 | #print('num_workers', num_workers) 217 | for batch_size in batch_size_list: 218 | #print('batch_size', batch_size) 219 | for return_incomplete in [True, False]: 220 | #print('return_incomplete', return_incomplete) 221 | for shuffle in [True, False]: 222 | #print('shuffle', shuffle) 223 | for seed_for_shuffle in seed_list: 224 | #print('seed_for_shuffle', seed_for_shuffle) 225 | if return_incomplete: 226 | if len(data) % batch_size == 0: 227 | expected_num_batches = len(data) // batch_size 228 | else: 229 | expected_num_batches = len(data) // batch_size + 1 230 | else: 231 | expected_num_batches = len(data) // batch_size 232 | 233 | expected_num_items = len(data) if return_incomplete else expected_num_batches * batch_size 234 | 235 | print("init") 236 | dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, seed_for_shuffle, 237 | return_incomplete=return_incomplete, shuffle=shuffle, 238 | infinite=False) 239 | 240 | mt = MultiThreadedAugmenter(dl, None, num_workers, 5, None, False, wait_time=0) 241 | mt._start() 242 | 243 | for epoch in range(epochs): 244 | print("sampling") 245 | all_return = [] 246 | total = 0 247 | ctr = 0 248 | for i in mt: 249 | ctr += 1 250 | total += len(i) 251 | all_return += i 252 | 253 | print('asserting') 254 | self.assertTrue(total == expected_num_items) 255 | self.assertTrue(ctr == expected_num_batches) 256 | self.assertTrue(len(np.unique(all_return)) == expected_num_items) 257 | 258 | 259 | if __name__ == "__main__": 260 | from multiprocessing import freeze_support 261 | freeze_support() 262 | unittest.main() 263 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2021] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 190 | AND Applied Computer Vision Lab, Helmholtz Imaging Platform] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /batchgenerators/dataloading/data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # and Applied Computer Vision Lab, Helmholtz Imaging Platform 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import ABCMeta, abstractmethod 17 | from builtins import object 18 | import warnings 19 | from collections import OrderedDict 20 | from warnings import warn 21 | import numpy as np 22 | 23 | from batchgenerators.dataloading.dataset import Dataset 24 | 25 | 26 | class DataLoaderBase(object): 27 | """ Derive from this class and override generate_train_batch. If you don't want to use this you can use any 28 | generator. 29 | You can modify this class however you want. How the data is presented as batch is you responsibility. You can sample 30 | randomly, cycle through the training examples or sample the dtaa according to a specific pattern. Just make sure to 31 | use our default data structure! 32 | {'data':your_batch_of_shape_(b, c, x, y(, z)), 33 | 'seg':your_batch_of_shape_(b, c, x, y(, z)), 34 | 'anything_else1':whatever, 35 | 'anything_else2':whatever2, 36 | ...} 37 | 38 | (seg is optional) 39 | 40 | Args: 41 | data (anything): Your dataset. Stored as member variable self._data 42 | 43 | BATCH_SIZE (int): batch size. Stored as member variable self.BATCH_SIZE 44 | 45 | num_batches (int): How many batches will be generated before raising StopIteration. None=unlimited. Careful 46 | when using MultiThreadedAugmenter: Each process will produce num_batches batches. 47 | 48 | seed (False, None, int): seed to seed the numpy rng with. False = no seeding 49 | 50 | """ 51 | def __init__(self, data, BATCH_SIZE, num_batches=None, seed=False): 52 | warnings.simplefilter("once", DeprecationWarning) 53 | warn("This DataLoader will soon be removed. Migrate everything to SlimDataLoaderBase now!", DeprecationWarning) 54 | __metaclass__ = ABCMeta 55 | self._data = data 56 | self.BATCH_SIZE = BATCH_SIZE 57 | if num_batches is not None: 58 | warn("We currently strongly discourage using num_batches != None! That does not seem to work properly") 59 | self._num_batches = num_batches 60 | self._seed = seed 61 | self._was_initialized = False 62 | if self._num_batches is None: 63 | self._num_batches = 1e100 64 | self._batches_generated = 0 65 | self.thread_id = 0 66 | 67 | def reset(self): 68 | if self._seed is not False: 69 | np.random.seed(self._seed) 70 | self._was_initialized = True 71 | self._batches_generated = 0 72 | 73 | def set_thread_id(self, thread_id): 74 | self.thread_id = thread_id 75 | 76 | def __iter__(self): 77 | return self 78 | 79 | def __next__(self): 80 | if not self._was_initialized: 81 | self.reset() 82 | if self._batches_generated >= self._num_batches: 83 | self._was_initialized = False 84 | raise StopIteration 85 | minibatch = self.generate_train_batch() 86 | self._batches_generated += 1 87 | return minibatch 88 | 89 | @abstractmethod 90 | def generate_train_batch(self): 91 | '''override this 92 | Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE) 93 | ''' 94 | pass 95 | 96 | 97 | class SlimDataLoaderBase(object): 98 | def __init__(self, data, batch_size, number_of_threads_in_multithreaded=None): 99 | """ 100 | Slim version of DataLoaderBase (which is now deprecated). Only provides very simple functionality. 101 | 102 | You must derive from this class to implement your own DataLoader. You must overrive self.generate_train_batch() 103 | 104 | If you use our MultiThreadedAugmenter you will need to also set and use number_of_threads_in_multithreaded. See 105 | multithreaded_dataloading in examples! 106 | 107 | :param data: will be stored in self._data. You can use it to generate your batches in self.generate_train_batch() 108 | :param batch_size: will be stored in self.batch_size for use in self.generate_train_batch() 109 | :param number_of_threads_in_multithreaded: will be stored in self.number_of_threads_in_multithreaded. 110 | None per default. If you wish to iterate over all your training data only once per epoch, you must coordinate 111 | your Dataloaders and you will need this information 112 | """ 113 | __metaclass__ = ABCMeta 114 | self.number_of_threads_in_multithreaded = number_of_threads_in_multithreaded 115 | self._data = data 116 | self.batch_size = batch_size 117 | self.thread_id = 0 118 | 119 | def set_thread_id(self, thread_id): 120 | self.thread_id = thread_id 121 | 122 | def __iter__(self): 123 | return self 124 | 125 | def __next__(self): 126 | return self.generate_train_batch() 127 | 128 | @abstractmethod 129 | def generate_train_batch(self): 130 | '''override this 131 | Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE) 132 | ''' 133 | pass 134 | 135 | 136 | class DataLoader(SlimDataLoaderBase): 137 | def __init__(self, data, batch_size, num_threads_in_multithreaded=1, seed_for_shuffle=None, return_incomplete=False, 138 | shuffle=True, infinite=False, sampling_probabilities=None): 139 | """ 140 | 141 | :param data: will be stored in self._data for use in generate_train_batch 142 | :param batch_size: will be used by get_indices to return the correct number of indices 143 | :param num_threads_in_multithreaded: num_threads_in_multithreaded necessary for synchronization of dataloaders 144 | when using multithreaded augmenter 145 | :param seed_for_shuffle: for reproducibility 146 | :param return_incomplete: whether or not to return batches that are incomplete. Only applies is infinite=False. 147 | If your data has len of 34 and your batch size is 32 then there return_incomplete=False will make this loader 148 | return only one batch of shape 32 (omitting 2 of your training examples). If return_incomplete=True a second 149 | batch with batch size 2 will be returned. 150 | :param shuffle: if True, the order of the indices will be shuffled between epochs. Only applies if infinite=False 151 | :param infinite: if True, each batch contains randomly (uniformly) sampled indices. An unlimited number of 152 | batches is returned. If False, DataLoader will iterate over the data only once 153 | :param sampling_probabilities: only applies if infinite=True. If sampling_probabilities is not None, the 154 | probabilities will be used by np.random.choice to sample the indexes for each batch. Important: 155 | sampling_probabilities must have as many entries as there are samples in your dataset AND 156 | sampling_probabilitiesneeds to sum to 1 157 | """ 158 | super(DataLoader, self).__init__(data, batch_size, num_threads_in_multithreaded) 159 | self.infinite = infinite 160 | self.shuffle = shuffle 161 | self.return_incomplete = return_incomplete 162 | self.seed_for_shuffle = seed_for_shuffle 163 | self.rs = np.random.RandomState(self.seed_for_shuffle) 164 | self.current_position = None 165 | self.was_initialized = False 166 | self.last_reached = False 167 | self.sampling_probabilities = sampling_probabilities 168 | 169 | # when you derive, make sure to set this! We can't set it here because we don't know what data will be like 170 | self.indices = None 171 | 172 | def reset(self): 173 | assert self.indices is not None 174 | 175 | self.current_position = self.thread_id * self.batch_size 176 | 177 | self.was_initialized = True 178 | 179 | # no need to shuffle if we are returning infinite random samples 180 | if not self.infinite and self.shuffle: 181 | self.rs.shuffle(self.indices) 182 | 183 | self.last_reached = False 184 | 185 | def get_indices(self): 186 | # if self.infinite, this is easy 187 | if self.infinite: 188 | return np.random.choice(self.indices, self.batch_size, replace=True, p=self.sampling_probabilities) 189 | 190 | if self.last_reached: 191 | self.reset() 192 | raise StopIteration 193 | 194 | if not self.was_initialized: 195 | self.reset() 196 | 197 | indices = [] 198 | 199 | for b in range(self.batch_size): 200 | if self.current_position < len(self.indices): 201 | indices.append(self.indices[self.current_position]) 202 | 203 | self.current_position += 1 204 | else: 205 | self.last_reached = True 206 | break 207 | 208 | if len(indices) > 0 and ((not self.last_reached) or self.return_incomplete): 209 | self.current_position += (self.number_of_threads_in_multithreaded - 1) * self.batch_size 210 | return indices 211 | else: 212 | self.reset() 213 | raise StopIteration 214 | 215 | @abstractmethod 216 | def generate_train_batch(self): 217 | ''' 218 | make use of self.get_indices() to know what indices to work on! 219 | :return: 220 | ''' 221 | pass 222 | 223 | 224 | def default_collate(batch): 225 | ''' 226 | heavily inspired by the default_collate function of pytorch 227 | :param batch: 228 | :return: 229 | ''' 230 | if isinstance(batch[0], np.ndarray): 231 | return np.vstack(batch) 232 | elif isinstance(batch[0], (int, np.int64)): 233 | return np.array(batch).astype(np.int32) 234 | elif isinstance(batch[0], (float, np.float32)): 235 | return np.array(batch).astype(np.float32) 236 | elif isinstance(batch[0], (np.float64,)): 237 | return np.array(batch).astype(np.float64) 238 | elif isinstance(batch[0], (dict, OrderedDict)): 239 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 240 | elif isinstance(batch[0], (tuple, list)): 241 | transposed = zip(*batch) 242 | return [default_collate(samples) for samples in transposed] 243 | elif isinstance(batch[0], str): 244 | return batch 245 | else: 246 | raise TypeError('unknown type for batch:', type(batch)) 247 | 248 | 249 | class DataLoaderFromDataset(DataLoader): 250 | def __init__(self, data, batch_size, num_threads_in_multithreaded, seed_for_shuffle=1, collate_fn=default_collate, 251 | return_incomplete=False, shuffle=True, infinite=False): 252 | ''' 253 | A simple dataloader that can take a Dataset as data. 254 | It is not super efficient because I cannot make too many hard assumptions about what data_dict will contain. 255 | If you know what you need, implement your own! 256 | :param data: 257 | :param batch_size: 258 | :param num_threads_in_multithreaded: 259 | :param seed_for_shuffle: 260 | ''' 261 | super(DataLoaderFromDataset, self).__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle, 262 | return_incomplete=return_incomplete, shuffle=shuffle, 263 | infinite=infinite) 264 | self.collate_fn = collate_fn 265 | assert isinstance(self._data, Dataset) 266 | self.indices = np.arange(len(data)) 267 | 268 | def generate_train_batch(self): 269 | indices = self.get_indices() 270 | 271 | batch = [self._data[i] for i in indices] 272 | 273 | return self.collate_fn(batch) 274 | --------------------------------------------------------------------------------